让扩散模型听话的小秘籍?CAN:通过操控权重来控制条件生成模型,图像生成效率大升级!

共 12373字,需浏览 25分钟

 ·

2024-04-28 22:00

↑ 点击蓝字 关注极市平台
作者丨科技猛兽
编辑丨极市平台

极市导读

 

本文提出的 CAN 模型 (Condition-Aware Neural Network) 是一种对图像生成模型添加控制的方法。CAN 可以通过动态操纵神经网络的权重来控制图像生成过程。作者在 ImageNet 图像生成任务以及 COCO 文生图任务上面测试了 CAN 方法。CAN 始终为扩散 Transformer 模型提供显著的改进,比如 DiT 和 UViT。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

扩散模型解读合集:

1.Sora的幕后功臣?详解大火的DiT:拥抱Transformer的扩散模型

本文目录

1 CAN:条件感知的扩散模型
(来自 MIT HAN LAB, Song Han 团队)
1 CAN 论文解读
1.1 可控图像生成模型:从控制特征到控制权重
1.2 哪一层设计成条件感知?
1.3 与自适应核选择的对比
1.4 CAN 方法具体实现
1.5 实验设置和评价指标
1.6 消融实验结果
1.7 与 SOTA 模型对比

太长不看版

本文提出的 CAN 模型 (Condition-Aware Neural Network) 是一种对图像生成模型添加控制的方法。CAN 可以通过动态操纵神经网络的权重来控制图像生成过程。如图1所示,具体的方法是通过一个条件感知权重生成模块,这个模块的输入是条件 (比如类别标签,时间步),作用是为卷积/线性层生成权重。作者在 ImageNet 图像生成任务以及 COCO 文生图任务上面测试了 CAN 方法。CAN 始终为扩散 Transformer 模型提供显著的改进,比如 DiT 和 UViT。

图1:CAN 的输入是条件信息,然后动态生成神经网络的权重,与原始模型的权重混合

本文做了什么工作

  1. 引入了一个控制图像生成模型的新机制:通过操控权重来控制条件生成模型。
  2. 提出了条件感知神经网络,一种用于条件图像生成的控制方法。
  3. CAN 可以用来提升图像生成模型的性能,大大优于之前的方法,而且对部署很有帮助。比如在 ImageNet 512×512 的图像生成任务,CAN 方法的 FID 比 DiT-XL/2 更小,且每个采样步骤的 MACs 少 52 倍,为边缘设备上的扩散模型应用提供支持。

效果图如下图2所示。

图2:在 ImageNet 512×512 图像生成任务中不同模型的结果对比

1 CAN:条件感知的扩散模型

论文名称:Condition-Aware Neural Network for Controlled Image Generation (CVPR 2024)

论文地址:http://arxiv.org/pdf/2404.01143.pdf

1.1 可控图像生成模型:从控制特征到控制权重

大规模图像和视频生成模型[1][2][3]在合成逼真图像和视频方面表现出了惊人的能力。为了将这些模型转换为人类的生产工具,关键步骤是添加控制。我们希望生成模型遵循我们的指令 (例如类别标签、文本、姿势等等[4]),而不是让模型随机生成数据样本。

一些先前的工作通过添加 Cross-Attention[5]或者 Self-Attention[6]将条件特征与图像特征融合。虽然使用的操作不同,但这些方法的共同之处是:通过特征空间操作来添加控制。同时,对于不同的条件 (Condition),神经网络权重 (卷积/线性层) 保持不变。

这项工作旨在回答以下问题:

  1. 是否可以通过操控图像生成模型的权重来控制图像的生成过程?
  2. 图像生成模型是否能够受益于这种新的控制方法?

本文提出一种条件感知神经网络 (Condition-Aware Neural Network, CAN),一种基于权重空间操作的新条件控制方法。CAN 引入了一个权重生成模块来产生权重,这个模块的输入是条件的嵌入,比如用户指令 (类别标签) 和扩散模型的时间步。模块的输出是卷积层/线性层的权重。

作者通过消融实验调研了 CAN 对于扩散模型的实际作用。首先,作者发现,仔细选择一部分模块使其权重是条件感知的,而不是将所有的模块都变成条件感知的,这样做更有利于性能和效率的权衡。其次,作者发现根据条件直接生成权重比自适应地合并静态的权重更加有效。

