拣回U-Net架构后,计算成本降了80%?华为诺亚等提出用于图像生成的 U 型扩散模型U-DiT ,

受到 U-Net 的主干特征是以低频为主的启发,作者对 Self-attention 中的 Query 和 Key 执行了 token 的下采样。计算量显著减少,而且带来了性能上的进一步改进。基于这个方法作者提出了 U-DiT 模型,在计算成本仅为 1/6 的情况下优于 DiT-XL/2。

扩散 Transformer (DiTs) 将 Transformer 架构引入到扩散任务中,用于 latent 空间图像生成。DiT 使用连续的 Transformer 块构建,展示出了极具竞争力的性能以及良好的可扩展性。但与此同时,DiT 放弃了 U-Net 架构这个举动也同样值得重新思考。

为此,作者进行了一个简单的实验来比较 U-Net 架构下的 DiT 以及原始 DiT 架构的优越性。结果表明,U-Net 架构的 DiT 仅获得了轻微的优势,表明 U-Net 风格的 DiT 中存在潜在的冗余。受到 U-Net 的主干特征是以低频为主的启发,作者对 Self-attention 中的 Query 和 Key 执行了 token 的下采样。计算量显著减少,而且带来了性能上的进一步改进。基于这个方法作者提出了 U-DiT 模型,在计算成本仅为 1/6 的情况下优于 DiT-XL/2。

U-DiT_Self

图1:左:U-DiT 和其他架构的性能和 FLOPs 比较。右:不同模型尺寸的 U-DiT 和 DiT 的比较

1 U-DiT:一种用于图像生成的 U 型扩散模型

论文名称:U-DiTs: Downsample Tokens in U-Shaped Diffusion Transformers (Arxiv 2024.05)

论文地址: http://arxiv.org/pdf/2405.02730

代码链接: http://github.com/YuchuanTian/U-DiT

1.1 U-Net 架构怎么就被 DiT 抛弃了?

DiT[1]将完整的 Transformers[2]的主干引入到扩散模型中,在图像空间和潜在空间生成任务上表现出出色的性能和可扩展性。最近的后续工作通过将扩散变换器的应用扩展到灵活的分辨率图像生成[3]、逼真的视频生成[4]等, 证明了扩散 Transformer 的前景。

但是有趣的是,以 DiT 为代表的很多工作都抛弃了之前工作极其常见的 U-Net[5]架构。U-Net 架构无论是在 Pixel Space[6][7],还是在 Latent Space[8]中都很常用。而 DiT 中使用直筒型的架构确实是成功的,因为放大的 DiT 模型实现了最高的性能。但是,DiT 放弃了广泛使用的 U-Net 架构,及其在 Latent Space 作图像生成任务的改进激发了作者的好奇心,因为 U-Net 架构带来的归纳偏置总是被认为有利于去噪。因此,作者重新思考在规范的 U-Net 架构上面去部署 DiT 模型。

为了测试 U-Net 架构与 DiT 模型的结合,作者首先提出了一个朴素的 U-Net 风格的 DiT (DiT-UNet),并将其与原始 DiT 进行比较。结果表明,在计算成本相近的情况下,DiT-UNets 仅与 DiTs 相当。从这个 toy experiment 中我们可以推断,当 U-Net 架构与 Transformer Block 简单组合时,其归纳偏置没得到充分利用。

因此,作者重新思考了 DiT-UNet 中的 Self-attention 机制。latent U-Net 去噪模型的 Backbone 架构提供了一个低频分量主导的特征[9]。这个发现意味着 Backbone 的特征中存在冗余,即 U-Net diffuser 中的 Self-attention 应该更加突出低频分量。之前的一些理论研究[10]提出下采样来过滤高频噪声分量,本文通过对 Self-attention 的特征执行 token 下采样来利用这种自然的低通滤波器。作者对 query,key,value 都执行了 token 下采样。如此一来,Self-attention 就是在下采样的 Latent Space 中进行了。令人惊讶的是,当作者将 Self-attention 与下采样 token 合并到 DiT-UNet 中时,Latent 扩散模型取得了更好的结果,而且计算量显著减少。

1.2 在 Latent 空间探索 U-Net DiT 模型

对 U-Net 去噪模型的理论研究[10]表明了它们的优势,即下采样的 U-Net 可以过滤主导高频的噪声。作者通过一个 toy experiment,重新思考了在 Latent Space 下的 U-Net 去噪模型。

U-DiT_去噪_02

