반응형
class FPN(nn.Module):
def __init__(self, dim, sizes, channels):
'''
dim : target dimension
sizes = [57, 113, 225, 450]
channels = [1024, 512, 256, 64]
'''
super(FPN, self).__init__()
self.sizes = sizes
self.channels = channels
self.dim_reduce, self.merge = nn.ModuleDict(), nn.ModuleDict()
for idx, size in enumerate(sizes):
self.dim_reduce[str(size)] = nn.Conv2d(channels[idx], dim, kernel_size=1, stride=1, padding=0)
self.merge[str(size)] = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
def upsample_add(self, up, bottom):
_, _, H, W = bottom.size()
return F.upsample(up, size=(H, W), mode='bilinear') + bottom
def forward(self, feats):
'''
feats : Dicts
'''
outputs = {}
for idx, size in enumerate(self.sizes):
if (idx == 0):
top = feats[str(size)]
top = self.dim_reduce[str(size)](top)
outputs[str(size)] = top
else:
bottom = feats[str(size)]
bottom = self.dim_reduce[str(size)](bottom)
bottom = self.upsample_add(up=outputs[str(self.sizes[idx-1])], bottom=bottom)
bottom = self.merge[str(size)](bottom)
outputs[str(size)] = bottom
return outputs