CAN 可以单独为图像生成模型提供有效的条件控制,提供比以前条件控制方法更低的 FID 和更高的 CLIP 分数。除了将 CAN 应用于现有的扩散 Transformer 模型之外,作者还通过结合 CAN 和 EfficientViT 进一步构建了一个名为 CaT 的新型扩散 Transformer 模型。

1.2 哪一层设计成条件感知?

理论上,可以使神经网络中的所有层做成条件感知的。但在实践中,这不一定好。

首先,从性能的角度看,使用过多的条件感知层可能会使得优化过程不稳定,为模型优化带来挑战。

其次, 从效率的角度来看, 虽然生成条件权重的计算开销可以忽略不计, 但会带来显著的参数开销。假设我们定义 Condition Embedding 为 (比如 384,512,1024 等等), 模型的静态参数大小为 #params 。使用单个线性层将 Condition Embedding 映射到条件权重需要 #params 个参数。这对于现实世界的使用是不切实际的。因此在这项工作中,作者只选择一部分模块来应用 CAN。

将 CAN 应用与 Diffusion Transformer 的具体做法如图3所示。Depth-Wise 卷积[7]的参数量比常规卷积小得多,将其作为条件感知的成本比较低。因此,作者按照[8]的设计,在 FFN 中间添加一个 Depth-Wise 卷积。作者使用 UViT-S/2 在 ImageNet 256×256 的图像生成任务上进行了消融实验研究,哪些模块要使用条件感知方法。在消融实验中,所有模型的架构都相同,唯一的区别是条件感知的模块不同。

图3:将 CAN 应用与 Diffusion Transformer 的方法

消融实验结果如图4所示。作者给出了两个观察:

  • 使用条件感知的模块并不总是能够提升性能,比如图4第2行和第4行,使用静态头比使用条件感知头得到更低的 FID 和更高的 CLIP 分数。
  • 将 Depth-Wise 卷积层、Patch Embedding 层和输出投影层设置为条件感知,可以带来,显著的性能提升:将 FID 从 28.32 提高到 8.82,CLIP 分数从 30.09 提高到 31.74。
图4:哪一层设计成条件感知的消融实验

基于这些结果,作者为 CAN 选择了图3这样的设计。对于 Depth-Wise 卷积层和 Patch Embedding 层,作者为每个层使用单独的条件权重生成模块,因为它们的参数大小很小。对输出投影层使用共享的条件权重生成模块,因为它们的参数量很大。由于不同的输出投影层本身具有不同的静态权重,因此不同的输出投影层还是具有不同的权重。

1.3 与自适应核选择的对比

自适应核选择 (Adaptive Kernel Selection[9][10])是另一种动态输出神经网络参数的方法。Adaptive Kernel Selection 维护了一组基本的卷积核,然后动态地生成缩放参数来组合这些卷积核。这种方法的参数开销小于 CAN。但是,这种自适应核选择策略的性能不如 CAN 方法,如图5所示。这表明仅动态参数化并不是提高性能的关键,更好的条件感知适应能力至关重要。

图5:CAN 方法比自适应核选择更有效

1.4 CAN 方法具体实现

由于条件感知层在给定不同样本的情况下具有不同的权重, 因此不能进行批量化的训练和推理。因此, 必须针对每个样本单独运行内核, 如图6左侧所示。为了解决这个问题, 作者提出了一种 CAN 的高效版本实现。核心思想是把所有的卷积核的调用封装为一个分组卷积, 其中组数 #Groups 为 Batch Size

在分组卷积之前, 作者进行了一步 batch-to-channel 的转换, 把维度为 的特征转换为维度为 的特征, 然后进行 Grouped Conv 操作。在结束之后, 再反过来通过 channel-to-batch 把特征变回原来的形状。

图6:CAN 的实际实现。左图:条件感知层对于不同的样本有不同的权重,需要为每个样本独立运行内核调用,这会给训练和批处理推理带来很大的开销。右侧:CAN 的高效版实现,将所有内核调用融合到分组卷积中

理论上,通过这种高效的实现,与运行静态模型相比,额外的训练的开销将可以忽略不计。在实践中,由于 NVIDIA GPU 相比于分组卷积,对常规卷积的支持更友好,作者仍然观察到 30%-40% 的训练开销。这个问题可以通过编写定制的 CUDA Kernel 来解决。

