DiG:使用门控线性注意力机制的高效可扩展 Diffusion Transformer

极市平台

共 23824字,需浏览 48分钟

 ·

2024-06-21 22:00

↑ 点击蓝字 关注极市平台
作者丨科技猛兽
编辑丨极市平台

极市导读

 

在相同的模型尺寸下,DiG-XL/2 比基于 Mamba 的扩散模型在 1024 的分辨率下快 4.2 倍,在 2048 的分辨率下比带有 CUDA 优化的 FlashAttention2 的 DiT 快 1.8 倍。这些结果都证明了其优越性能。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

本文目录

1 DiG:使用门控线性注意力机制的高效可扩展 Diffusion Transformer
(来自华科,字节跳动)
1 DiM 论文解读
1.1 DiG:一种轻量级 DiT 架构
1.2 门控线性注意力 Transformer
1.3 扩散模型
1.4 Diffusion GLA 模型
1.5 DiG Block 架构
1.6 复杂度分析
1.7 实验结果

太长不看版

Diffusion Transformer 模型面临的一个问题是计算复杂度与序列长度呈二次方关系,这不利于扩散模型的缩放。本文通过门控线性注意力机制 (Gated Linear Attention) 的长序列建模能力来应对这个问题,来提升扩散模型的适用性。

本文提出的模型称为 Diffusion Gated Linear Attention Transformers (DiG),是一种基于门控线性注意力机制和 DiT[1]的简单高效的扩散 Transformer 模型。除了比 DiT 更好的性能外,DiG-S/2 的训练速度比 DiT-S/2 高 2.5 倍,并在 1792 的分辨率节省 75.7% 的 GPU 显存。此外,作者分析了 DiG 在各种计算复杂度下的可扩展性。结果是随着模型的缩放,DiG 模型始终表现出更优的 FID。作者还将 DiG 与其他次 subquadratic-time 的扩散模型进行了比较。在相同的模型尺寸下,DiG-XL/2 比基于 Mamba 的扩散模型在 1024 的分辨率下快 4.2 倍,在 2048 的分辨率下比带有 CUDA 优化的 FlashAttention2 的 DiT 快 1.8 倍。这些结果都证明了其优越性能。

本文做了哪些具体的工作

  1. 提出了 Diffusion GLA (DiG),通过分层扫描和局部视觉感知进行全局视觉上下文建模。DiG 使用线性注意力 Transformer 来实现 diffusion backbone。
  2. DiG 在训练速度和 GPU 显存成本方面都表现出更高的效率,同时保持与 DiT 相似的建模能力。具体而言,DiG 比 DiT 快 2.5 倍,并在 1792×1792 的分辨率中节省 75.7% 的 GPU 显存,如图1所示。
  3. 作者在 ImageNet 数据集上进行了广泛的实验。结果表明,与 DiT 相比,DiG 表现出可扩展的能力并实现了卓越的性能。在大规模长序列生成的背景下,DiG 有望成为下一代 Backbone。
图1:DiT、DiS 和 DiG 模型的效率比较。DiG 在处理高分辨率图像时实现了更高的训练速度,同时成本更低的 GPU 显存
图2:DiS、DiT、带有Flash Attention-2 (Flash-DiT) 的 DiT 和不同模型大小的 DiG 模型之间的 FPS 对比

1 DiG:使用门控线性注意力机制的高效可扩展 Diffusion Transformer

论文名称:DiG: Scalable and Efficient Diffusion Models with Gated Linear Attention (Arxiv 2024.05)

论文地址:

http://arxiv.org/pdf/2405.18428

代码链接:

http://github.com/hustvl/DiG

1.1 DiG:一种轻量级 DiT 架构

