DynamicViT:动态Token稀疏化ViT 论文解读

AI人工智能初学者

共 9812字,需浏览 20分钟

 ·

2021-06-22 00:28

关注"Smarter",加"星标"置顶

及时获取最优质的CV内容

【Vision Transformer 解读】仅代表个人理解,如果有理解错误的地方欢迎纠正~

原文地址:DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification
github源码:raoyongming/DynamicViT

文中提出了一种动态token稀疏化的视觉Transformer,通过分层剪枝的操作,即在ViT的不同层中动态地筛选所需要传入下一层的token数量,来达到加速inference的目的。文中的分层剪枝策略,最终剪枝了约66%的token数量,可以减少31%~37%的FLOPs,提高40%左右的模型运行速度。并且准确率只下降了0.5%不到。


01

Motivation


文中通过可视化发现,在Vision Transformer的视觉推理过程中,Attention主要集中在一部分信息丰富的patch上,如图1所示。这意味着在prediction的过程中,我们可以只保留重要性较大的token,动态地去除重要性较低的token,不会对最终的预测结果带来较大的影响。

图1:Vision Transformer的Attention可视化

文中的可视化使用的以下paper中的Attention可视化方法,收录于CVPR2021,感兴趣的可以去看看源码。

论文地址:

Transformer Interpretability Beyond Attention Visualization

源码地址:

hila-chefer/Transformer-Explainability

在传统的卷积神经网络中,为了加速模型的计算,往往会用到池化等方法来减少输入的feature map大小,也就是一种结构化的下采样方法,而对于Transformer模型来说,以往的加速工作侧重点在于,如何降低Attention Map计算的复杂度,例如Reformer、Linformer、Performer等,但由于Transformer本身的结构,可以处理动态长度的输入序列,因此可以从另外一个思路去进行剪枝加速,即针对输入的token,将对于prediction贡献度不大的token减去,如图2中的(b)所示,只保留重要的token,将不重要的token去除。并且针对不同的样本,包含重要information的token不一致,因此需要实现一个data-dependent的sparsification策略,根据不同的样本,自适应地去选择包含重要information的token

图2:(a)结构化下采样(b)动态token稀疏化

02

Method

基于以上的motivation及对应分析,文中提出了一个data-dependent的predictor模块,来自适应地选择出重要的token。

图3:加入了prediction module的ViT整体架构

整体架构如图3所示,包含一个主干的网络,例如:DeiT,LV-ViT等,另外包含了额外插入的预测模块,来动态稀疏化token。

为了减少预测的成本开销:文中使用的prediction模块是比较轻量化的MLP网络

对于输入包含N个C维Token的   而言,Predictor模块首先对于每个token通过一层Linear Projection来整合每个token单独的信息,也就是local的信息,表示如下:

  

然后在通过简单的全局池化,得到一个包含全局信息的向量:

  

对于该公式中的几个部分说明,其中  表示的操作是,对当前阶段参与Transformer计算的token进行的全局池化操作,因为文中的sparse策略是渐进式的,是在所保留的token基础上进一步稀疏化,在不同阶段所参与计算的token数量不一致,所以在预测mask时,仅仅对保留下来的token进行全局池化,用以表示全局信息,  是用来表示的是在这一层中所保留下来的token。

  

  

最终将全局向量与每个token拼接,输入到轻量化的MLP模块中预测出一个二值化的Mask来表示这一层中sparse策略所保留的token。

需要注意的是,在整个网络中都维护着一个初始化全为1的二值mask,在文中用   表示,初始化为1表示在网络的浅层中全部token都会参与计算,在经过不同阶段的sparse prediction模块后,mask中的一部分1会用0替代,来表示这个位置的token不参与计算。

可视化流程如下所示:

在初始阶段,默认所有token参与计算。

在经过一次sparsification操作后,模型会根据输入的data,predict出一个新的二值mask,并与当前的mask合并,经过predict出来不重要的token不参与计算,如下图所示,不重要的部分用0表示。

整个mask的稀疏化过程可以用下图表示:需要注意的是,每个阶段中所被放弃的token,在之后的阶段不会再参与更深层次中的attention的计算,所以这个二值化mask中为0的部分是不会变化的,在每一层中都会保留。

整个predictor模块的代码说明如下:从官方代码拷贝下来的,感兴趣的可以去阅读源码~

