반응형
class TransFormer(nn.Module):
def __init__(self, dim, heads, dim_head, drop=0.1, qkv_bias=True):
super(TransFormer, self).__init__()
self.dim_head = dim_head
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = nn.Linear(dim, heads * dim_head, bias=qkv_bias)
self.to_k = nn.Linear(dim, heads * dim_head, bias=qkv_bias)
self.to_v = nn.Linear(dim, heads * dim_head, bias=qkv_bias)
self.drop1 = nn.Dropout(drop)
self.drop2 = nn.Dropout(drop)
self.proj = nn.Linear(heads * dim_head, dim)
self.prenorm = nn.LayerNorm(dim)
self.mlp = nn.Sequential(nn.Linear(dim, 2 * dim), nn.GELU(), nn.Dropout(drop), nn.Linear(2 * dim, dim))
self.postnorm = nn.LayerNorm(dim)
def simple_attention(self, q, k, v):
'''
q : b Lq d
k : b Lk d
v : b Lk d
'''
b, Lq, _ = q.size()
b, Lk, _ = k.size()
# Projection with MLP
q = self.to_q(q) # b 1 dim
k = self.to_k(k) # b (H W) dim
v = self.to_v(v) # b (H W) dim
# Split into multiple headers
q = q.reshape(b, Lq, self.heads, self.dim_head) # b 1 heads dim_head
k = k.reshape(b, Lk, self.heads, self.dim_head) # b (H W) heads dim_head
v = v.reshape(b, Lk, self.heads, self.dim_head) # b (H W) heads dim_head
q = q.permute(0, 3, 1, 2).contiguous() # b heads 1 dim_head
k = k.permute(0, 3, 1, 2).contiguous() # b heads (H W) dim_head
v = v.permute(0, 3, 1, 2).contiguous() # b heads (H W) dim_head
q = q.reshape(b * self.heads, Lq, self.dim_head) # (b heads) 1 dim_head
k = k.reshape(b * self.heads, Lk, self.dim_head) # (b heads) (H W) dim_head
v = v.reshape(b * self.heads, Lk, self.dim_head) # (b heads) (H W) dim_head
# Inner Product
dot = self.scale * q @ k.permute(0, 2, 1) # (b heads) 1 (H W)
dot = dot.softmax(dim=-1)
# Combine values (image level features).
a = dot @ v # (b heads) 1 dim_head
a = a.reshape(b, self.heads, 1, self.dim_head) # b heads 1 dim_head
a = a.permute(0, 2, 1, 3).contiguous() # b 1 heads dim_head
a = a.reshape(b, 1, self.heads * self.dim_head)
# Combine multiple heads
return self.proj(a)
def forward(self, q, k, v):
"""
q : b Lq d
k : b Lk d
v : b Lk d
"""
attn = self.simple_attention(q, k, v)
# Skip connection
output = q + self.drop1(attn)
output = self.prenorm(output)
# FFN
output = output + self.drop2(self.mlp(output))
output = self.postnorm(output)
return output