【论文复现】FPN

以RetinaNet里面的FPN为例,总共有5层,原始的FPN论文中只有P3,P4,P5三层。P6和P7是RetinaNet论文里面特有的。其中P6有两种实现方式,可以由P5得到也可以由C5得到。

import torch.nn as nn
import torch
from torch.nn import functional as F

class FPN(nn.Module):

    def __init__(self,C3_size,C4_size,C5_size,feature_size=256,use_p5=True):
        super().__init__()

        self.prj_5 = nn.Conv2d(C5_size,feature_size,kernel_size=1,stride=1,padding=0)
        self.prj_4 = nn.Conv2d(C4_size,feature_size,kernel_size=1,stride=1,padding=0)
        self.prj_3 = nn.Conv2d(C3_size,feature_size,kernel_size=1,stride=1,padding=0)
        
        self.conv_5 = nn.Conv2d(feature_size,feature_size,kernel_size=3,padding=1)
        self.conv_4 = nn.Conv2d(feature_size,feature_size,kernel_size=3,padding=1)
        self.conv_3 = nn.Conv2d(feature_size,feature_size,kernel_size=3,padding=1)
        
        if use_p5:
            self.conv_out6 = nn.Conv2d(feature_size,feature_size,kernel_size=3,padding=1,stride=2)
        else:
            self.conv_out6 = nn.Conv2d(C5_size,feature_size,kernel_size=3,padding=1,stride=2)

        self.conv_out7 = nn.Conv2d(feature_size, feature_size, kernel_size=3,padding=1,stride=2)
        self.use_p5 = use_p5
        self.apply(self.init_conv_kaiming)

    def upsample(self,inputs):
        src,target = inputs
        return F.interpolate(src,size=(target.shape[2],target.shape[3]),mode='nearest')

    def init_conv_kaiming(self,module):
        if isinstance(module,nn.Conv2d):
            nn.init.kaiming_normal_(module.weight,a=1)

            if module.bias is not None:
                nn.init.constant_(module.bias,0)
    
    def forward(self,inputs):
        C3,C4,C5 = inputs
        P5 = self.prj_5(C5)
        P4 = self.prj_4(C4)
        P3 = self.prj_3(C3)
        P4 = P4 + self.upsample([P5,P4])
        P3 = P3 + self.upsample([P4,P3])
        P3 = self.conv_3(P3)
        P4 = self.conv_4(P4)
        P5 = self.conv_5(P5)
        P6 = self.conv_out6(P5) if self.use_p5 else self.conv_out6(C5)
        P7 = self.conv_out7(F.relu(P6))

        return [P3, P4, P5, P6, P7]

if __name__ == '__main__':
    C3 = torch.randn(2,16,200,200)
    C4 = torch.randn(2,32,100,100)
    C5 = torch.randn(2,64,50,50)

    model = FPN(16,32,64)

    out = model([C3,C4,C5])
    for i in range(len(out)):
        print(out[i].shape)
        
    # torch.Size([2, 256, 200, 200])
    # torch.Size([2, 256, 100, 100])
    # torch.Size([2, 256, 50, 50])
    # torch.Size([2, 256, 25, 25])
    # torch.Size([2, 256, 13, 13])

版权声明:本文为CSDN博主「是王同学呀」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/wxd1233/article/details/121645640

是王同学呀

我还没有学会写个人说明!

暂无评论

发表评论

相关推荐

YOLOX训练自己的数据集,txt形式

YOLOX官方支持训练VOC和COCO数据集,但习惯了Yolov3~v5的txt加载数据集,尤其是训练自己的数据集时,标签写入txt文本更方便些,但是YOLOX官方要你自己写&#xff0c