8.6 Diffusion Model——DDPM

前言

dm-sketch

2020年,DDPM横空出世,将扩散模型概念迅速扩散到深度学习各个领域,并在2022年随着Stable Diffusion的提出及开源,该技术实现了破圈,走入了大众的视野,而不再是科技工作者才了解的概念。

为此,对扩散模型原理进行简要介绍,并通过代码实现DDPM模型,同时介绍Stable Diffusion 模型背后的原理。

本文主要内容包括

  1. Diffusion Model 概念介绍
  2. DDPM 模型原理及代码实现,训练、推理
  3. Guided Diffusion:引导条件的扩散模型介绍,包括classifier-base 和 classifier-free 两大主流模型
  4. Stable Diffusion:让技术出圈的模型
  5. Latent Diffusion Model(LDM):Stable Diffusion背后的核心技术

Diffusion Model 简介

扩散模型(Diffusion Model)发展至今已成为一个大的概念、思想,扩散是借鉴物理学中的扩散过程(Diffusion Process)概念。

物理学中扩散是一种物质分子在非均匀环境中的运动,物质从高浓度区域向低浓度区域传输,最终实现浓度均衡。

在深度学习中,则是将噪声加入到原始图像中进行扩散,最终使图片变为噪声,然后利用深度学习模型学习从噪声变到图像的过程,最后可以随机生成噪声,并利用模型将噪声生成图片的过程。

深度学习中扩散的概念自2015就有了,并在2019年发表于论文《Generative Modeling by Estimating Gradients of the Data Distribution》,

最终在2020年的《Denoising Diffusion Probabilistic Models》中被大众熟知,随后就开启了扩散模型的学术界扩散,如DALL-E 2,imagen, Stable Diffusion等强大应用。

DDPM 实现噪声到图片步骤

此处借鉴李宏毅教授2023年春季期ML课程中课件进行讲解。

DDPM模型推理过程是将一个标准正态分布中采样的噪声图片(与原图同尺寸),经过T步(1000步)的去噪(Denoising),生成高质量图像的过程。

ddpm-process

DDPM模型推理过程,可以看似将噪声逐步的去除,先获得图像大体轮廓,逐步精雕细琢,获得清晰的图像。

这就像雕像制作过程,工匠常说:“雕像本身就在石头里,我只是把多余的部分剔除掉”,雕刻雕像的过程就像噪声变到高质量图像的过程,一开始它们都是可以生成“万物”的本源。

statue

如何对Denoise模块进行数学建模,使得噪声逐步变清晰?

可以这么做,设计一个神经网络,它接收噪声图以及当前步数,输出一个噪声,然后与原图相减,获得更清晰的图片。

denoise-module

如何训练这样的神经网络模型?训练数据如何构建?

前面提到噪声如何变图像,现在反过来看看,图片如何变噪声的过程

对于原图,经过T步的逐渐加高斯噪声,使图像逐步模糊,最终趋近于标准高斯分布。

这其中就可以构建Noise Predicter的训练数据,例如蓝色框中为输入,红色框噪声则是反向过程时期望预测的标签。

ddpm-data-collection

对于具体模型,DDPM中采用了Unet架构的神经网络实现数据预测。

到这里,DDPM实现从噪声生成图像的步骤就清晰了:

  1. 前向过程:将原图逐步添加噪声, 共1000步
  2. 反向过程:利用神经网络学习加噪图像到噪声的变换,使得模型可以去噪
  3. 推理使用:随机采样,得到高斯噪声,然后逐步去噪,经过1000步去噪,得到清晰图像。

DDPM 公式理解

根据上述步骤,可以将DDPM训练、推理过程采用数学形式表达,如下图所示

ddpm-formula

训练过程:

  • q(x0) 表示原始图像数据集(分布),x0表示一张原始图像
  • t 看成是1到1000的均匀分布采样
  • ε 表示从标准正态分布中采样得到的噪声图像
  • εθ 表示模型需要学习到的噪声图像,该图像是利用unet生成的,unet接收上一步去噪图与当前步数t,预测出一个噪声图像,并且期望它与高斯噪声越接近越好。即ε - εθ 趋于0。
  • αt_bar:均值系数,可由重参数方法训练而来,或是固定值。固定值如0.0001 到0.02线性插值。