图2:从 DiT 到本文的 U-DiT。(a) 原始 DiT 架构;(b) toy experiment 中的 DiT-UNet 架构;(c) 本文提出的 U-DiT 架构

首先如图 2(b) 所示,作者直接将 DiT Block 嵌入到原始 U-Net 架构中,提出了一种名为 DiT-UNet 的基于 Transformer 的 U-Net 降噪器。遵循 U-Net 的设计,DiT-UNet 由一个 Encoder 和一个具有相同阶段数量的 Decoder 组成。Encoder 通过下采样处理输入图像;Decoder 将编码后的图像从最压缩重新扩展到输入大小。在每个 Encoder 的 stage 转换中,执行因子为2的空间下采样,同时特征的维度加倍。在每个 stage 转换中都有 skip-connection。skipped 的特征与上个 stage 的 Decoder 的输出特征进行 Concat 和 Fuse 操作,来弥补上采样带来的信息损失。

考虑到 Latent Space 的维度较小 (256×256 大小的原始图片对应 32×32 的 Latent Space),作者指定了3个 stages,即特征被下采样2次,然后上采样2次恢复到其原始大小。为了跨多尺度 stage 拟合各种特征维度的 time embedding 和 condition embedding,作者对各个 stage 使用独立的 Embedder。

通过 toy experiment,作者将所提出的 U-Net 风格的 DiT 与原始的 DiT 进行比较。为了将模型与 DiT 设计对齐,作者在每个阶段重复使用普通的 DiT Block。每个 DiT Block 包括一个 Self-attention 模块作为 token mixer,一个2层前馈网络作为 channel mixer。作者训练 U-Net-Style DiT 400K iterations,并将其与 DiT-S/4 进行比较。所有训练超参数保持不变。实验结果如下图3所示,U-Net 风格的 DiT 模型相比于原始 DiT 模型只表现出了很有限的优势,说明 U-Net 架构的归纳偏置没有得到充分利用。

U-DiT_Self_03

图3:U-Net 风格的 DiT 的 toy experiments 实验结果

1.3 通过 token 下采样增强 U-Net 风格的 DiT 模型

为了更好地将注意力纳入到 U-Net 扩散模型中,作者回顾了 U-Net Backbone 作为扩散去噪器的作用。最近一项关于 Latent 扩散模型的工作[9]对 U-Net Backbone 的中间特征进行了频率分析,得出结论是能量集中在低频域。这种频域发现暗示了 Backbone 中潜在的冗余,即:U-Net Backbone 应该从全局角度突出粗略对象,而不是高频细节。

很自然地,作者使用了带有下采样 token 的注意力机制,因为下采样操作天然地是丢弃高频分量的低通滤波器。下采样操作的低通滤波属性已经在扩散模型的场景下研究过,得出的结论是下采样有助于扩散模型的去噪器,因为它自动地 "丢弃了噪声主导的高频子空间[10]"。因此,作者选择对 token 进行下采样操作进行 Self-attention

事实上,对下采样的 tokens 进行 attention 这件事并不新鲜。在视觉 Transformer 中,[11][12]曾提出对 Key-Value 对进行下采样以降低计算成本。最近关于无训练扩散加速的工作[13]也对稳定扩散模型应用键值下采样。但是这些工作保持了 Query 的数量。而且,这些下采样度量通常涉及减小张量大小,这可能会导致信息丢失。

与这些工作不同的是,本文提出了一种简单的 token 下采样方法,即:同时对 Query、Key 和 Value 进行下采样,以实现对扩散模型友好的 Self-attention,同时保持整体张量大小以避免信息丢失。

流程是:特征图输入首先通过下采样器转换为4个 2× 的下采样特征。然后,每个下采样的特征用来计算 Self-attention。在这之后,将下采样的 tokens 在空间上合并为一个统一的,以恢复原始数量的 tokens。值得注意的是,在整个过程中,特征的维度保持不变。与 U-Net 下采样不同,作者在下采样过程中不减少或增加特征中的元素数量。相反,则是以并行的方式将4个下采样的 token 逐个进行 Self-attention 操作。

使用下采样 tokens 的 Self-attention 的确有助于 DiT-UNet 扩散模型的性能提升。如上图3所示,FLOPs 显著减少,且 FID 指标还有轻微提升。

1.4 复杂度分析

U-DiT_去噪_04

每个 Self-attention 操作仅花费 1/16 的全尺寸 Self-attention 的计算复杂度,组数变为4之后,下采样 tokens 的 Self-attention 的计算复杂度仅仅为原来的 1/4,通过 token 下采样节约了 Self-attention 3/4 的计算成本。

