《Scalable Diffusion Models with Transformers》

《基于transformer的可扩展扩散模型》

论文地址:
https://arxiv.org/pdf/2212.09748.pdf

项目地址:
https://github.com/facebookresearch/DiT

摘要

论文提出了一类使用transformer的扩散模型。 将其中的主干网络 U-Nettransformer替代 ,以获取更好的效果。

实验证明了transformer架构在扩散模型上的scalability能力,分析发现 DiTs速度更快(Gflops更高),并且始终具有较低的FID(FID是反应生成图片和真实图片的距离,数据越小越好)。

1GLOPs=10亿次浮点运算。是Paper里比较流行的单位。

FID是反应生成图片和真实图片的距离,数据越小越好

最大的模型DiT-XL/2ImageNet数据集,类别条件生成任务上 512×512和256×256 表现优于所有先前的扩散模型,256×256上实现了SOTAFID指标(2.27)。

1. 前言

trasnformer的提出使机器学习经历复兴。在过去的五年中,用于自然语言处理,视觉和其他几个领域的神经架构基本上都包含trasnformer。许多图像级通用模型仍然坚持这一趋势。transformer在自回归模型得到广泛使用,但是在其他通用建模框架中采用的较少。例如,扩散模型一直处于图像级最新进展的前沿生成模型,然而,它们都采用了卷积U-Net架构作为主干。

Ho等人的开创性工作首次引入了扩散模型的U-Net主干。最初我们看到像素级自回归模型和传统GANs的成功,U-Net是从Pixel CNN++继承而来。U-Net是卷积网络,主要由ResNet块组成。在与标准的U-Net相比,额外的 spatial self-attention blockstrasnformer中是必不可少的组成部分,穿插在较低分辨率下。Dhariwal和Nichol取消了U-Net的几个结构 ,例如注入条件信息的adaptive normalization layers 和 卷积层中的卷积通道数。然而,Ho等人提出的U-Net网络的高层设计在很大程度上保持完整。

通过这项工作,我们旨在揭开扩散模型的架构选择的意义,并为未来生成建模研究提供经验基线。我们表明U-Net归纳偏差对扩散模型的性能表现不是至关重要的,并且可以很容易地将它们重新与transformer等标准设计放在一起。结果就是,扩散模型很好地从结构上进行统一,符合最新的发展趋势。 通过继承来自其他领域的最佳实践和训练方法,如以及保持良好的性能,如可扩展性,鲁棒性和效率。标准化的架构将会也为跨领域研究开辟了新的可能性。

本文主要研究基于transformer的一类新的扩散模型。我们称它们为 Diffusion Transformers,或简称DiTsDiTs遵循的是的最佳实践视觉transformer (vit),已被证明可以比传统的视觉识别更有效的缩放卷积网络(例如ResNet)。

更具体地说,本文研究了不同规模的transformer在 网络复杂性vs样本质量 之间的平衡。通过在潜在扩散模型(LDMs)框架下构建 DiT设计空间并进行基准测试,其中扩散模型在VAE的潜空间中训练,我们可以成功用transformer替换U-Net主干。我们进一步显示DiTs是扩散模型的可扩展架构:网络计算复杂性(由Gflops测量)与样本质量(测量FID)。通过简单地扩大DiT和训练LDM有了高容量的骨干网(118.6 Gflops),我们可以做到在class有条件的256 × 256 ImageNet生成基准上取得了2.27 FID的最新结果。

2. 相关工作

Transformers

transformer已经取代了跨语言、视觉、强化学习 和元学习 领域特定架构。他们在不断增加的模型大小、训练计算和语言数据下显示出显著的扩展性,作为通用自回归模型和除了语言,transformer已经训练自回归预测像素。他们也在离散编码上训练过 自回归模型和掩码生成模型 ;前者具有良好的扩展性能多达20B参数。最后,transformer已经在DDPM中探索非空间数据合成。例如,对在DALL·E 2中生成CLIP图像嵌入。

在本文研究了transformer的缩放特性 用作图像扩散模型的主干。

DDPMs

去噪扩散概率模型,Denoising diffusion probabilistic models (DDPMs)

扩散和基于分数的生成模型 在图像生成尤其成功。在许多情况下,图像的性能优于 迭代对抗网络(GANs)。

过去的两年,DDPMS的改进很大程度上是由采样技术改进带来的,最著名的分类指导,重新制定扩散模型预测噪声而不是像素,并使用级联DDPM piplines,低分辨率的基础扩散模型与上采样器并行训练 。对于上面列出的所有融合模型,backbone架构选择了卷积U-Nets。当前的Work引入了一种新颖、高效的在DDPMS中引入attention,而我们研究纯transformer。