推理过程:

  • xT:从正态分布中随机采样的噪声
  • z:从正态分布中随机采样的噪声
  • xt-1:主要是由Xt减去模型生成的噪声图像,并且以一定的权重加权后,加上标准差乘以随机噪声。至于原因需要看原文及进行公式推导理解了

更多公式推导,推荐阅读

ddpm-code-debug

DDPM 模型结构

下面通过代码进行DDPM模型结构的剖析,官方代码为TF版,在这里采用非官方的PyTorch版

论文采用TPU v3-8(相当于8张V100 GPU),在cifar10上花了10.6小时,由此可见,要想在256x256的图片上训练,会非常耗时。

为了快速使用DDPM,这里采用cifar10进行学习。

通过代码分析,DDPM模型结构如下图所示,是在unet结构上进行了一些改进,包括加入时间步t的embedding,卷积中采用了ResBlock,并且采用了Self-Attention机制。

ddpm-struct

如图例所示,模型整体有7个组件,彩色箭头是一组操作,包含2-3个网络层的堆叠,通常最后一个网络层才会改变图像分辨率。

第一个,时间步的embedding,它会输入到除了head,tail的其它网络层当中,并且是add的形式添加的(h += self.temb_proj(temb)[:, :, None, None])

ddpm-code

第二个,head模块,是一个3x3卷积,主要是为了改变通道,没有特殊的地方。

第三个,down block,是下采样的核心,一个block由2个ResBlock与一个下采样层构成。ResBlock内部如图所示:

ddpm-resblock

第四个,middle block,由两个ResBlock构成

self.middleblocks = nn.ModuleList([
    ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
    ResBlock(now_ch, now_ch, tdim, dropout, attn=False),])

第五个,Up block,由3个ResBlock+1个上采样层,

第六个,tail,由GN + swish + conv构成,输出最终图像

self.tail = nn.Sequential(
    nn.GroupNorm(32, now_ch),
    Swish(),
    nn.Conv2d(now_ch, 3, 3, stride=1, padding=1))

第七个,concat,是unet的low-level特征融合到high-level特征当中

总的来说,特色在于时间步t的embedding是加入了每一个ResBlock中进行使用,并且ResBlock采用了self-attention机制。

DDPM——训练Cifar-10

接下来使用配套代码中的Main.py进行训练,并且使用Main.py进行推理,训练和推理需要调整"state": "eval"。

  • 数据准备:在Main.py同级目录下创建cifar文件夹,并且将cifar-10-python.tar.gz放到里边。
  • 运行训练:python Main.py,1080ti上训练,约16小时,训练完毕,在Checkpoints文件下有ckpt_last_.pt。

下面观察cifar-10的训练细节,在配套代码中可以看到,数据采用torchvision提供的cifar10 dataset接口,模型为Unet,优化器为AdamW,学习率为warmup+consine。

在主循环中,只用了images,labels是没有使用到的

模型的迭代,封装在了GaussianDiffusionTrainer类的forward函数,这也是核心代码之一,下面详细看看forward函数。

第5行:进行时间步的采样,即为每一个样本配一个时间步,并不需要为每个样本采样1000个时间步进行训练,这是因为公式推导的时候,xt可以用x0直接表示的,不需要依赖xt-1。

第6行:对需要加入的噪声进行采样,这里为标准正态分布

第11行:根据公式计算得到x_t,x_t由x0与noise加权得到,细节可看公式理解部分,这里的两个权重之和不等于一,但是接近1。

第14行:模型接收x_t与t,去预测噪声图像,并且通过mse_loss进行损失计算。

def forward(self, x_0):
    """
    Algorithm 1.
    """
    t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device)  # 不需要1000步都训练,随机batchsize个
    noise = torch.randn_like(x_0)  # 标准正态分布
    # 基于x0,获得xt, 随后得到训练数据[(xt, t, noise), ]
    # x_t.shape     [bs, 3, 32, 32]
    # noise.shape   [bs, 3, 32, 32]
    # t.shape       (bs,)
    x_t = (extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
           extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)

    loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')
    return loss

代码与上文的原理介绍是一致的,时间步step2与加噪后的图像是输入数据,标签ground truth是noise。

ddpm-data-collection

DDPM——推理Cifar-10

训练完毕后,先看看推理效果。

首先在Main.py中修改 "state": "eval", # train or eval,然后运行python Main.py,即可在"sampled_dir": "./SampledImgs/ 文件夹下获得如下图片,看上去还像个样子,毕竟数据量、算力、时间摆在这里。

