transformer 中的 attention

机器学习与生成对抗网络

共 2466字,需浏览 5分钟

 ·

2022-05-10 18:10

来源:知乎—皮特潘

地址:https://zhuanlan.zhihu.com/p/444811538
大火的transformer 本质就是:
使用attention机制的seq2seq。
所以它的核心就是attention机制,今天就讲attention。直奔代码VIT-pytorch:
https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
中的

class Attention(nn.Module):    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):        super().__init__()        inner_dim = dim_head *  heads        project_out = not (heads == 1 and dim_head == dim)
self.heads = heads self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) ) if project_out else nn.Identity()
def forward(self, x): qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v) out = rearrange(out, 'b h n d -> b n (h d)')        return self.to_out(out)
看吧!就是这么简单。今天就彻底搞懂这个东西。
先记住attention的这么几个点:
  • attention和CNN、RNN、FC、GCN等都是一个级别的东西,用来提取特征;既然是特征提取,一定有权重(W+B)存在。
  • attention的优点:可以像CNN一样并行运算 + 像RNN一样通过一层就拥有全局资讯。有一个东西也可以做到,那就是FC,但是FC有个弱点:对输入尺寸有限制,说白了不好适应可变输入数据,这对于序列无疑是非常不友好的。
  • 可以像CNN一样并行运算 ,其实CNN运算也是通过im2col或winograd等转化为矩阵运算的。
  • RNN不能并行,所以通常它处理的数据有“时序”这个特点。既然是“时序”,那么就不是同一个时刻完成的,所以不能并行化。
综上所述:attention优点 = CNN并行+RNN全局资讯+对输入尺寸(时序长度维度上)没有限制。
如果你能创造一个拥有上面三点优点的东西出来,你也可以引领潮流。
然后回到代码,再熟悉这么几个设置:
  • batch维度:大家利用同样的权重和操作提取特征,可以理解为for循环式,相互之间没有信息交互;
  • multi head维度:同batch类似,不过是利用的不同权重和相同操作提取特征,最后concate一起使用;
  • FC层:是作用在每一个特征上,类似CNN中的1X1,可以叫“pointwise”,和序列长度没有关系;因为序列中所有的特征经过的是同一个FC。
下面看这个图,看完不懂的可以扇自己了:
attention的顺序是:
1. 你有长度为n(序列)的特征,每个特征都是一个向量;
2. 每个向量都经过FC1,FC2,FC3获取到q,k,v三个向量(长度自己定),记住,不同特征用的是同一个FC1,FC2,FC3。可以说对于一个head,就一组FC1,FC2,FC3。
3. 特征1的q1和所有特征的k 进行点乘,获取一串值,注意:和自己的k也进行点乘;点乘向量变标量,表示相似性。多个K可不就是一串标量。
4. 3中的那一串值进行softmax操作,作为权重 对所有v加权求和,获得特征1输出;
5. 其他所有的特征和特征1的操作一样,注意所有特征是一块并行计算的;
6. 最后获取的和输入一样长度的特征序列再经过FC进行长度(特征维度)调整,也可以不要;
对了,softmax之前不要忘记 除以 qkv长度开方进行scaled,其实就是标准化操作(我觉得可以理解为各种N(BN,GN,LN等))。



猜您喜欢:

 戳我,查看GAN的系列专辑~!
一顿午饭外卖,成为CV视觉前沿弄潮儿!
CVPR 2022 | 25+方向、最新50篇GAN论文
 ICCV 2021 | 35个主题GAN论文汇总
超110篇!CVPR 2021最全GAN论文梳理
超100篇!CVPR 2020最全GAN论文梳理


拆解组新的GAN:解耦表征MixNMatch

StarGAN第2版:多域多样性图像生成


附下载 | 《可解释的机器学习》中文版

附下载 |《TensorFlow 2.0 深度学习算法实战》

附下载 |《计算机视觉中的数学方法》分享


《基于深度学习的表面缺陷检测方法综述》

《零样本图像分类综述: 十年进展》

《基于深度神经网络的少样本学习综述》


浏览 34
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报