扩散模型以其生成高质量的图像生成能力而闻名。随着采样算法的快速发展,主要技术根据其 Backbone 架构演变为2个主要类别:基于 U-Net 的方法[2]和基于 ViT 的方法[3]。基于 U-Net 的方法继续利用卷积神经网络 (CNN) 架构,其分层特征建模能力有利于视觉生成任务。另一方面,基于 ViT 的方法结合注意力机制。由于其出色的性能与可扩展性,基于 ViT 的方法已被用作最先进的扩散工作中的 Backbone,包括 PixArt、Sora、Stable Diffusion 3 等。然而,基于 ViT 的架构的 Self-attention 机制与输入序列长度呈二次方关系,使得它们在处理长序列生成任务 (例如高分辨率图像生成、视频生成等) 时资源消耗较大。最近的架构 Mamba[4]、RWKV[5]和 Gated Linear Attention Transformer (GLA)[6],试图通过集成 RNN 类的架构,以及硬件感知算法来提高长序列处理效率。其中,GLA 将依赖于数据的门控操作和硬件高效的实现结合到线性注意力 Transformer 中,显示出具有竞争力的性能,但吞吐量更高。

受 GLA 在自然语言处理领域的成功的启发,作者将这种成功从语言生成转移到视觉内容生成领域,即使用高级线性注意力设计可扩展且高效的 Diffusion Backbone。然而,使用 GLA 进行视觉生成面临两个挑战,即单向扫描建模和缺乏局部信息。为了应对这些挑战,本文提出了 Diffusion GLA (DiG) 模型,该模型结合了一个轻量级的空间重定向和增强模块 (Spatial Reorient & Enhancement Module, SREM),用于分层扫描方向控制和局部感知。扫描方向包含四个基本模式,并使序列中的每个 Patch 能够感知沿纵横方向的其他 Patch。此外,作者还在 SREM 中加入了深度卷积 (DWConv),使用很少的参数为模型注入局部信息。

1.2 门控线性注意力 Transformer

Gated Linear Attention Transformer (GLA) 结合依赖于数据的门控机制和线性注意力, 实现了卓越的循环建模性能。给定输入 ( 是序列长度, 是维度),GLA 计算 Query、Key 和 Value 向量:

式中 是线性投影权重。 是维度数。接下来, GLA 计算门控矩阵 ,如下所示:

其中 是 token 的索引, 是 sigmoid 函数, 是偏置项, 是温度项。如图3所示, 最终输出 如下:

图3:GLA Pipeline

其中, Swish 是 Swish 激活函数, 是逐元素乘法运算。在接下来的部分中, 使用 来指代输入序列的门控线性注意力计算。

1.3 扩散模型

DDPM[7]通过迭代去噪输入将噪声作为输入和采样图像。DDPM 的前向过程是随机过程,其中初始图像 逐渐被噪声破坏,最后转化为更简单、噪声主导的状态。前向噪声过程可以表示如下:

其中 是从时间 的噪声图像序列。然后, DDPM 使用可学习的 恢复原始图像的反向过程:

其中, 是去噪模型的参数, 使用 variational lower bound 在观测数据 的分布下训练:

其中, 是总的损失函数。为了进一步简化 DDPM 的训练过程, 研究人员将 重参数化为噪声预测网络 , 使 与真实高斯噪声 之间的均方误差损失 做最小化:

然而, 为了训练能够学习反向过程协方差  Σ 𝜃  的扩散模型, 就需要优化完整的  𝐷 𝐾 𝐿  项。本文作者遵循 DiT 训练网络, 其中使用损失  𝐿 simple   来训练噪声预测网络  𝜖 𝜃 , 并使用全损失  𝐿  来训练协方差预测网络  Σ 𝜃  。

1.4 Diffusion GLA 模型

本文提出了 Diffusion GLA (DiG),一种用于生成任务的新架构。本文的目标是尽可能忠实于标准的 GLA 架构,以保持其缩放能力和高效率的特性。GLA 的概述如图 3 所示。