class PredictorLG(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, embed_dim=384):
super().__init__()
# local建模
self.in_conv = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, embed_dim),
nn.GELU()
)
# 为每个token预测一个两维的向量,来表示当前token是否需要被mask掉
self.out_conv = nn.Sequential(
nn.Linear(embed_dim, embed_dim // 2),
nn.GELU(),
nn.Linear(embed_dim // 2, embed_dim // 4),
nn.GELU(),
nn.Linear(embed_dim // 4, 2),
nn.LogSoftmax(dim=-1)
)

def forward(self, x, policy):
# x 表示当前输入的tokens
# policy表示当前的mask,由0和1组成,0表示不需要参与后序计算的token
x = self.in_conv(x) # 对于输入的每一个token先经过一层linear projection对局部信息进行建模
B, N, C = x.size()
local_x = x[:,:, :C//2]
# 在计算全局向量的时候,只对参与后序计算的token进行全局池化操作
global_x = (x[:,:, C//2:] * policy).sum(dim=1, keepdim=True) / torch.sum(policy, dim=1, keepdim=True)
# 将全局向量与局部向量拼接
x = torch.cat([local_x, global_x.expand(B, N, C//2)], dim=-1)
# 通过简单的MLP来输出每个token是否需要保留的一个分数,一组score
return self.out_conv(x)

在进行local向量的linear projection和预测每个token是否需要保留的score的时使用MLP的原因是为了让整个predictor模块轻量化

笔者用随机的数据进行测试的结果如下:

x = torch.randn(2, 14, 384) # 随机初始化一个batch size为2,token数量为14,每个token维度为384的输入数据
prev_decision = torch.ones(2, 14 ,1) # 初始化一个全为1的mask
pred_score = predictor(x, prev_decision) # 将policy和x同时输入到predictor模块中,得到一个score向量
print(pred_score)

所输出的pred_score如下:

tensor([[[-0.6901, -0.6963],
[-0.6839, -0.7025],
[-0.6914, -0.6949],
[-0.6991, -0.6873],
[-0.6854, -0.7009],
[-0.6829, -0.7035],
[-0.7014, -0.6849],
[-0.7043, -0.6821],
[-0.6982, -0.6881],
[-0.6844, -0.7020],
[-0.7058, -0.6806],
[-0.6888, -0.6975],
[-0.6916, -0.6947],
[-0.7000, -0.6863]],

[[-0.7011, -0.6853],
[-0.6940, -0.6923],
[-0.6988, -0.6875],
[-0.6869, -0.6994],
[-0.6981, -0.6882],
[-0.6957, -0.6906],
[-0.6889, -0.6974],
[-0.7022, -0.6841],
[-0.7027, -0.6837],
[-0.6872, -0.6991],
[-0.6941, -0.6922],
[-0.6915, -0.6948],
[-0.6931, -0.6932],
[-0.6863, -0.7000]]], grad_fn=<LogSoftmaxBackward>)

对于输出的score,文中使用了Gumbel-Softmax trick进行了二值化操作,之所以使用Gumbel-Softmax操作是为了让整个二值化的过程是可导的。

hard_keep_decision = F.gumbel_softmax(pred_score, hard=True)[:, :, 0:1] * prev_decision
print(hard_keep_decision)

二值化后的mask输出如下:

tensor([[[0.],
[0.],
[1.],
[0.],
[1.],
[1.],
[0.],
[0.],
[0.],
[1.],
[1.],
[1.],
[1.],
[1.]],

[[1.],
[1.],
[1.],
[1.],
[1.],
[0.],
[1.],
[1.],
[0.],
[1.],
[1.],
[1.],
[0.],
[0.]]], grad_fn=<MulBackward0>)

预测出了一个稀疏的二值化mask,其中1的位置表示需要保留的token,0表示不重要的token,不参与后序计算。

Training Details

在网络的训练过程中,为了实现并行化,不会将无用的token剪去,而是采用了一个注意力masking策略,注意力masking策略的意思是,在self-attention中计算attention map时,加上一个mask,来显示地切断token之间的联系,已经被稀疏化的token不参与attention的计算。让模型的训练更加稳定。

稍加注释的源码如下:

class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

def softmax_with_policy(self, attn, policy, eps=1e-6):
# 将输入的policy转化为对应的mask,保证mask中值为0的token不参与attention的计算
B, N, _ = policy.size()
B, H, N, N = attn.size()
attn_policy = policy.reshape(B, 1, 1, N) # * policy.reshape(B, 1, N, 1)
eye = torch.eye(N, dtype=attn_policy.dtype, device=attn_policy.device).view(1, 1, N, N)
attn_policy = attn_policy + (1.0 - attn_policy) * eye
max_att = torch.max(attn, dim=-1, keepdim=True)[0]
attn = attn - max_att
# attn = attn.exp_() * attn_policy
# return attn / attn.sum(dim=-1, keepdim=True)

# for stable training
attn = attn.to(torch.float32).exp_() * attn_policy.to(torch.float32)
attn = (attn + eps/N) / (attn.sum(dim=-1, keepdim=True) + eps)
return attn.type_as(max_att)

def forward(self, x, policy):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)

attn = (q @ k.transpose(-2, -1)) * self.scale

if policy is None:
attn = attn.softmax(dim=-1)
else:
attn = self.softmax_with_policy(attn, policy)

x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x

可视化过程如下所示:其中被稀疏化的token只与自身作Attention计算

首先将当前的二值化mask转化为attention mask:

根据二值化Mask生成Attention Mask

被稀疏化的token只与自身交互,其余token参与self-attention计算

将最终的attention mask加到计算后的attention map上,这样就可以显示地切断已经被筛选掉的token与其他token之间的交互,保证最后prediction的结果只与保留下来的token有关。

仅对保留下来的token作self-attention操作

Testing Details

在测试阶段,为了加速推理速度,DynamicViT的做法是,对每个token,按照PredictorLG模块所输出的保留概率进行排序,每次保留排序靠前的,固定比率的token。假设有12层的Vision Transformer,文中的做法是将PredictorLG模块分别放在第4层,第7层,第9层之前,在训练阶段不直接剔除token,而是采用attention masking策略,而在测试阶段每层剔除一定比率的token,如果文中设置的比率为0.7,那么经过三个PredictorLG模块后,所保留的token数量大约为:    ,也就是34.3%左右的token数量。

Training and Inference

为了最小化 token sparsification对模型性能带来的影响,文中以original ViT backbone作为teacher model,并且希望DynamicViT的behavior能尽可能地接近original ViT backbone。

First

让最终保存下来的token特征尽可能地接近teacher model对应同样位置的token特征,不需要考虑在之前的sparsification stage未保存下来的token。可以看作是一种自蒸馏的方式。

其中  和  分别表示经过DynamicViT和Teacher Model处理后的第i个token。

  表示第b个样本,在第s个 sparsification stage时是否保留第i个token,和之前一样,是一个二值化的mask。

Second

最小化DynamicViT和Teacher Model之间的预测差异:

Finally

希望DynamicViT在每个sparsification stage最终保留的token比例接近我们人为预设的值

其中  表示在第  个sparsification stage认为预设的比例。

最终的损失函数如下:

DynamicViT的实验设置为  ,用来平衡各个Loss。

在Inference阶段,在每一个sparsification stage给定一个所需保留token的比率   ,直接根据PredictorLG模块所预测的score进行排序,保留固定比率的token,其余token直接抛弃。假设  表示初始化时的token数量,那么在第s个stage,我们保留下来的token数量为  ,如果初始化的token数量为100个,假设前两个sparsification stage保留token的数量比率都为0.7,那么在s=2时,所保留的token数量约为  个。

这个保留的token是根据score排序得到的,也就是以下公式,score大意味着information丰富。

03

Experiments

可以看到在每个stage都只保留  也就是百分之70%的token,准确率只drop了0.5不到。

可以看出在仅增加少量参数的情况下,DynamicViT可以大幅度减少GFLOPs,并依然可以保证模型的准确率。

为了验证文中提出的动态模块的有效性,文中还对其他sparsification策略进行了比较,如下表所示:

其中 Structural 表示的是 kernel size=2 和 stride=2 的 Average Pooling,文中将这个pooling操作放在网络的第6层,Static表示的是,整个网络学习出来一个统一的sparsification mask,而非数据依赖的,意味着不能根据样本动态调整所需要的舍弃的token,对于每个样本而言所舍弃的token都是相同位置的。

可以看出在同等GFLOPs下,文中的Dynamic Sparsification策略为模型带来的影响最小。

Visualization

可视化结果可以看出,经过每个stage后,包含信息量比较丰富的token都被保存了下来,不相关的背景token都被舍弃了。

Reference

[1] DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification

全新激活函数 | 详细解读:HP-x激活函数(附论文下载)


遮挡人脸问题 | 详细解读Attention-Based方法解决遮挡人脸识别问题(附论文下载)


效率新秀 | 详细解读:如何让EfficientNet更加高效、速度更快

浏览 69
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报