小黑算法成长日记23:selfAttention与multiHeadAttention

SelfAttention操作

从单个字的角度:

q i = h i W Q , k j = h j W K , v j = h j W V q_i = h_iW_Q,k_j = h_jW_K,v_j = h_jW_V qi=hiWQ,kj=hjWK,vj=hjWV

e i j = q i k j T e_{ij} = q_ik_j^T eij=qikjT

α i = S o f t m a x ( [ e i , 1 , . . . , e i , T ] ) \alpha_i = Softmax([e_{i,1},...,e_{i,T}]) αi=Softmax([ei,1,...,ei,T])

h i ′ = ( ∑ j = 1 T α i , j v j ) W 0 h'_i = (\sum_{j=1}^T \alpha_{i,j}v_j)W_0 hi=(j=1Tαi,jvj)W0


矩阵的形式:

Q = H W Q , K = H W K , V = H W V Q = HW_Q,K = HW_K,V = HW_V Q=HWQ,K=HWK,V=HWV

E = Q K T E = QK^T E=QKT

E ′ = S o f t m a x ( E ) E' = Softmax(E) E=Softmax(E)

H ′ = E ′ V H' = E'V H=EV

单头selfAttention

import math
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
    def __init__(self,d_model,d_head):
        super(SelfAttention,self).__init__()
        self.w_q = nn.Linear(d_model,d_head)
        self.w_k = nn.Linear(d_model,d_head)
        self.w_v = nn.Linear(d_model,d_head)
        self.w_o = nn.Linear(d_head,d_model)
    def forward(self,x):
        # x:[batch_size,max_len,model_dim]
        # q,k,v:[batch_size,max_len,d_head]
        q = self.w_q(x)
        k = self.w_k(x)
        v = self.w_v(x)
        
        attn_score = torch.matmul(q,k.permute(0,2,1))   # 注意这里不是reshape
        attn_score = torch.softmax(attn_score,dim = -1)    # [batch_size,max_len,max_len]
        output = torch.matmul(attn_score,v)    # [batch_size,max_len,d_head]
        return self.w_o(output)
x = torch.randn(3,9,100)
model = SelfAttention(100,80)
model(x).shape

多头selfAttention

# 多头selfattention
class MultiHeadSelfAttention(nn.Module):
    def __init__(self,d_model = 768,d_head = 64):
        super(MultiHeadSelfAttention,self).__init__()
        assert d_model % d_head == 0
        self.w_q = nn.Linear(d_model,d_model)
        self.w_k = nn.Linear(d_model,d_model)
        self.w_v = nn.Linear(d_model,d_model)
        self.w_o = nn.Linear(d_model,d_model)
        
        self.n_heads = int(d_model // d_head)
        
        self.d_model = d_model
        self.d_head = d_head
    def forward(self,x,mask = None):
        batch_size = x.shape[0]
        max_len = x.shape[1]
        q = self.w_q(x).view(batch_size,max_len,self.n_heads,self.d_head)
        k = self.w_k(x).view(batch_size,max_len,self.n_heads,self.d_head)
        v = self.w_v(x).view(batch_size,max_len,self.n_heads,self.d_head)
        
        q = q.permute(0,2,1,3)
        k = k.permute(0,2,1,3)
        v = v.permute(0,2,1,3)    # [batch_size,num_head,max_len,d_head]
        
        attn_score = torch.matmul(q,k.permute(0,1,3,2))
        
        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(-1)    # [batch_size,1,max_len,1]
            attn_score = attn_score.masked_fill(mask == 0,-1e-25)  
        attn_score = torch.softmax(attn_score,-1)    # [batch_size,num_head,max_len,max_len]
        out = torch.matmul(attn_score,v).permute(0,2,1,3)
        out = out.contiguous().view(batch_size,max_len,-1)
        return self.w_o(out)
if __name__ == "__main__":
    x = torch.randn(2, 9, 768)
    mask = torch.tensor([
        [1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0, 0, 0],
    ]).bool()

    model = MultiHeadSelfAttention()
    print(model(x,mask).shape)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值