标准 GLA 一般用于一维序列的因果语言建模。为了适配图像的 DDPM 训练, 本文遵循 ViT 架构的实践。DiG 以 VAE 编码器的输出的空间表征 作为输入。对于 的图像, VAE 编码器的空间表征 的形状为 。DiG 随后通过 Patchify 层将空间输入转换为 token 序列 , 其中 为序列的长度, 为空间表示通道数, 为图像补丁的大小, 因此 的减半将使得 变为 4 倍。接下来, 将 线性投影到维度为 的向量上, 并将基于频率的位置嵌入 添加到所有投影 token 中, 如下所示:

其中 的第 个 Patch, 是可学习的投影矩阵。至于噪声时间步 和类标签 等条件信息, 作者分别采用多层感知 (MLP) 和嵌入层作为 timestep embedder 和 label embedder。

其中 是 time Embedding, 是 label Embedding。然后, 作者将令牌序列 发送到 DiG 编码器的第 层, 得到输出 。最后, 对输出标记序列 进行归一化, 并将其馈送到线性投影头以获得最终预测的噪声 和预测的协方差 , 如下所示:

其中, 是第 个扩散 Block, 是层数, Norm 是归一化层。 和预测的协方差 与输入空间表示具有相同的形状, 即

1.5 DiG Block 架构

原始的 GLA Block 以循环格式处理输入序列,这只能对 1-D 序列进行因果建模。本文提出的 DiG 的 Block 架构集成了一种空间重定向和增强模块 (Spatial Reorient & Enhancement Module, SREM),用于控制逐层扫描方向。DiG Block 架构如下图4所示。

图4:DiG 模型架构

作者通过调整回归自适应层范数 (adaLN) 参数来启动门控线性注意 (GLA) 和前馈网络 (FFN)。

图5:DiG 算法流程

然后,作者把序列改为 2D 的形状,并使用一个轻量级的 3×3 深度卷积来感知局部空间信息。但使用传统的 DWConv2d 初始化会导致收敛速度慢,因为卷积权重分散在周围。为了解决这个问题,作者提出了 Identity 初始化,将卷积核中心设置为1,将周围其他设置为0。最后,每两个块执行转置 2D token 矩阵,并翻转展平的序列,来控制下一个 Block 的扫描方向。如图4右侧所示,每层只处理一个方向的扫描。

1.6 复杂度分析

DiG 架构共有4种尺寸,分别是 DiG-S, DiG-B, DiG-L, 和 DiG-XL,配置如下图6所示。其参数量从 31.5M 到 644.6M,计算量从 1.09GFLOPs 到 22.53GFLOPs。值得注意的是,与相同大小的基线模型 (即 DiT) 相比,DiG 只消耗 77.0% 到 78.9% 的 GFLOPs。

图6:DiG 架构配置

GPU 包含两个重要的组件, 即高带宽内存 (HBM) 和 SRAM。HBM 具有更大的内存大小, 但 SRAM 具有更大的带宽。为了以并行形式充分利用 SRAM 和建模序列, GLA 将整个序列拆分为许多块, 可以在 SRAM 上完成计算。定义块的尺寸为  𝑀 , 训练复杂度是 𝑂 ( 𝑇 𝑀 ( 𝑀 2 𝐷 + 𝑀 𝐷 2 ) ) = 𝑂 ( 𝑇 𝑀 𝐷 + 𝑇 𝐷 2 )  。当  𝑇 < 𝐷  时, 略小于传统注意力机制的计算复杂度  𝑂 ( 𝑇 2 𝐷 )  。此外, DiT Block 中的 Depth-wise 卷积和高效矩阵运算也保证了效率。

1.7 实验结果

作者使用 ImageNet 进行 class-conditional 图像生成任务的训练,分辨率为 256×256。作者使用水平翻转作为数据增强,使用Frechet Inception Distance (FID)、Inception Score、sFID 和 Precision/Recall 来衡量生成性能。

使用恒定学习率为 1e-4 的 AdamW 优化器。遵循 DiT 的做法在训练期间对 DiG 权重进行指数移动平均 (EMA),衰减率为 0.9999。使用 EMA 模型生成图像。对于 ImageNet 的训练,使用现成的预训练的 VAE。