架构复杂度

对于图片生成的迭代过程,我们可以使用参数量来衡量不同模型的复杂度。一般而言,参数量来评估模型复杂度不是很合适,因为参数量并不能代表模型的计算复杂度,比如当模型参数量相同时,图片分辨率不同会导致计算复杂度上较大的差异。所以文章采用Gflops来衡量模型架构的复杂度。

3. 扩散Transformer

3.1 准备知识

扩散公式
  • DDPM 主要分为两个过程:

    • forward 加噪过程(从右往左):加噪过程是指向数据集中的真实图像逐步加入高斯噪声, 加噪过程满足一定的数学规律,不需要学习
    • reverse 去噪过程(从左往右):去噪过程是指对加了噪声的图片逐步去噪,从而还原出真实图像。去噪过程则采用神经网络模型来学习。这样一来,神经网络模型就可以从一堆杂乱无章的噪声图片中生成真实图片了。
      在这里插入图片描述

扩散模型需要训练反向过程
输入 x t , 输出 x t − 1 输入x_t, 输出x_t-1 输入xt,输出xt1
扩散过程的nosie scheduler采用简单的linear scheduler(timesteps=1000,beta_start=0.0001,beta_end=0.02),这个和SD是不同的。

其次,DiT所使用的扩散模型沿用了OpenAI的**Improved DDPM,相比原始DDPM一个重要的变化是不再采用固定的方差,而是采用网络来预测方差**。在DDPM中,生成过程的分布采用一个参数化的高斯分布来建模。

在这里插入图片描述

Classifier-free Guidance

模型在训练时,使用一个网络架构优化两个模型(uncond,cond)。众所周知,与通用采样技术相比,无分类器指导可以产生显着改进的样本,并且这种趋势也适用于DiT模型。

LDMs

潜在扩散模型,我们使用现成的卷积VAE和基于Transformer的DDPM。

3.2. Diffusion Transformer Design Space

DiTs 模型架构图,如下所示:
在这里插入图片描述

  • 左部分:我们训练了传统的latent DiT 模型。输入的latent被分解为patches , 并经过几个 DiT blocks;

  • 右部分:是 DiT blocks内部的详细结构。

Patch化

DiT的输入是通过VAE后的一个稀疏的表示z(输入为256×256×3的图片,输出得到压缩后的latent为32×32×4),其中采用的autoencoder是SD所使用的KL-f8,这就降低了扩散模型的计算量。

然后将输入转成patch,文章采用超参p=2,4,8进行对比实验。
在这里插入图片描述

DiT模块设计
  • **In-context **

    in-context条件是将时间步t 的embedding 和c 作为额外的token拼接到 DiT的输入序列中;

    输入:latent image, t的embedding 和 类别标签c

    将两个embeddings看成两个tokens合并在输入的tokens中,这种处理方式有点类似ViT中的cls token,实现起来比较简单,也不基本上不额外引入计算量。

  • Cross-attention

    DiT结构与Condition交互的方式,与原来U-Net结构类似;将两个embeddings拼接成一个数量为2的序列,然后在transformer block中插入一个cross attention,条件embeddings作为cross attention的key和value;这种方式也是目前文生图模型所采用的方式,它需要额外引入15%的Gflops。

  • Adaptive layer norm(adaLN)

    采用adaLN,这里是将time embedding和class embedding相加,然后来回归scale和shift两个参数,这种方式也基本不增加计算量。

  • adaLN-zero

    采用zero初始化的adaLN,这里是将adaLN的linear层参数初始化为zero,这样网络初始化时transformer block的残差模块就是一个identity函数;另外一点是,这里除了在LN之后回归scale和shift,还在每个残差模块结束之前回归一个scale。

上面四种嵌入,adaLN-Zero最好,DiT默认这种方式来嵌入条件embedding。DiT发现adaLN-Zero最好,但是这种方式只适合这种只有类别信息的简单条件嵌入,只需要引入一个class embedding,但对于文生图来说,条件往往是序列化的text embeddings,因此采用cross-attention通常是更合适的方式。

模型大小

与ViT大小相似,分别使用DiT-S、DiT-B、DiT-L和DiT-XL(2.5G),Gflops从0.3到118.6。
在这里插入图片描述

Transformer Decoder

在Transformer最上层需要预测噪音,因为Transformer可以保证大小与输入一致,所以在最上层使用一层线性进行decoder。