1.5 模型的可扩展性

作者使用 DiT 的训练设置,采用为 Latent 扩散模型设计的 VAE (sd-vae-ft-ema)[14],以及 AdamW 优化器。训练的数据集是 ImageNet[15]。

作者还额外把一些已有的技术用在了 U-DiT 上:cosine similarity attention[16][17],RoPE2D[18][19][20],depthwise conv FFN[21][22][23],re-parametrization[24][25]。

消融实验结果

如下图4所示是这部分的消融实验结果。可以首先看到,把 full-scale self-attention 换成 downsampled self-attention 可以减少约 1/3 的 FLOPs。为了评估通过模型性能下采样的改进,作者还设计了一个精简版的 DiT-UNet,即 DIT-UNet (Slim)。这个 DiT-UNet (Slim) 使用完整的 Self-attention,但只花费与 U-DiT 大致相同的计算量 (~ 0.9GFLOPs)。从图4的结果显示,DiT-UNet 中的下采样 tokens 可以带来 ~18 FID 的性能提升。

同时也看得出来,已有的技术如 cosine similarity attention 带来了 ~2.5FIDs 的性能提升;RoPE2D 带来了 ~2.5FIDs 的性能提升;depthwise convolution layer 带来了 ~5FIDs 的性能提升;re-parametrization 带来了约 ~3.5FIDs 的性能提升。

U-DiT_人工智能_05

图4:U-DiT 组件的消融实验结果

下采样操作将完整的特征转换为多个空间下采样特征。作者指出,之前的方法一般是做 Pixel Shuffling,或者是在 Pixel Shuffling 之前通过一层卷积,但是会带来较大的计算代价。因此作为折衷方案,作者改为使用 Depth-wise Convolution。作者还为这个 Depth-wise Convolution 加了一个 Shortcut 操作。这个操作带来的额外计算成本可以忽略不计,而且,可以在推理阶段删掉。实验结果如图5所示。

U-DiT_性能提升_06

图5:不同下采样器的消融实验结果

可扩展性实验结果

为了验证所提出的 U-DiT 模型的有效性,作者将它们放大并与更大尺寸的 DiT 模型进行比较。为了公平比较,作者使用与 DiT 相同的训练超参数:所有模型都训练 400K iterations。ImageNet 256×256 的结果如下图6所示。作者还将 U-DiTs 分别扩展到 ~6e9、~20e9、~80e9 FLOPs,并将它们与相似计算成本的 DiT 模型进行比较。  Taobao 天皓智联 

U-DiT_人工智能_07

图6:放大模型之后的实验结果

所有 U-DiT 模型都可以以相当大的优势打败 DiT 模型。具体来说,U-DiT-S 和 U-DiT-B 可以超过 DiT ~30FID,U-DiT-L 可以超过 DiT-XL/2 ~10FID。U-DiT-B 可以在计算成本仅为 1/6 的情况下胜过 DiT-XL/2。

除了 DiTs 和 U-DiTs,作者还在图1中对比了很多其他方法,比如:SiT[26],它为 DiT 提出了一个插值框架,SiT-LLaMA 结合了 DiT 的 Backbone VisionLLaMA[20] 和 SiT。U-DiTs 相对于其他基线的优势在图1中尤为突出。

这些结果证实了 U-DiT 模型的可扩展性。

U-DiTs 在具有 classifier-free guidance 的生成场景中也表现出色。作者将 U-DiT 与 cfg = 1.5 的 DiT 进行比较。为了公平比较,作者在相同的设置下训练 U-DiT 和 DiT 400K iterations,结果如图7所示。

U-DiT_性能提升_08

图7:classifier-free guidance 的实验结果

延长训练步数

作者通过将 training iterations 扩展到 1M 来评估 U-DiT 的潜力。图1也进一步证明 U-DiT 的优势在所有训练步骤中都是一致的。随着训练步骤逐渐上升到 1M,U-DiT 的性能不断提高,如图8所示。作者也可视化了图像质量逐渐变好的过程 (图9)。如图10所示,U-DiT-L 模型只需 1M training iterations 就可以有条件地生成真实图像。

U-DiT_Self_09

图8:U-DiT-B 和 U-DiT-L 的性能随着训练迭代数的变化

U-DiT_性能提升_10

图9:随着训练的继续,生成样本的质量在逐步改进

U-DiT_人工智能_11

图10:训练了 1M iterations 的 U-DiT-L 模型的生成的样本