如下图7所示,作者分析了所提出的空间重定向和增强模块 (SREM) 的有效性。作者将 DiT-S/2 作为基线方法。原始的 DiG 模型只有 causal modeling,计算量和参数量都很少。但是因为缺乏全局上下文,因此 FID 很差。作者首先向 DiG 添加双向扫描,并观察到了显著的改进,证明了全局上下文的重要性。而且,使用 Identity 初始化的 DWConv2d 也可以大大提高性能。DWConv2d 的实验证明了 Identity 初始化和局部信息的重要性。最后一行的实验表明,完整的 SREM 可以实现最佳的性能,且同时关注局部和全局上下文。

图7:SREM 模块的消融实验结果

缩放模型尺寸

作者研究了 DiG 在 ImageNet 上的四种不同模型尺度之间的缩放能力。如图 8(a) 所示,随着模型从 S/2 扩展到 XL/2,性能有所提高。结果表明了 DiG 的缩放能力,以及作为基础扩散模型的潜力。

Patch Size 的影响

作者在 ImageNet 上训练了 Patch Size 从 2、4 和 8 不等的 DiG-S。如图 8(b) 所示,通过减少 DiG 的 Patch Size,可以在整个训练过程中观察到明显的 FID 优化。因此,最佳性能需要更小的 Patch Size 和更长的序列长度。与 DiT 基线相比,DiG 在处理长序列生成任务方面更有效。

图8:DiG 模型大小和 Patch Size 的缩放分析

作者将所提出的 DiG 与基线方法 DiT 进行比较,二者具有相同的超参数,结果如下图9所示。所提出的 DiG 在 400K 训练迭代的4个模型尺度上优于 DiT。此外,与以前的最先进方法相比,classifier-free guidance 的 DiG-XL/2-1200K 也显示出具有竞争力的结果。

图9:ImageNet 256×256 class-conditional 图像生成任务实验结果

图10展示了从 DiG-XL/2 中采样的结果,这些结果来自 ImageNet 训练的模型,分辨率为 256×256。结果表明,DiG 生成结果的正确的语义和精确的空间关系。

图10:DiG-XL/2 模型生成结果

参考

  1. ^Scalable Diffusion Models with Transformers
  2. ^Denoising Diffusion Probabilistic Models
  3. ^An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
  4. ^Mamba: Linear-Time Sequence Modeling with Selective State Spaces
  5. ^RWKV: Reinventing RNNs for the Transformer Era
  6. ^abGated Linear Attention Transformers with Hardware-Efficient Training
  7. ^Denoising Diffusion Probabilistic Models

公众号后台回复“极市直播”获取100+期极市技术直播回放+PPT

极市干货

极视角动态2023GCVC全球人工智能视觉产业与技术生态伙伴大会在青岛圆满落幕!极视角助力构建城市大脑中枢,芜湖市湾沚区智慧城市运行管理中心上线!
数据集:面部表情识别相关开源数据集资源汇总打架识别相关开源数据集资源汇总(附下载链接)口罩识别检测开源数据集汇总
经典解读:多模态大模型超详细解读专栏

极市平台签约作者#


科技猛兽

知乎:科技猛兽


清华大学自动化系19级硕士

研究领域:AI边缘计算 (Efficient AI with Tiny Resource):专注模型压缩,搜索,量化,加速,加法网络,以及它们与其他任务的结合,更好地服务于端侧设备。


作品精选

搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了
用Pytorch轻松实现28个视觉Transformer,开源库 timm 了解一下!(附代码解读)
轻量高效!清华智能计算实验室开源基于PyTorch的视频 (图片) 去模糊框架SimDeblur



投稿方式:
添加小编微信Fengcall(微信号:fengcall19),备注:姓名-投稿
△长按添加极市平台小编

觉得有用麻烦给个在看啦~  

浏览 107
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报