这里提供一个训练好的模型参数,ckptlast.pt,下载后放到Checkpoints文件夹。链接:https://pan.baidu.com/s/17X_L9oH4lmrGwnD-V9D5HQ 提取码:w4ki

ddpm-inference

推理过程的代码理解相对有一点绕,主要还是参照论文中的sampling过程,如红框所示,首先获得均值(为什么叫均值?可能要去推一下公式了),然后加上时间步对应的标准差乘以随机噪声。

其中,均值主要是由Xt减去模型生成的噪声图像,并且以一定的权重加权后得到。

ddpm-formula

核心代码涉及3个函数,

  1. forward()为主函数。
  2. p_mean_variance()为调用Unet模型,获得mean和var。
  3. predict_xt_prev_mean_from_eps()是进行上图中红色框运算的过程。

第26行,获得模型预测的噪声图片

第27行,获得mean,即上图中的红色框

第15行,加上标准差乘以随机噪声,获得t时刻的输出,反复迭代1000次,得到最终输出图像。

def forward(self, x_T):
    """
    Algorithm 2.
    """
    x_t = x_T
    for time_step in reversed(range(self.T)):
        print(time_step)
        t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
        mean, var = self.p_mean_variance(x_t=x_t, t=t)  # mean是 xt图 减去 噪声图
        # no noise when t == 0
        if time_step > 0:
            noise = torch.randn_like(x_t)
        else:
            noise = 0
        x_t = mean + torch.sqrt(var) * noise
        assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
    x_0 = x_t
    return torch.clip(x_0, -1, 1)   

def p_mean_variance(self, x_t, t):
    # below: only log_variance is used in the KL computations
    # posterior_var: 由 betas计算得到,betas=[0.0001 to 0.02]
    var = torch.cat([self.posterior_var[1:2], self.betas[1:]])  # betas=[0.0001 to 0.02]
    var = extract(var, t, x_t.shape)

    eps = self.model(x_t, t)  # eps是unet输出的图像
    xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)  # 加权减法,xt图 减去 噪声图

    return xt_prev_mean, var

def predict_xt_prev_mean_from_eps(self, x_t, t, eps):
    assert x_t.shape == eps.shape
    return (
        extract(self.coeff1, t, x_t.shape) * x_t -
        extract(self.coeff2, t, x_t.shape) * eps
    )

Diffusion Model 拓展 —— Guided Diffusion

guided diffusion是加入了引导信息,让生成的图像变为我们想要的形式,而不是随机的图片。

引导式的扩散模型从有无分类器,可以分为两种,classifier-base和classifier-free,classifier-free由于不需要分类器,引导信息直接embedding到模型中,所以应用更为广泛。

classifier-base ——《Diffusion Models Beat GANs on Image Synthesis》

DDPM提出后,其实效果并未惊艳大家,在DDPM发表后的几个月,《Diffusion Models Beat GANs on Image Synthesis》的发表, 其github,把扩散模型带入了高潮,因为它效果比GAN更好,并且针对DDPM,引入了classifier-guidance思想,可以在生成时加入条件约束,可控制生成特定类别的图像。

具体公式详见原文。在使用时,采用unet估计mean时,需要额外加上分类器的分类结果的梯度,详见openai的github:https://github.com/openai/guided-diffusion

第4行:均值除了unet的,还需要加入分类器得到的梯度

第7行:分类器推理,计算梯度过程,这里有个重要参数是args.classifier_scale

# guided_diffusion/gaussian_diffusion.py
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
    gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
    new_mean = (p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float())
    return new_mean
# scripts/classifier_sample.py
def cond_fn(x, t, y=None):
    assert y is not None
    with th.enable_grad():
        x_in = x.detach().requires_grad_(True)
        logits = classifier(x_in, t)
        log_probs = F.log_softmax(logits, dim=-1)
        selected = log_probs[range(len(logits)), y.view(-1)]
        return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale

classifier-free —— 《classifier free diffusion guidance》

由于classifier-base需要训练分类器,并且在推理时有超参数args.classifier_scale的影响,以及引导条件的加入过于单一,没有办法通用性的加入各类条件。

为此,谷歌大脑的两位工程师提出了classifier free的方式,文中将分类信息通过embedding的方式加入到模型中训练,这里类似时间步t的embedding。