在最后一个DiT块之后,我们需要将我们的图像标记列解码为输出噪声和输出对角协方差预测。
这两个输出的形状都等于原始的spatial输入。我们使用标准的线性解码器来完成这项工作。

我们应用最终层范数(如果使用adaLN,则为自适应)和lin将每个token提前解码为p×p×2C张量,其中C为
空间输入到DiT的通道数。最后, 将解码的token重新排列到其原始空间中布局以获得预测的噪声和协方差。

4. 实验设置

我们探索了DiT 设计空间,并研究了模型类别的扩展特性。

模型使用 【结构/patch数量】 方式命名,比如【DiT-XL/2】表示模型采用DiT-XL,patch size为2。

训练

在ImageNet 256×256和512×512分辨率的数据集上训练 条件类别的 latent DiTs。初始化最后一层线性层为0,另外,其他初始化都与ViT一致。训练模型采用AdamW,常量学习率1e-4,no weight decay,batch size为256,数据增广仅使用水平翻转。和之前的很多工作不同,此处没有使用学习率warmup和正则化。

尽管没有使用这些技术,训练非常稳定,在所有模型配置中,我们没有观察到任何训练transformer时常见的峰值损失。和很多生成模型相同,我们在训练中保持了 exponential moving average (EMA) (decay 0.9999)。所有的结果都使用了EMA model。我们在所有DiT模型大小和patch 大小上使用完全相同的训练超参数。训练超参数大部分来源于ADM,并且我们并没有进行 learning rates, decay/warm-up, schedules, Adam, β1*/ β*2 or weight decays的调参

扩散

我们使用Stable Diffusion中一个现成的预训练变分自编码器。

VAE encoder使用下采样参数8—— 将RGB图像256×256×3的图像编码到32×32×4的隐空间,

扩散模型作用于隐空间,采样得到新的 latent后,使用VAE decoder将32×32×4的隐空间还原到256×256×3的图像。

我们保留了ADM 中的超参数,同时沿用了标签 和时间步embedding 的方法,扩散过程的nosie scheduler采用简单的linear scheduler(timesteps=1000,beta_start=0.0001,beta_end=0.02),

评估指标

我们使用FID测量扩展性能, FID是反应生成图片和真实图片的距离,数据越小越好。

在与之前的论文进行比较时,我们遵循惯例并使用250 DDPM 采样步骤报告 FID-50K。众所周知,FID对小范围实现很敏感。

为了保证对比的公平性,本文中移植的所有值都是通过导出样本和获得的使用ADM的TensorFlow评估套件。

本节中报告的FID 值不使用无分类的,除非另有说明。

另外我们增加了 Score [51], sFID [34] and Precision/Recall [32] 作为第二评价指标。

计算

我们在JAX上实现了所有的模型,并在TPU-v3 pods上进行训练。

DiT-XL/2, 是计算密集的模型,在全局批量大小为256的TPU v3-256 pod上以大约5.7次迭代/秒的速度进行训练

5. 实验

DiT block 设计

我们训练四个Gflop最高的DiT-XL/2模型,每个使用不同的块设计-

  • in-context(119.4 Gflops)

  • cross-attention (137.6 Gflops)

  • adaptive layer norm(adaLN, 118.6 Gflops)

  • adaLN- Zero(118.6 Gflops) ,蓝色,FID最低,效果最好

我们在训练过程中测量FID。如下图所示
在这里插入图片描述

结果显示在400 k迭代后,adaLN-Zero实现的FID是in-context FID的一半, 表明条件作用机制对模型的质量很重要。

初始化也很重要——adaLNZero,它将每个DiT块初始化,显著优于 vanilla adaLN。

在本文中,所有模型都将使用adaLN-Zero DiT块。

模型大小和patch大小

我们训练了12个DiT models,模型配置(S, B, L, XL)和patch size(8、4、2)。

注意,DiT-L和DiT-XL的 Gflops是很接近的

图2(左)给出了不同Gflop模型及其在400K训练迭代时的FID。

在所有情况下,我们发现模型大小越大,patch size 越小,产生的扩散模型越好。

在这里插入图片描述

图6 (top)在固定patch size的情况下,模型越大,FID越小

图6 (bottom)在固定模型 size 的情况下, patch size越小,FID越小

在这里插入图片描述

DiT Gflops 对改进模型很重要

图6的结果表明,参数计数并不是唯一决定DiT模型的质量。

当模型size保持不变,patch大小减少,transformer的总参数有效保持不变(actually,总参数略有下降),只有

Gflops增加了。

这些结果表明,模型Gflops实际上是提高性能的关键。

为了进一步研究,我们在400K训练步骤中绘制FID-50K。图8中的模型Gflops。