1.5 实验设置和评价指标

数据集: 由于资源限制,作者使用 ImageNet 数据集进行类条件图像生成实验,并使用 COCO 进行文本到图像生成实验。对于大规模的文本到图像实验[11],作者将其留给未来的工作。

评价指标: 按照常见的做法,作者使用 FID[12]作为图像质量的评估指标。此外,作者使用 CLIP 分数[13]作为可控性的指标。使用公共 CLIP ViT-B/32[14]来测量 CLIP 分数,遵循[15]的做法。text prompt 按照 CLIP 的 Zero-Shot 图像分类设置来构建。

具体实现: 作者将 CAN 应用在了 DiT[16]和 UViT[17]模型中。所有模型都使用无分类器指导 (Classifier-Free Guidance),除非另有说明。基线模型的架构与 CAN 模型的架构相同,在 FFN 层中有 Depth-Wise Convolution。在训练期间使用自动混合精度。除了将 CAN 应用于现有模型外,作者还通过将 CAN 和 EfficientViT 结合起来构建了一个称为 CaT 的新型 Diffusion Transformer 模型。宏观的架构如图7所示。

图7:CaT 模型宏观架构

1.6 消融实验结果

除非另有说明,否则作者在消融实验中训练 80 个 Epoch,Batch Size 为 1024。所有模型都使用 DPM-Solver[18] 和 50 步对图像进行采样。如下图8所示为 UViT 和 DiT 模型应用了 CAN 方法之后的结果。CAN 显著地提高所有模型的生成图像质量和可控性,而且这些改进的计算成本开销可以忽略不计。

图8:在不同 UViT 和 DiT 模型上的实验结果

下图9比较了 CAN 方法在 UViT-S/2 和 DiT-S/2 上的训练曲线。可以看到,当两个模型的训练时间更长时,绝对的改进仍然显著。这表明改进不是由于更快的收敛。相反,添加 CAN 可以提高模型的性能上限。

图9:训练曲线

对于扩散模型,Condition Embedding 包含类别标签和时间步长。为了剖析哪个对条件权重生成过程更重要,作者使用 UViT-S/2 进行了消融实验,并把结果总结在图10中,可以发现:

  • 类别标签信息比权重生成过程中的时间步信息更重要。 仅添加类别标签比单独添加时间步得到更好的 FID 和 CLIP 分数。
  • 在 Condition Embedding 中包含类别标签和时间步长可以获得最佳的结果。 因此,在接下来的实验中,作者坚持这种设计。
图10:条件类型的消融实验结果

与之前的条件控制方法对比

为了对比 CAN 和之前的条件控制方法,作者在下图11中展示了实验结果,并有以下发现:

  • CAN 就已经可以作为一种有效的条件控制方法。
  • CAN 可以与其他条件控制方法相结合,以获得更好的结果。
  • 对于 UViT 模型,将 CAN 与注意力 (Condition 作为 tokens) 相结合会略微损害性能。因此,在接下来的实验中作者仅在 UViT 模型上使用 CAN。
图11:与之前的条件控制方法对比实验结果

1.7 与 SOTA 模型对比

作者将本文 CaT 模型与其他方法在 ImageNet 图像生成任务和 COCO 文生图任务中进行了比较,实验结果如图12和15所示。对于 CaT 模型,作者使用了 UniPC[19]的技术加速采样。

ImageNet 256×256 类别条件图像生成任务

使用无分类器指导 (classifier-free guidance, cfg),本文的 CaT-B0 在 ImageNet 上实现了 2.09 的 FID,超越了 DiT-XL/2 和 UViT-H/2。更重要的是,CaT-B0 比这些模型的计算效率要高得多:MAC 比 DiT-XL/2 少 9.9 倍,MAC 比 UViT-H/2 少 11.1 倍。在没有无分类器指导的情况下,CaT-B0 在所有比较模型中也实现了最低的 FID。

ImageNet 512×512 类别条件图像生成任务

在更具挑战性的 512×512 图像生成任务中,可以观察到 CAN 的涨点变得更加显著。例如,CAN (UViT-S-Deep/4) 可以匹配 UViT-H (4.04 vs. 4.05) 的性能,而每个扩散步骤只需要 12% 的 UViT-H 的计算成本。此外,CaT-L0 在 ImageNet 512×512 上得到 2.78 的 FID,优于 DiT-XL/2 (3.04 FID),DiT-XL/2 的每个扩散步骤需要 52× 高的计算成本。此外,通过缩放模型, CaT-L1 进一步将 FID 从 2.78 提高到 2.48。

