본문 바로가기

Deep Learning

Feature Pyramid Network (FPN) pytorch implementation

반응형
class FPN(nn.Module):

    def __init__(self, dim, sizes, channels):
        '''
        dim : target dimension
        sizes = [57, 113, 225, 450]
        channels = [1024, 512, 256, 64]
        '''
        super(FPN, self).__init__()

        self.sizes = sizes
        self.channels = channels
        self.dim_reduce, self.merge = nn.ModuleDict(), nn.ModuleDict()
        for idx, size in enumerate(sizes):
            self.dim_reduce[str(size)] = nn.Conv2d(channels[idx], dim, kernel_size=1, stride=1, padding=0)
            self.merge[str(size)] = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)

    def upsample_add(self, up, bottom):
        _, _, H, W = bottom.size()
        return F.upsample(up, size=(H, W), mode='bilinear') + bottom

    def forward(self, feats):
        '''
        feats : Dicts
        '''

        outputs  = {}
        for idx, size in enumerate(self.sizes):
            if (idx == 0):
                top = feats[str(size)]
                top = self.dim_reduce[str(size)](top)
                outputs[str(size)] = top
            else:
                bottom = feats[str(size)]
                bottom = self.dim_reduce[str(size)](bottom)
                bottom = self.upsample_add(up=outputs[str(self.sizes[idx-1])], bottom=bottom)
                bottom = self.merge[str(size)](bottom)
                outputs[str(size)] = bottom

        return outputs