본문 바로가기

Deep Learning

[Pytorch] Transformer w/o self-attention implementation compatible with TensorRT

반응형
class TransFormer(nn.Module):

    def __init__(self, dim, heads, dim_head, drop=0.1, qkv_bias=True):
        super(TransFormer, self).__init__()

        self.dim_head = dim_head
        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(dim, heads * dim_head, bias=qkv_bias)
        self.to_k = nn.Linear(dim, heads * dim_head, bias=qkv_bias)
        self.to_v = nn.Linear(dim, heads * dim_head, bias=qkv_bias)

        self.drop1 = nn.Dropout(drop)
        self.drop2 = nn.Dropout(drop)

        self.proj = nn.Linear(heads * dim_head, dim)
        self.prenorm = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(nn.Linear(dim, 2 * dim), nn.GELU(), nn.Dropout(drop), nn.Linear(2 * dim, dim))
        self.postnorm = nn.LayerNorm(dim)

    def simple_attention(self, q, k, v):
        '''
        q : b Lq d
        k : b Lk d
        v : b Lk d
        '''

        b, Lq, _ = q.size()
        b, Lk, _ = k.size()

        # Projection with MLP
        q = self.to_q(q)                                # b 1 dim
        k = self.to_k(k)                                # b (H W) dim
        v = self.to_v(v)                                # b (H W) dim

        # Split into multiple headers
        q = q.reshape(b, Lq, self.heads, self.dim_head)  # b 1 heads dim_head
        k = k.reshape(b, Lk, self.heads, self.dim_head)  # b (H W) heads dim_head
        v = v.reshape(b, Lk, self.heads, self.dim_head)  # b (H W) heads dim_head

        q = q.permute(0, 3, 1, 2).contiguous()  # b heads 1 dim_head
        k = k.permute(0, 3, 1, 2).contiguous()  # b heads (H W) dim_head
        v = v.permute(0, 3, 1, 2).contiguous()  # b heads (H W) dim_head

        q = q.reshape(b * self.heads, Lq, self.dim_head)  # (b heads) 1 dim_head
        k = k.reshape(b * self.heads, Lk, self.dim_head)  # (b heads) (H W) dim_head
        v = v.reshape(b * self.heads, Lk, self.dim_head)  # (b heads) (H W) dim_head

        # Inner Product
        dot = self.scale * q @ k.permute(0, 2, 1)                    # (b heads) 1 (H W)
        dot = dot.softmax(dim=-1)

        # Combine values (image level features).
        a = dot @ v                                                  # (b heads) 1 dim_head
        a = a.reshape(b, self.heads, 1, self.dim_head)               # b heads 1 dim_head
        a = a.permute(0, 2, 1, 3).contiguous()                       # b 1 heads dim_head
        a = a.reshape(b, 1, self.heads * self.dim_head)

        # Combine multiple heads
        return self.proj(a)

    def forward(self, q, k, v):
        """
        q : b Lq d
        k : b Lk d
        v : b Lk d
        """

        attn = self.simple_attention(q, k, v)

        # Skip connection
        output = q + self.drop1(attn)
        output = self.prenorm(output)

        # FFN
        output = output + self.drop2(self.mlp(output))
        output = self.postnorm(output)

        return output