训练时会结合有条件与无条件进行训练,无条件则将分类标签embedding全部设置为0,具体细节可参见论文。

由于论文中没有提供代码,所以找到的代码是这个DDPM,其中的condition模式就是classifier-free。

第2行:训练时,有10%的是无条件的,90%是有条件的

第9行:标签信息与时间步一样,通过embedding汇入模型中,称为引导信息

# DiffusionFreeGuidence/TrainCondition.py
if np.random.rand() < 0.1:
    labels = torch.zeros_like(labels).to(device)

# DiffusionFreeGuidence/ModelCondition.py
def forward(self, x, t, labels):
    # Timestep embedding
    temb = self.time_embedding(t)
    cemb = self.cond_embedding(labels)
    # Downsampling
    h = self.head(x)
    hs = [h]
    for layer in self.downblocks:
        h = layer(h, temb, cemb)
        hs.append(h)
    ...略

classifier free diffusion是打开了一扇大门,既然类别标签可以embedding,那么文本信息也可以通过该方式注入模型中进行引导,火爆的Stable Diffusion就是这么做的。

Diffusion Model 拓展 —— Stable Diffusion

Stable Diffusion 是2022年火爆全球的文图生成(text-to-image)扩散模型,由于它开源,并且效果相当炸裂,因此已经被大多数人使用。

Stable Diffusion 背后的技术是LDM(latent diffusion model),之所以叫Stable Diffusion,或许与其背后的公司由Stability AI有关。

Stable Diffusion 是由CompVis、Stability AI和LAION三家公司共同创建,CompVis提供的技术LDM(latent diffusion model)源自论文《High-Resolution Image Synthesis with Latent Diffusion Models》,对应的github。LAION公司是一家致力于推动人工智能和数据科学发展的科技公司,其从互联网上抓取的 58 亿「图像-文本」数据,并开源了 LAION-5B数据集。而Stability AI的贡献,或许是出钱出力出人吧。

Stable Diffusion 的开源代码: https://github.com/CompVis/stable-diffusion 与 LDM(latent diffusion model)的开源代码:https://github.com/CompVis/latent-diffusion都在CompVis下,代码几乎一样。

下面简要介绍Stable Diffusion用到的latent diffusion model技术。

LDM之前,扩散模型在像素域进行扩散与去噪,这样的计算量过大。因此,考虑将扩散过程放到隐空间(latent space),即将数据经过encoder,来到特征空间,在特征空间上进行扩散和去噪。

这样一来,有以下好处:

  1. 计算量减小,训练和推理速度变快
  2. 可以加入更多引导信息,例如文本信息。

LDM论文中有一幅图很好的解释了LDM的思想:首先在pixel space,需要有encoder和decoder,在latent space采用了多头注意力机制,并且除了时间步信息,加入了conditioning模块,其中的引导信息可以是文本、图片、表征向量等等一切内容,然后为引导信息配一个embedding模块,就可以将引导信息加入模型中。

stable-diffusion-struct

这里配上李宏毅老师的结构示意图,可知道LDM的核心在于2当中,处理的不再是像素空间,而是一个特征空间

stable-diffusion-sketch

stable diffusion 的使用与安装,网上有太多教程,这里不进行介绍,主要了解LDM的架构。推荐阅读:文生图模型之Stable Diffusion

与Stable Diffusion同一时期,叫得上名称的文图生成模型还有Midjourney、DALL-E 2,不过它们都是不开源的。

小结

本案例借助DDPM的代码剖析,了解扩散模型实现去噪,从而生成图像的过程和原理,并且对Guided Diffusion Model进行介绍,模型要能根据我们的“指示”生成特定的图像,这样的模型才有更大的应用价值。

在Guided Diffusion Model中,包含classifier-base 和 classifier-free,classifier-free是后来的主流。

classifier-free的代表即出圈的Stable Diffusion,Stable Diffusion是完全开源的,因此得到了全球的使用与关注。

在扩散模型中,LDM(latent diffusion model)思想值得仔细研究,它将一切信息都放到隐空间(特征空间)进行处理,使得图片处理起来更小,还可以进行多模态处理。

Copyright © TingsongYu 2021 all right reserved,powered by Gitbook文件修订时间: 2024年04月26日21:48:10

results matching ""

    No results matching ""