结果表明,不同的DiT配置会获得相似的FID值,它们的Gflops总量相似(例如DiT-S/2和DiT-B/4)。

发现模型Gflops和FID-50K之间存在很强的负相关性,表明额外的模型计算是改进DiT模型的关键因素。

在图 12(附录)中,我们发现这种趋势适用于其他像Inception分数这样的指标。

更大的 DiT 模型计算效率更高

在图9,我们绘制适用于所有DiT训练模型FID函数。

我们评估了训练计算:Gflops · batch size · training steps ·3,其中的因子3大致将向后传递近似为两次和前向传

递一样计算量。

我们发现即使是小的DiT 模型,训练时间很长时, 相比训练少量step的较大DiT模型也会变得计算效率低。

类似地,当控制Gflops时,我们发现除了patch size之外, 不同配置的模型也会有不同的性能。

例如,XL/4在大约10的10次方Gflops后的表现要优于XL/2。

缩放可视化

我们将缩放的效果可视化如图7中的样本质量。

在400K训练步骤中,我们使用相同的起始噪声,从我们的12个DiT模型中的每个模型中sample图像

例如,对噪声和类标签进行采样。

这让我们可以直观地解释缩放如何影响DiT sample质量。

实际上,扩展了模型的大小和数量token的使用可以显著提高视觉质量。

5.1. State-of-the-Art Diffusion Models

256*256 ImageNet

根据我们的缩放设置,我们训练了最高Gflops的模型 DiT-XL/2, for 7M steps。

figures1中展示了生成的样本。

我们还和当前类别条件生成模型进行了对比,如表2所示

当使用无类别引导时 , DiT-XL/2 表现优于所有先前的扩散模型,将之前最好的FID-50K LDM 的3.6降低到2.27.

图2(右)表示DiT-XL/2 (118.6 Gflops) 计算高效 相比latent space U-Net models like LDM-4 (103.6 Gflops) ,

并且比像素空间U-Net 更有效,如ADM (1120 Gflops)或ADM- u (742 Gflops)。

我们的模型相比之前的生成式模型包括StyleGAN-XL, 都获得了更低的FID。

最后,我们还发现DiT-XL/2相比LDM-4 和 LDM-8 获得了更高的recall。

当训练到2.35M steps时, XL/2 以FID2.55同样优于先前所有的扩散模型。
在这里插入图片描述

512*512 ImageNet

我们在ImageNet的分辨率为512 × 512上,训练一个新的DiT-XL/2模型,迭代次数为3M,使用与256 × 256模型相

同的超参数。

patch大小为2,XL/2模型总共处理1024个tokens的标记,64 × 64 × 4 input latent(524.6Gflops)。

表3显示了与最新技术的比较方法。

XL/2再次优于所有先前的扩散模型在此分辨率下,将之前的最佳FID 3.85(ADM ) 改进为3.04。

即使增加了token的number, XL/2仍然是计算高效的。

例如 ADM使用1983 Gflops, ADM-U使用2813 Gflops; XL/2使用了524.6 Gflops。

我们在图1和附录中展示了高分辨率XL/2模型的示例

5.2. Scaling Model vs. Sampling Compute

扩散模型的独特之处在于,它们可以在训练后通过增加模型的数量来使用额外的计算生成图像时的采样步骤。

鉴于模型Gflops对样本质量的影响,在本节中,我们研究较小模型的计算DiT是否可以胜过较大模型
一种是通过使用更多的采样计算。我们计算FID, 对于经过400K训练步骤的所有12个DiT模型,我们使用[16, 32, 64, 128, 256, 1000]每个图像的采样

步骤。考虑一下DiT-L/2 使用1000步采样和DiT-XL/2的128步采样。

在这种情况下,L/2使用80.7 Tflops对每个图像进行采样; XL/2对每个样本的计算量减少了5倍,即15.2 tflops。尽管如此,XL/2有更好的FID-10K (23.7vs 25.9)。一般来说,上扩采样计算不能弥补模型计算的不足

总结

本文介绍扩散transformer (DiTs),一种简单的基于transformer的扩散模型骨干,表现优于之前的U-Net模型,并继承了transformer模型类的优秀可扩展特性。考虑到有希望的扩展结果,未来的工作应该
继续将dit扩展到更大的模型和token计数。DiT还可以作为文本到图像模型(如DALL·E 2和稳定扩散)的基本框架。

参考资料

https://zhuanlan.zhihu.com/p/557971459
https://blog.csdn.net/u012193416/article/details/134268353

Logo

科技之力与好奇之心,共建有温度的智能世界

更多推荐