图12:ImageNet 类别条件图像生成任务实验结果

除了计算成本比较之外,图13也比较了在 NVIDIA Jetson AGX Orin 上的 CaT-L0 和 DiT-XL/2 的延时。延迟是用 TensorRT, fp16 测量的。CaT-L0 可以在 ImageNet 512×512 图像生成任务中得到更好的 FID 结果,且在与快速采样方法 UniPC 结合之后,在 Orin 上的运行比 DiT-XL/2 快 229 倍。

图13:NVIDIA Jetson AGX Orin 上的 Latency 和 FID 结果对比

除了定量结果之外,下图14为 CAN 模型随机生成的图像的样本,证明了本文模型在生成高质量图像方面的能力。

图14:CAN 模型随机生成的图片样本

COCO 256×256 文生图任务

对于 COCO 文生图实验,作者遵循 UViT 中使用的相同设置。模型在 COCO 2014 训练集上从头开始训练。在 UViT 之后,作者从 COCO 2014 验证集中随机抽取 30K 个文本提示来生成图像,然后计算 FID。作者使用与 UViT 中相同的 CLIP 编码器来编码文本提示。

图15:COCO 256×256 文生图任务实验结果

实验结果如图15所示,CaT-S0 实现了与 UViTS-Deep/2 相似的 FID 结果,同时计算成本要低得多 (19GMACs → 3GMACs),证明了本文模型的泛化能力。这个实验说明 CAN 方法不仅仅适用于图像生成任务,也适用于文生图任务。

参考

  1. ^High-resolution image synthesis with latent diffusion models
  2. ^Video generation models as world simulators
  3. ^Stable video diffusion: Scaling latent video diffusion models to large datasets
  4. ^Adding conditional control to text-to-image diffusion models
  5. ^High-Resolution Image Synthesis with Latent Diffusion Models
  6. ^All are Worth Words: a ViT Backbone for Score-based Diffusion Models
  7. ^Xception: Deep Learning with Depthwise Separable Convolutions
  8. ^EfficientViT: Lightweight Multi-Scale Attention for High-Resolution Dense Prediction
  9. ^Scaling up GANs for Text-to-Image Synthesis
  10. ^CondConv: Conditionally Parameterized Convolutions for Efficient Inference
  11. ^Improving image captioning with better use of captions
  12. ^GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium
  13. ^CLIPScore: A Reference-free Evaluation Metric for Image Captioning
  14. ^Learning Transferable Visual Models From Natural Language Supervision
  15. ^Improving Image Captioning with Better Use of Captions
  16. ^Scalable Diffusion Models with Transformers
  17. ^All are Worth Words: a ViT Backbone for Score-based Diffusion Models
  18. ^DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps
  19. ^UniPC: A Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models

公众号后台回复“极市直播”获取100+期极市技术直播回放+PPT

极市干货

极视角动态2023GCVC全球人工智能视觉产业与技术生态伙伴大会在青岛圆满落幕!极视角助力构建城市大脑中枢,芜湖市湾沚区智慧城市运行管理中心上线!
数据集:面部表情识别相关开源数据集资源汇总打架识别相关开源数据集资源汇总(附下载链接)口罩识别检测开源数据集汇总
经典解读:多模态大模型超详细解读专栏

极市平台签约作者#


科技猛兽

知乎:科技猛兽


清华大学自动化系19级硕士

研究领域:AI边缘计算 (Efficient AI with Tiny Resource):专注模型压缩,搜索,量化,加速,加法网络,以及它们与其他任务的结合,更好地服务于端侧设备。


作品精选

搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了
用Pytorch轻松实现28个视觉Transformer,开源库 timm 了解一下!(附代码解读)
轻量高效!清华智能计算实验室开源基于PyTorch的视频 (图片) 去模糊框架SimDeblur



投稿方式:
添加小编微信Fengcall(微信号:fengcall19),备注:姓名-投稿
△长按添加极市平台小编

觉得有用麻烦给个在看啦~  

浏览 246
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报