Feature Pyramid Networks for Object Detection

PDF: ​​https://arxiv.org/pdf/1612.03144v2.pdf​​​ PyTorch代码: ​​https://github.com/shanglianlm0525/PyTorch-Networks​

目标检测论文:Feature Pyramid Networks for Object Detection及其PyTorch实现_2d


目标检测论文:Feature Pyramid Networks for Object Detection及其PyTorch实现_github_02


PyTorch代码:

import torch
import torch.nn as nn
import torchvision

class FPN(nn.Module):
def __init__(self):
super(FPN, self).__init__()
resnet = torchvision.models.resnet50()
layers = list(resnet.children())

self.layer1 = nn.Sequential(*layers[:5])
self.layer2 = nn.Sequential(*layers[5])
self.layer3 = nn.Sequential(*layers[6])
self.layer4 = nn.Sequential(*layers[7])

self.lateral5 = nn.Conv2d(in_channels=2048,out_channels=256,kernel_size=1)
self.lateral4 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1)
self.lateral3 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1)
self.lateral2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1)

self.upsample2 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=4, stride=2, padding=1)
self.upsample3 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=4, stride=2, padding=1)
self.upsample4 = nn.ConvTranspose2d(in_channels=256,out_channels=256, kernel_size=4, stride=2, padding=1)

self.smooth2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1)
self.smooth3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1)
self.smooth4 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1)

self.init_params()

def init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

def forward(self, x):
c2 = x = self.layer1(x)
c3 = x = self.layer2(x)
c4 = x = self.layer3(x)
c5 = x = self.layer4(x)

p5 = self.lateral5(c5)
p4 = self.upsample4(p5)+ self.lateral4(c4)
p3 = self.upsample3(p4)+ self.lateral3(c3)
p2 = self.upsample2(p3)+ self.lateral2(c2)

p4 = self.smooth4(p4)
p3 = self.smooth3(p3)
p2 = self.smooth4(p2)
return p2,p3,p4,p5

if __name__ == '__main__':
model = FPN()
print(model)

input = torch.randn(1, 3, 224, 224)
p2, p3, p4, p5 = model(input)
print(p2.shape)
print(p3.shape)
print(p4.shape)
print(p5.shape)