AI大模型教程
一起来学习

天才程序员周弈帆 | Stable Diffusion 解读(三):原版实现源码解读(篇幅略长,建议收藏!)

本文来源公众号“天才程序员周弈帆”,仅用于学术分享,侵权删,干货满满。

原文链接:Stable Diffusion 解读(三):原版实现源码解读

天才程序员周弈帆 | Stable Diffusion 解读(一):回顾早期工作-CSDN博客

天才程序员周弈帆 | Stable Diffusion 解读(二):论文精读-CSDN博客

看完了Stable Diffusion的论文,在最后这几篇文章里,我们来学习Stable Diffusion的代码实现。具体来说,我们会学习Stable Diffusion官方仓库及Diffusers开源库中有关采样算法和U-Net的代码,而不会学习有关训练、VAE、text encoder (CLIP) 的代码。如今大多数工作都只会用到预训练的Stable Diffusion,只学采样算法和U-Net代码就能理解大多数工作了。

受字数限制,Diffusers的介绍会放到下一篇文章里。

建议读者在阅读本文之前了解DDPM、ResNet、U-Net、Transformer。

本文用到的Stable Diffusion版本是v1.5。Diffusers版本是0.25.0。为了提升可读性,本文对源代码做了一定的精简,部分不会运行到的分支会被略过。

1 算法梳理

在正式读代码之前,我们先用伪代码梳理一下Stable Diffusion的采样过程,并回顾一下U-Net架构的组成。实现Stable Diffusion的代码库有很多,各个库之间的API差异很大。但是,它们实际上都是在描述同一个算法,同一个模型。如果我们理解了算法和模型本身,就可以在学习时主动去找一个算法对应哪一段代码,而不是被动地去理解每一行代码在干什么。

1.1 LDM 采样算法

让我们从最早的DDPM开始,一步一步还原Latent Diffusion Model (LDM)的采样算法。DDPM的采样算法如下所示:

def ddpm_sample(image_shape):
  ddpm_scheduler = DDPMScheduler()
  unet = UNet()
  xt = randn(image_shape)
  T = 1000
  for t in T ... 1:
    eps = unet(xt, t)
    std = ddpm_scheduler.get_std(t)
    xt = ddpm_scheduler.get_xt_prev(xt, t, eps, std)
  return xt

在DDPM的实现中,一般会有一个类专门维护扩散模型的alpha,beta等变量。我们这里把这个类称为DDPMScheduler。此外,DDPM会用到一个U-Net神经网络unet,用于计算去噪过程中图像应该去除的噪声eps。准备好这两个变量后,就可以用randn()标准正态分布中采样一个纯噪声图像xt。它会被逐渐去噪,最终变成一幅图片。去噪过程中,时刻t会从总时刻T遍历至1(总时刻T一般取1000)。在每一轮去噪步骤中,U-Net会根据这一时刻的图像xt当前时间戳t估计出此刻应去除的噪声eps,根据xteps就能知道下一步图像的均值。除了均值,我们还要获取下一步图像的方差,这一般可以从DDPM调度类中直接获取。有了下一步图像的均值和方差,我们根据DDPM的公式,就能采样出下一步的图像。反复执行去噪循环,xt会从纯噪声图像变成一幅有意义的图像。

DDIM对DDPM的采样过程做了两点改进:1) 去噪的有效步数可以少于T步,由另一个变量ddim_steps决定;2) 采样的方差大小可以由eta决定。

因此,改进后的DDIM算法可以写成这样:

def ddim_sample(image_shape, ddim_steps = 20, eta = 0):
  ddim_scheduler = DDIMScheduler()
  unet = UNet()
  xt = randn(image_shape)
  T = 1000
  timesteps = ddim_scheduler.get_timesteps(T, ddim_steps) # [1000, 950, 900, ...]
  for t in timesteps:
    eps = unet(xt, t)
    std = ddim_scheduler.get_std(t, eta)
    xt = ddim_scheduler.get_xt_prev(xt, t, eps, std)
  return xt

其中,ddim_steps是去噪循环的执行次数。根据ddim_steps,DDIM调度器可以生成所有被使用到的t。比如对于T=1000, ddim_steps=20,被使用到的就只有[1000, 950, 900, ..., 50]这20个时间戳,其他时间戳就可以跳过不算了。eta会被用来计算方差,一般这个值都会设成0

DDIM是早期的加速扩散模型采样的算法。如今有许多比DDIM更好的采样方法,但它们多数都保留了stepseta这两个参数。因此,在使用所有采样方法时,我们可以不用关心实现细节,只关注多出来的这两个参数。

在DDIM的基础上,LDM从生成像素空间上的图像变为生成隐空间上的图像。隐空间图像需要再做一次解码才能变回真实图像。从代码上来看,使用LDM后,只需要多准备一个VAE,并对最后的隐空间图像zt解码。

def ldm_ddim_sample(image_shape, ddim_steps = 20, eta = 0):
  ddim_scheduler = DDIMScheduler()
  vae = VAE()
  unet = UNet()
  zt = randn(image_shape)
  T = 1000
  timesteps = ddim_scheduler.get_timesteps(T, ddim_steps) # [1000, 950, 900, ...]
  for t in timesteps:
    eps = unet(zt, t)
    std = ddim_scheduler.get_std(t, eta)
    zt = ddim_scheduler.get_xt_prev(zt, t, eps, std)
  xt = vae.decoder.decode(zt)
  return xt

而想用LDM实现文生图,则需要给一个额外的文本输入text。文本编码器会把文本编码成张量c,输入进unet。其他地方的实现都和之前的LDM一样。

def ldm_text_to_image(image_shape, text, ddim_steps = 20, eta = 0):
  ddim_scheduler = DDIMScheduler()
  vae = VAE()
  unet = UNet()
  zt = randn(image_shape)
  T = 1000
  timesteps = ddim_scheduler.get_timesteps(T, ddim_steps) # [1000, 950, 900, ...]

  text_encoder = CLIP()
  c = text_encoder.encode(text)

  for t = timesteps:
    eps = unet(zt, t, c)
    std = ddim_scheduler.get_std(t, eta)
    zt = ddim_scheduler.get_xt_prev(zt, t, eps, std)
  xt = vae.decoder.decode(zt)
  return xt

最后这个能实现文生图的LDM就是我们熟悉的Stable Diffusion。Stable Diffusion的采样算法看上去比较复杂,但如果能够从DDPM开始把各个功能都拆开来看,理解起来就不是那么困难了。

1.2 U-Net 结构组成

Stable Diffusion代码实现中的另一个重点是去噪网络U-Net的实现。仿照上一节的学习方法,我们来逐步学习Stable Diffusion中的U-Net是怎么从最经典的纯卷积U-Net逐渐发展而来的。

最早的U-Net的结构如下图所示:

可以看出,U-Net的结构有以下特点

  • 整体上看,U-Net由若干个大层组成。特征在每一大层会被下采样成尺寸更小的特征,再被上采样回原尺寸的特征。整个网络构成一个U形结构

  • 下采样后,特征的通道数会变多。一般情况下,每次下采样后图像尺寸减半,通道数翻倍。上采样过程则反之。

  • 为了防止信息在下采样的过程中丢失,U-Net每一大层在下采样前的输出会作为额外输入拼接到每一大层上采样前的输入上。这种数据连接方式类似于ResNet中的「短路连接」

DDPM则使用了一种改进版的U-Net。改进主要有两点:

  • 原来的卷积层被替换成了ResNet中的残差卷积模块。每一大层有若干个这样的子模块。对于较深的大层,残差卷积模块后面还会接一个自注意力模块。

  • 原来模型每一大层只有一个短路连接。现在每个大层下采样部分的每个子模块的输出都会额外输入到其对称的上采样部分的子模块上。直观上来看,就是短路连接更多了一点,输入信息更不容易在下采样过程中丢失。

最后,LDM提出了一种给U-Net添加额外约束信息的方法:把U-Net中的自注意力模块换成交叉注意力模块。具体来说,DDPM的U-Net的自注意力模块被换成了标准的Transformer模块。约束信息可以作为Cross Attention的K, V输入进模块中。

Stable Diffusion的U-Net还在结构上有少许修改,该U-Net的每一大层都有Transformer块,而不是只有较深的大层有。

至此,我们已经学完了Stable Diffusion的采样原理和U-Net结构。接下来我们来看一看它们在不同框架下的代码实现。

2 Stable Diffusion 官方 GitHub 仓库

2.1 安装

克隆仓库后,照着官方Markdown文档安装即可。

git clone git@github.com:CompVis/stable-diffusion.git

先用下面的命令创建conda环境,此后ldm环境就是运行Stable Diffusiion的conda环境。

conda env create -f environment.yaml
conda activate ldm

之后去网上下一个Stable Diffusion的模型文件。比较常见一个版本是v1.5,该模型在Hugging Face上:https://huggingface.co/runwayml/stable-diffusion-v1-5 (推荐下载v1-5-pruned.ckpt)。下载完毕后,把模型软链接到指定位置。

mkdir -p models/ldm/stable-diffusion-v1/
ln -s  models/ldm/stable-diffusion-v1/model.ckpt 

准备完毕后,只要输入下面的命令,就可以生成实现文生图了。

python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" 

在默认的参数下,“一幅骑着马的飞行员的照片”的绘制结果会被保存在outputs/txt2img-samples中。你也可以通过--outdir 参数来指定输出到的文件夹。我得到的一些绘制结果为:

【说明】如果你在安装时碰到了错误,可以在搜索引擎上或者GitHub的issue里搜索,一般都能搜到其他人遇到的相同错误。

2.2 主函数

接下来,我们来探究一下scripts/txt2img.py的执行过程。为了方便阅读,我们可以简化代码中的命令行处理,得到下面这份精简代码。(你可以把这份代码复制到仓库根目录下的一个新Python脚本里并直接运行。别忘了修改代码中的模型路径

import os
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from einops import rearrange
from pytorch_lightning import seed_everything
from torch import autocast
from torchvision.utils import make_grid

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler


def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model


def main():
    seed = 42
    config = 'configs/stable-diffusion/v1-inference.yaml'
    ckpt = 'ckpt/v1-5-pruned.ckpt'
    outdir = 'tmp'
    n_samples = batch_size = 3
    n_rows = batch_size
    n_iter = 2
    prompt = 'a photograph of an astronaut riding a horse'
    data = [batch_size * [prompt]]
    scale = 7.5
    C = 4
    f = 8
    H = W = 512
    ddim_steps = 50
    ddim_eta = 0.0

    seed_everything(seed)

    config = OmegaConf.load(config)
    model = load_model_from_config(config, ckpt)

    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = model.to(device)
    sampler = DDIMSampler(model)

    os.makedirs(outdir, exist_ok=True)
    outpath = outdir

    sample_path = os.path.join(outpath, "samples")
    os.makedirs(sample_path, exist_ok=True)
    grid_count = len(os.listdir(outpath)) - 1

    start_code = None
    precision_scope = autocast
    with torch.no_grad():
        with precision_scope("cuda"):
            with model.ema_scope():
                all_samples = list()
                for n in trange(n_iter, desc="Sampling"):
                    for prompts in tqdm(data, desc="data"):
                        uc = None
                        if scale != 1.0:
                            uc = model.get_learned_conditioning(
                                batch_size * [""])
                        if isinstance(prompts, tuple):
                            prompts = list(prompts)
                        c = model.get_learned_conditioning(prompts)
                        shape = [C, H // f, W // f]
                        samples_ddim, _ = sampler.sample(S=ddim_steps,
                                                         conditioning=c,
                                                         batch_size=n_samples,
                                                         shape=shape,
                                                         verbose=False,
                                                         unconditional_guidance_scale=scale,
                                                         unconditional_conditioning=uc,
                                                         eta=ddim_eta,
                                                         x_T=start_code)

                        x_samples_ddim = model.decode_first_stage(samples_ddim)
                        x_samples_ddim = torch.clamp(
                            (x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

                        all_samples.append(x_samples_ddim)
                grid = torch.stack(all_samples, 0)
                grid = rearrange(grid, 'n b c h w -> (n b) c h w')
                grid = make_grid(grid, nrow=n_rows)

                # to image
                grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
                img = Image.fromarray(grid.astype(np.uint8))
                img.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
                grid_count += 1

    print(f"Your samples are ready and waiting for you here: n{outpath} n"
          f" nEnjoy.")


if __name__ == "__main__":
    main()

抛开前面一大堆初始化操作,代码的核心部分只有下面几行。

uc = None
if scale != 1.0:
    uc = model.get_learned_conditioning(
        batch_size * [""])
if isinstance(prompts, tuple):
    prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
shape = [C, H // f, W // f]
samples_ddim, _ = sampler.sample(S=ddim_steps,
                                  conditioning=c,
                                  batch_size=n_samples,
                                  shape=shape,
                                  verbose=False,
                                  unconditional_guidance_scale=scale,
                                  unconditional_conditioning=uc,
                                  eta=ddim_eta,
                                  x_T=start_code)

x_samples_ddim = model.decode_first_stage(samples_ddim)

我们来逐行分析一下这段代码。一开始的几行是执行Classifier-Free Guidance (CFG)uc表示的是CFG中的无约束下的约束张量。scale表示的是执行CFG的程度,scale不等于1.0即表示启用CFG。model.get_learned_conditioning表示用CLIP把文本编码成张量。对于文本约束的模型,无约束其实就是输入文本为空字符串("")。因此,在代码中,若启用了CFG,则会用CLIP编码空字符串,编码结果为uc

如果你没学过CFG,也不用担心。你可以暂时不要去理解上面这段话。等读完了后文中有关CFG的代码后,你差不多就能理解CFG的用法了。

uc = None
if scale != 1.0:
    uc = model.get_learned_conditioning(
        batch_size * [""])

之后的几行是在把用户输入的文本编码成张量。同样,model.get_learned_conditioning表示用CLIP把输入文本编码成张量c

if isinstance(prompts, tuple):
    prompts = list(prompts)
c = model.get_learned_conditioning(prompts)

接着是用扩散模型的采样器生成图片。在这份代码中,sampler是DDIM采样器,sampler.sample函数直接完成了图像生成。

shape = [C, H // f, W // f]
samples_ddim, _ = sampler.sample(S=ddim_steps,
                                  conditioning=c,
                                  batch_size=n_samples,
                                  shape=shape,
                                  verbose=False,
                                  unconditional_guidance_scale=scale,
                                  unconditional_conditioning=uc,
                                  eta=ddim_eta,
                                  x_T=start_code)

后,LDM生成的隐空间图片被VAE解码成真实图片。函数model.decode_first_stage负责图片解码。x_samples_ddim在后续的代码中会被后处理成正确格式的RGB图片,并输出至文件里。

x_samples_ddim = model.decode_first_stage(samples_ddim)

Stable Diffusion 官方实现的主函数主要就做了这些事情。这份实现还是有一些凌乱的。采样算法的一部分内容被扔到了主函数里,另一部分放到了DDIM采样器里。在阅读官方实现的源码时,既要去读主函数里的内容,也要去读采样器里的内容。

接下来,我们来看一看DDIM采样器的部分代码,学完采样算法的剩余部分的实现。

2.3 DDIM 采样器

回头看主函数的前半部分,DDIM采样器是在下面的代码里导入的:

from ldm.models.diffusion.ddim import DDIMSampler

跳转到ldm/models/diffusion/ddim.py文件,我们可以找到DDIMSampler的实现。

先看一下这个类的构造函数。构造函数主要是把U-Net model给存了下来。后文中的self.model都指的是U-Net。

def __init__(self, model, schedule="linear", **kwargs):
    super().__init__()
    self.model = model
    self.ddpm_num_timesteps = model.num_timesteps
    self.schedule = schedule

# in main

config = OmegaConf.load(config)
model = load_model_from_config(config, ckpt)
model = model.to(device)
sampler = DDIMSampler(model)

再沿着类的self.sample方法,看一下DDIM采样的实现代码。以下是self.sample方法的主要内容。这个方法其实就执行了一个self.make_schedule,之后把所有参数原封不动地传到了self.ddim_sampling里。

@torch.no_grad()
def sample(self,
            S,
            batch_size,
            shape,
            conditioning=None,
            ...
            ):
    if conditioning is not None:
        ...

    self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
    # sampling
    C, H, W = shape
    size = (batch_size, C, H, W)
    print(f'Data shape for DDIM sampling is {size}, eta {eta}')

    samples, intermediates = self.ddim_sampling(...)

self.make_schedule用于预处理扩散模型的中间计算参数。它的大部分实现细节可以略过。DDIM用到的有效时间戳列表就是在这个函数里设置的,该列表通过make_ddim_timesteps获取,并保存在self.ddim_timesteps中。此外,由ddim_eta决定的扩散模型的方差也是在这个方法里设置的。大致扫完这个方法后,我们可以直接跳到self.ddim_sampling的代码。

def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
    self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
                                              num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
    ...

穿越重重的嵌套,我们总算能看到DDIM采样的实现方法self.ddim_sampling了。它的主要内容如下所示:

@torch.no_grad()
def ddim_sampling(self, ...):
    device = self.model.betas.device
    b = shape[0]
    img = torch.randn(shape, device=device)
    timesteps = self.ddim_timesteps
    intermediates = ...
    time_range = np.flip(timesteps)
    total_steps = timesteps.shape[0]

    iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)

    for i, step in enumerate(iterator):
        index = total_steps - i - 1
        ts = torch.full((b,), step, device=device, dtype=torch.long)

        outs = self.p_sample_ddim(img, cond, ts, ...)
        img, pred_x0 = outs

    return img, intermediates

这段代码和我们之前自己写的伪代码非常相似。一开始,方法获取了在make_schedule里初始化的DDIM有效时间戳列表self.ddim_timesteps,并预处理成一个iterator。该迭代器用于控制DDIM去噪循环。每一轮循环会根据当前时刻的图像img和时间戳ts计算下一步的图像img。具体来说,代码每次用当前的时间戳step创建一个内容全部为step,形状为(b,)的张量ts。该张量会和当前的隐空间图像img,约束信息张量cond一起传给执行一轮DDIM去噪的p_sample_ddim方法。p_sample_ddim方法会返回下一步的图像img。最后,经过多次去噪后,ddim_sampling方法将去噪后的隐空间图像img返回。

p_sample_ddim里的p_sample看上去似乎意义不明,实际上这个叫法来自于DDPM论文。在DDPM论文中,扩散模型的前向过程用字母q表示,反向过程用字母p表示。因此,反向过程的一轮去噪在代码里被叫做p_sample

最后来看一下p_sample_ddim这个方法,它的主体部分如下:

@torch.no_grad()
def p_sample_ddim(self, x, c, t, ...):
    b, *_, device = *x.shape, x.device

    if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
        e_t = self.model.apply_model(x, t, c)
    else:
        x_in = torch.cat([x] * 2)
        t_in = torch.cat([t] * 2)
        c_in = torch.cat([unconditional_conditioning, c])
        e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
        e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)


    # Prepare variables
    ...

    # current prediction for x_0
    pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
    if quantize_denoised:
        pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
    # direction pointing to x_t
    dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
    noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
    if noise_dropout > 0.:
        noise = torch.nn.functional.dropout(noise, p=noise_dropout)
    x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
    return x_prev, pred_x0

方法的内容大致可以拆成三段:首先,方法调用U-Net self.model,使用CFG来计算除这一轮该去掉的噪声e_t。然后,方法预处理出DDIM的中间变量。最后,方法根据DDIM的公式,计算出这一轮去噪后的图片x_prev。我们着重看第一部分的代码。

不启用CFG时,方法直接通过self.model.apply_model(x, t, c)调用U-Net,算出这一轮的噪声e_t。而想启用CFG,需要输入空字符串的约束张量unconditional_conditioning,且CFG的强度unconditional_guidance_scale不为1。CFG的执行过程是:对U-Net输入不同的约束c,先用空字符串约束得到一个预测噪声e_t_uncond,再用输入的文本约束得到一个预测噪声e_t。之后令e_t = et_uncond + scale * (e_t - e_t_uncond)scale大于1,即表明我们希望预测噪声更加靠近有输入文本的那一个。直观上来看,scale越大,最后生成的图片越符合输入文本,越偏离空文本。下面这段代码正是实现了上述这段逻辑,只不过代码使用了一些数据拼接技巧,让空字符串约束下和输入文本约束下的结果在一次U-Net推理中获得

if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
    e_t = self.model.apply_model(x, t, c)
else:
    x_in = torch.cat([x] * 2)
    t_in = torch.cat([t] * 2)
    c_in = torch.cat([unconditional_conditioning, c])
    e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
    e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)

p_sample_ddim 方法的后续代码都是在实现下面这个DDIM采样公式。代码工工整整地计算了公式中的predicted_x0dir_xtnoise,非常易懂,没有需要特别注意的地方。

我们已经看完了p_sample_ddim的代码。该方法可以实现一步去噪操作。多次调用该方法去噪后,我们就能得到生成的隐空间图片。该图片会被返回到main函数里,被VAE的解码器解码成普通图片。至此,我们就学完了Stable Diffusion官方仓库的采样代码。

对照下面这份我们之前写的伪代码,我们再来梳理一下Stable Diffusion官方仓库的代码逻辑。官方仓库的采样代码一部分在main函数里,另一部分在ldm/models/diffusion/ddim.py里。main函数主要完成了编码约束文字、解码隐空间图像这两件事。剩下的DDIM采样以及各种Diffusion图像编辑功能都是在ldm/models/diffusion/ddim.py文件中实现的。

def ldm_text_to_image(image_shape, text, ddim_steps = 20, eta = 0)
  ddim_scheduler = DDIMScheduler()
  vae = VAE()
  unet = UNet()
  zt = randn(image_shape)
  eta = input()
  T = 1000
  timesteps = ddim_scheduler.get_timesteps(T, ddim_steps) # [1000, 950, 900, ...]

  text_encoder = CLIP()
  c = text_encoder.encode(text)

  for t = timesteps:
    eps = unet(zt, t, c)
    std = ddim_scheduler.get_std(t, eta)
    zt = ddim_scheduler.get_xt_prev(zt, t, eps, std)
  xt = vae.decoder.decode(zt)
  return xt

在学习代码时,要着重学习DDIM采样器部分的代码。大部分基于Diffusion的图像编辑技术都是在DDIM采样的中间步骤中做文章,只要学懂了DDIM采样的代码,学相关图像编辑技术就会非常轻松。除此之外,和LDM相关的文字约束编码、隐空间图像编码解码的接口函数也需要熟悉,不少技术会调用到这几项功能。

还有一些Diffusion相关工作会涉及U-Net的修改。接下来,我们就来看Stable Diffusion官方仓库中U-Net的实现。

2.4 U-Net

我们来回头看一下main函数和DDIM采样中U-Net的调用逻辑。和U-Net有关的代码如下所示。LDM模型类 model在主函数中通过load_model_from_config从配置文件里创建,随后成为了sampler的成员变量。在DDIM去噪循环中,LDM模型里的U-Net会在self.model.apply_model方法里被调用。

# main.py
config = 'configs/stable-diffusion/v1-inference.yaml'
config = OmegaConf.load(config)
model = load_model_from_config(config, ckpt)
sampler = DDIMSampler(model)

# ldm/models/diffusion/ddim.py
e_t = self.model.apply_model(x, t, c)

为了知道U-Net是在哪个类里定义的,我们需要打开配置文件 configs/stable-diffusion/v1-inference.yaml。该配置文件有这样一段话:

model:
  target: ldm.models.diffusion.ddpm.LatentDiffusion
  params:
    conditioning_key: crossattn
    unet_config:
        target: ldm.modules.diffusionmodules.openaimodel.UNetModel

根据这段话,我们知道LDM类定义在ldm/models/diffusion/ddpm.pyLatentDiffusion里,U-Net类定义在ldm/modules/diffusionmodules/openaimodel.pyUNetModel里。一个LDM类有一个U-Net类的实例。我们先简单看一看LatentDiffusion的实现。

ldm/models/diffusion/ddpm.py原本来自DDPM论文的官方仓库,内含DDPM类的实现。DDPM类维护了扩散模型公式里的一些变量,同时维护了U-Net类的实例。LDM的作者基于之前DDPM的代码进行开发,定义了一个继承自DDPMLatentDiffusion类。除了DDPM本身的功能外,LatentDiffusion还维护了VAE(self.first_stage_model),CLIP(self.cond_stage_model)。也就是说,LatentDiffusion主要维护了扩散模型中间变量、U-Net、VAE、CLIP这四类信息。这样,所有带参数的模型都在LatentDiffusion里,我们可以从一个checkpoint文件中读取所有的模型的参数。相关代码定义代码如下:

把所有模型定义在一起有好处也有坏处。好处在于,用户想使用Stable Diffusion时,只需要下载一个checkpoint文件就行了。坏处在于,哪怕用户只改了某个子模型(如U-Net),为了保存整个模型,他还是得把其他子模型一起存下来。这其中存在着信息冗余,十分不灵活。Diffusers框架没有把模型全存在一个文件里,而是放到了一个文件夹里。

class DDPM(pl.LightningModule):
    # classic DDPM with Gaussian diffusion, in image space
    def __init__(self,
                 unet_config,
                 ...):
        self.model = DiffusionWrapper(unet_config, conditioning_key)
        

class LatentDiffusion(DDPM):
    """main class"""
    def __init__(self,
                 first_stage_config,
                 cond_stage_config,
                 ...):

        self.instantiate_first_stage(first_stage_config)
        self.instantiate_cond_stage(cond_stage_config)

我们主要关注LatentDiffusion类的apply_model方法,它用于调用U-Net self.modelapply_model看上去有很长,但略过了我们用不到的一些代码后,整个方法其实非常短。一开始,方法对输入的约束信息编码cond做了一个前处理,判断约束是哪种类型。如论文里所描述的,LDM支持两种约束:将约束与输入拼接、将约束注入到交叉注意力层中。方法会根据self.model.conditioning_keyconcat还是crossattn,使用不同的约束方式。Stable Diffusion使用的是后者,即self.model.conditioning_key == crossattn。做完前处理后,方法执行了x_recon = self.model(x_noisy, t, **cond)。接下来的处理交给U-Net self.model来完成。

def apply_model(self, x_noisy, t, cond, return_ids=False):
    if isinstance(cond, dict):
        # hybrid case, cond is exptected to be a dict
        pass
    else:
        if not isinstance(cond, list):
            cond = [cond]
        key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
        cond = {key: cond}

    x_recon = self.model(x_noisy, t, **cond)

    if isinstance(x_recon, tuple) and not return_ids:
        return x_recon[0]
    else:
        return x_recon

现在,我们跳转到ldm/modules/diffusionmodules/openaimodel.pyUNetModel里。UNetModel只定义了神经网络层的运算,没有多余的功能。我们只需要看它的__init__方法和forward方法。我们先来看较为简短的forward方法。

def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
    hs = []
    t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
    emb = self.time_embed(t_emb)

    h = x.type(self.dtype)
    for module in self.input_blocks:
        h = module(h, emb, context)
        hs.append(h)
    h = self.middle_block(h, emb, context)
    for module in self.output_blocks:
        h = th.cat([h, hs.pop()], dim=1)
        h = module(h, emb, context)
    h = h.type(x.dtype)
    return self.out(h)

forward方法的输入是x, timesteps, context,分别表示当前去噪时刻的图片、当前时间戳、文本约束编码。根据这些输入,forward会输出当前时刻应去除的噪声eps。一开始,方法会先对timesteps使用Transformer论文中介绍的位置编码timestep_embedding,得到时间戳的编码t_embt_emb再经过几个线性层,得到最终的时间戳编码emb。而context已经是CLIP处理过的编码,它不需要做额外的预处理。时间戳编码emb和文本约束编码context随后会注入到U-Net的所有中间模块中

def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
    hs = []
    t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
    emb = self.time_embed(t_emb)

经过预处理后,方法开始处理U-Net的计算。中间结果h会经过U-Net的下采样模块input_blocks,每一个子模块的临时输出都会被保存进一个栈hs里。

 h = x.type(self.dtype)
for module in self.input_blocks:
    h = module(h, emb, context)
    hs.append(h)

接着,h会经过U-Net的中间模块。

h = self.middle_block(h, emb, context)

随后,h开始经过U-Net的上采样模块output_blocks。此时每一个编码器子模块的临时输出会从栈hs里弹出,作为对应解码器子模块的额外输入。额外输入hs.pop()会与中间结果h拼接到一起输入进子模块里。

for module in self.output_blocks:
    h = th.cat([h, hs.pop()], dim=1)
    h = module(h, emb, context)
h = h.type(x.dtype)

最后,h会被输出层转换成一个通道数正确的eps张量。

return self.out(h)

这段代码的数据连接图如下所示:

在阅读__init__前,我们先看一下待会会用到的另一个模块类TimestepEmbedSequential的定义。在PyTorch中,一系列输入和输出都只有一个变量的模块在串行连接时,可以用串行模块类nn.Sequential来把多个模块合并简化成一个模块。而在扩散模型中,多数模块的输入是x, t, c三个变量,输出是一个变量。为了也能用类似的串行模块类把扩散模型的模块合并在一起,代码中包含了一个TimestepEmbedSequential类。它的行为类似于nn.Sequential,只不过它支持x, t, c的输入。forward中用到的多数模块都是通过TimestepEmbedSequential创建的

class TimestepEmbedSequential(nn.Sequential, TimestepBlock):

    def forward(self, x, emb, context=None):
        for layer in self:
            if isinstance(layer, TimestepBlock):
                x = layer(x, emb)
            elif isinstance(layer, SpatialTransformer):
                x = layer(x, context)
            else:
                x = layer(x)
        return x

看完了数据的计算过程,我们回头来看各个子模块在__init__方法中是怎么被详细定义的。__init__的主要内容如下:

class UNetModel(nn.Module):
    def __init__(self, ...):

        self.time_embed = nn.Sequential(
            linear(model_channels, time_embed_dim),
            nn.SiLU(),
            linear(time_embed_dim, time_embed_dim),
        )

        self.input_blocks = nn.ModuleList(
            [
                TimestepEmbedSequential(
                    conv_nd(dims, in_channels, model_channels, 3, padding=1)
                )
            ]
        )

        for level, mult in enumerate(channel_mult):
            for _ in range(num_res_blocks):
                layers = [
                    ResBlock(...)]
                ch = mult * model_channels
                if ds in attention_resolutions:
                     layers.append(
                        AttentionBlock(...) if not use_spatial_transformer else SpatialTransformer(...))

                self.input_blocks.append(TimestepEmbedSequential(*layers))
            if level != len(channel_mult) - 1:
                out_ch = ch
                self.input_blocks.append(
                    TimestepEmbedSequential(
                        ResBlock(...)
                        if resblock_updown
                        else Downsample(...)
                    )
                )

        self.middle_block = TimestepEmbedSequential(
            ResBlock(...),
            AttentionBlock(...) if not use_spatial_transformer else SpatialTransformer(...),
            ResBlock(...),
        )

        self.output_blocks = nn.ModuleList([])
        for level, mult in list(enumerate(channel_mult))[::-1]:
            for i in range(num_res_blocks + 1):
                ich = input_block_chans.pop()
                layers = [
                    ResBlock(...)
                ]
                ch = model_channels * mult
                if ds in attention_resolutions:
                    layers.append(
                        AttentionBlock(...) if not use_spatial_transformer else SpatialTransformer(...)
                    )
                if level and i == num_res_blocks:
                    out_ch = ch
                    layers.append(
                        ResBlock(...)
                        if resblock_updown
                        else Upsample(...)
                    )
                    ds //= 2
                self.output_blocks.append(TimestepEmbedSequential(*layers))
    self.out = nn.Sequential(
            normalization(ch),
            nn.SiLU(),
            zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
        )

__init__方法的代码很长。在阅读这样的代码时,我们不需要每一行都去细读,只需要理解代码能拆成几块,每一块在做什么即可__init__方法其实就是定义了forward中用到的5个模块,我们一个一个看过去即可。

class UNetModel(nn.Module):
    def __init__(self, ...):

        self.time_embed = ...

        self.input_blocks = nn.ModuleList(...)
        for level, mult in enumerate(channel_mult):
            ...

        self.middle_block = ...

        self.output_blocks = nn.ModuleList([])
        for level, mult in list(enumerate(channel_mult))[::-1]:
            ...
    self.out = ...

先来看time_embed。回忆一下,在forward里,输入的整数时间戳会被正弦编码timestep_embedding(即Transformer中的位置编码)编码成一个张量。之后,时间戳编码处理模块time_embed用于进一步提取时间戳编码的特征。从下面的代码中可知,它本质上就是一个由两个普通线性层构成的模块。

self.time_embed = nn.Sequential(
            linear(model_channels, time_embed_dim),
            nn.SiLU(),
            linear(time_embed_dim, time_embed_dim),
        )

再来看U-Net最后面的输出模块out。输出模块的结构也很简单,它主要包含了一个卷积层,用于把中间变量的通道数从dims变成model_channels

self.out = nn.Sequential(
            normalization(ch),
            nn.SiLU(),
            zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
        )

接下来,我们把目光聚焦在U-Net的三个核心模块上:input_blocksmiddle_blockoutput_blocks。这三个模块的组成都很类似,都用到了残差块ResBlock和注意力块。稍有不同的是,input_blocks的每一大层后面都有一个下采样模块,output_blocks的每一大层后面都有一个上采样模块。上下采样模块的结构都很常规,与经典的U-Net无异。我们把学习的重点放在残差块和注意力块上。我们先看这两个模块的内部实现细节,再来看它们是怎么拼接起来的。

Stable Diffusion的U-Net中的ResBlock和原DDPM的U-Net的ResBlock功能完全一样,都是在普通残差块的基础上,支持时间戳编码的额外输入。具体来说,普通的残差块是由两个卷积模块和一条短路连接构成的,即y = x + conv(conv(x))。如果经过两个卷积块后数据的通道数发生了变化,则要在短路连接上加一个转换通道数的卷积,即y = conv(x) + conv(conv(x))

在这种普通残差块的基础上,扩散模型中的残差块还支持时间戳编码t的输入。为了把t和输入x的信息融合在一起,t会和经过第一个卷积后的中间结果conv(x)加在一起。可是,t的通道数和conv(x)的通道数很可能会不一样。通道数不一样的数据是不能直接加起来的。为此,每一个残差块中都有一个用于转换t通道数的线性层。这样,tconv(x)就能相加了。整个模块的计算可以表示成y=conv(x) + conv(conv(x) + linear(t))。残差块的示意图和源代码如下:

代码解析:

class ResBlock(TimestepBlock):
    def __init__(self, ...):
        super().__init__()
        ...

        self.in_layers = nn.Sequential(
            normalization(channels),
            nn.SiLU(),
            conv_nd(dims, channels, self.out_channels, 3, padding=1),
        )

        self.emb_layers = nn.Sequential(
            nn.SiLU(),
            linear(
                emb_channels,
                2 * self.out_channels if use_scale_shift_norm else self.out_channels,
            ),
        )
        self.out_layers = nn.Sequential(
            normalization(self.out_channels),
            nn.SiLU(),
            nn.Dropout(p=dropout),
            zero_module(
                conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
            ),
        )

        if self.out_channels == channels:
            self.skip_connection = nn.Identity()
        elif use_conv:
            self.skip_connection = conv_nd(
                dims, channels, self.out_channels, 3, padding=1
            )
        else:
            self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)

    def forward(self, x, emb):
        h = self.in_layers(x)
        emb_out = self.emb_layers(emb).type(h.dtype)
        while len(emb_out.shape) 

代码中的in_layers是第一个卷积模块,out_layers是第二个卷积模块。skip_connection是用于调整短路连接通道数的模块。若输入输出的通道数相同,则该模块是一个恒等函数,不对数据做任何修改。emb_layers是调整时间戳编码通道数的线性层模块。这些模块的定义都在ResBlock__init__里。它们的结构都很常规,没有值得注意的地方。我们可以着重阅读模型的forward方法。

如前文所述,在forward中,输入x会先经过第一个卷积模块in_layers,再与经过了emb_layers调整的时间戳编码emb相加后,输入进第二个卷积模块out_layers。最后,做完计算的数据会和经过了短路连接的原输入skip_connection(x)加在一起,作为整个残差块的输出。

def forward(self, x, emb):
    h = self.in_layers(x)
    emb_out = self.emb_layers(emb).type(h.dtype)
    while len(emb_out.shape) 

这里有一点实现细节需要注意。时间戳编码emb_out的形状是[n, c]。为了把它和形状为[n, c, h, w]的图片加在一起,需要把它的形状变成[n, c, 1, 1]后再相加(形状为[n, c, 1, 1]的数据在与形状为[n, c, h, w]的数据做加法时形状会被自动广播成[n, c, h, w])。在PyTorch中,x=x[..., None]可以在一个数据最后加一个长度为1的维度。比如对于形状为[n, c]tt[..., None]的形状就会是[n, c, 1]

残差块的内容到此结束。

我们接着来看注意力模块。在看模块的具体实现之前,我们先看一下源代码中有哪几种注意力模块。在U-Net的代码中,注意力模型是用以下代码创建的:

if ds in attention_resolutions:
    layers.append(
        AttentionBlock(...) if not use_spatial_transformer else SpatialTransformer(...)
    )

第一行if ds in attention_resolutions:用于控制在U-Net的哪几个大层。Stable Diffusion每一大层都用了注意力模块,可以忽略这一行。随后,代码根据是否设置use_spatial_transformer来创建AttentionBlock或是SpatialTransformerAttentionBlock是DDPM中采样的普通自注意力模块,而SpatialTransformer是LDM中提出的支持额外约束的标准Transfomer块。Stable Diffusion使用的是SpatialTransformer。我们就来看一看这个模块的实现细节。

如前所述,SpatialTransformer使用的是标准的Transformer块,它和Transformer中的Transformer块完全一致。输入x先经过一个自注意力层,再过一个交叉注意力层。在此期间,约束编码c会作为交叉注意力层的K, V输入进模块。最后,数据经过一个全连接层。每一层的输入都会和输出做一个残差连接。

当然,标准Transformer是针对一维序列数据的。要把Transformer用到图像上,则需要把图像的宽高拼接到同一维,即对张量做形状变换n c h w -> n c (h * w)。做完这个变换后,就可以把数据直接输入进Transformer模块了。 这些图像数据与序列数据的适配都是在SpatialTransformer里完成的。SpatialTransformer类并没有直接实现Transformer块的细节,仅仅是U-Net和Transformer块之间的一个过渡。Transformer块的实现在它的一个子模块里。我们来看它的实现代码。

SpatialTransformer有两个卷积层proj_inproj_out,负责图像通道数与Transformer模块通道数之间的转换。SpatialTransformertransformer_blocks才是真正的Transformer模块。

class SpatialTransformer(nn.Module):

    def __init__(self, in_channels, n_heads, d_head,
                 depth=1, dropout=0., context_dim=None):
        super().__init__()
        self.in_channels = in_channels
        inner_dim = n_heads * d_head
        self.norm = Normalize(in_channels)

        self.proj_in = nn.Conv2d(in_channels,
                                 inner_dim,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)

        self.transformer_blocks = nn.ModuleList(
            [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
                for d in range(depth)]
        )

        self.proj_out = zero_module(nn.Conv2d(inner_dim,
                                              in_channels,
                                              kernel_size=1,
                                              stride=1,
                                              padding=0))

forward中,图像数据在进出Transformer模块前后都会做形状和通道数上的适配。运算结束后,结果和输入之间还会做一个残差连接。context就是约束信息编码,它会接入到交叉注意力层上。

def forward(self, x, context=None):
    b, c, h, w = x.shape
    x_in = x
    x = self.norm(x)
    x = self.proj_in(x)
    x = rearrange(x, 'b c h w -> b (h w) c')
    for block in self.transformer_blocks:
        x = block(x, context=context)
    x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
    x = self.proj_out(x)
    return x + x_in

每一个Transformer模块的结构完全符合上文的示意图。如果你之前学过Transformer,那这些代码你会十分熟悉。我们快速把这部分代码浏览一遍。

class BasicTransformerBlock(nn.Module):
    def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
        super().__init__()
        self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout)  # is a self-attention
        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
        self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
                                    heads=n_heads, dim_head=d_head, dropout=dropout)  # is self-attn if context is none
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)
        self.checkpoint = checkpoint

    def forward(self, x, context=None):
        x = self.attn1(self.norm1(x)) + x
        x = self.attn2(self.norm2(x), context=context) + x
        x = self.ff(self.norm3(x)) + x
        return x

自注意力层和交叉注意力层都是用CrossAttention实现的。该模块与Transformer论文中的多头注意力机制完全相同。当forward的参数context=None时,模块其实只是一个提取特征的自注意力模块;而当context为约束文本的编码时,模块就是一个根据文本约束进行运算的交叉注意力模块。该模块用不到mask,相关的代码可以忽略。

class CrossAttention(nn.Module):
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)

        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context=None, mask=None):
        h = self.heads

        q = self.to_q(x)
        context = default(context, x)
        k = self.to_k(context)
        v = self.to_v(context)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

        if exists(mask):
            ...

        # attention, what we cannot get enough of
        attn = sim.softmax(dim=-1)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        return self.to_out(out)

Transformer块的内容到此结束。看完了SpatialTransformerResBlock,我们可以回头去看模块之间是怎么拼接的了。先来看U-Net的中间块。它其实就是一个ResBlock接一个SpatialTransformer再接一个ResBlock

self.middle_block = TimestepEmbedSequential(
    ResBlock(...),
    SpatialTransformer(...),
    ResBlock(...),
)

下采样块input_blocks和上采样块output_blocks的结构几乎一模一样,区别只在于每一大层最后是做下采样还是上采样。这里我们以下采样块为例来学习一下这两个块的结构。

self.input_blocks = nn.ModuleList(
    [
        TimestepEmbedSequential(
            conv_nd(dims, in_channels, model_channels, 3, padding=1)
        )
    ]
)

for level, mult in enumerate(channel_mult):
    for _ in range(num_res_blocks):
        layers = [
            ResBlock(...)]
        ch = mult * model_channels
        if ds in attention_resolutions:
                layers.append(
                AttentionBlock(...) if not use_spatial_transformer else SpatialTransformer(...))

        self.input_blocks.append(TimestepEmbedSequential(*layers))
    if level != len(channel_mult) - 1:
        out_ch = ch
        self.input_blocks.append(
            TimestepEmbedSequential(
                ResBlock(...)
                if resblock_updown
                else Downsample(...)
            )
        )

上采样块一开始是一个调整输入图片通道数的卷积层,它的作用和self.out输出层一样。

self.input_blocks = nn.ModuleList(
    [
        TimestepEmbedSequential(
            conv_nd(dims, in_channels, model_channels, 3, padding=1)
        )
    ]
)

之后正式进行上采样块的构造。此处代码有两层循环,外层循环表示正在构造哪一个大层,内层循环表示正在构造该大层的哪一组模块。也就是说,共有len(channel_mult)个大层,每一大层都有num_res_blocks组相同的模块。在Stable Diffusion中,channel_mult=[1, 2, 4, 4]num_res_blocks=2

for level, mult in enumerate(channel_mult):
    for _ in range(num_res_blocks):
        ...

每一组模块由一个ResBlock和一个SpatialTransformer构成。

layers = [
    ResBlock(...)
]
ch = mult * model_channels
if ds in attention_resolutions:
    ...
    layers.append(
        SpatialTransformer(...)
    )
self.input_blocks.append(TimestepEmbedSequential(*layers))
...

构造完每一组模块后,若现在还没到最后一个大层,则添加一个下采样模块。Stable Diffusion有4个大层,只有运行到前3个大层时才会添加下采样模块。

for level, mult in enumerate(channel_mult):
    for _ in range(num_res_blocks):
        ...
    if level != len(channel_mult) - 1:
        out_ch = ch
        self.input_blocks.append(
            TimestepEmbedSequential(
                ResBlock(...)
                if resblock_updown
                else Downsample(...)
            )
        )
        ch = out_ch
        input_block_chans.append(ch)
        ds *= 2

至此,我们已经学完了Stable Diffusion的U-Net的主要实现代码。让我们来总结一下。U-Net是一种先对数据做下采样,再做上采样的网络结构。为了防止信息丢失,下采样模块和对应的上采样模块之间有残差连接。下采样块、中间块、上采样块都包含了ResBlockSpatialTransformer两种模块。ResBlock是图像网络中常使用的残差块,而SpatialTransformer是能够融合图像全局信息并融合不同模态信息的Transformer块。Stable Diffusion的U-Net的输入除了有图像外,还有时间戳t和约束编码ct会先过几个嵌入层和线性层,再输入进每一个ResBlock中。c会直接输入到所有Transformer块的交叉注意力块中。

Diffusers的源码会在下篇文章中解读。敬请期待!

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

文章来源于互联网:天才程序员周弈帆 | Stable Diffusion 解读(三):原版实现源码解读(篇幅略长,建议收藏!)

赞(0)
未经允许不得转载:5bei.cn大模型教程网 » 天才程序员周弈帆 | Stable Diffusion 解读(三):原版实现源码解读(篇幅略长,建议收藏!)

天才程序员周弈帆 | Stable Diffusion 解读(二):论文精读

本文来源公众号“天才程序员周弈帆”,仅用于学术分享,侵权删,干货满满。

原文链接:Stable Diffusion 解读(二):论文精读

【小小题外话】端午安康!

在上一篇文章天才程序员周弈帆 | Stable Diffusion 解读(一):回顾早期工作-CSDN博客中,我们梳理了基于自编码器(AE)的图像生成模型的发展脉络,并引出了Stable Diffusion的核心思想。简单来说,Stable Diffusion是一个两阶段的图像生成模型,它先用一个AE压缩图像,再在压缩图像所在的隐空间上用DDPM生成图像。在这篇文章中,我们来精读Stable Diffusion的论文:High-Resolution Image Synthesis with Latent Diffusion Models

注意:如果你从未学习过扩散模型,Stable Diffusion并不是你应该的读的第一篇论文。请参照我的 上一篇文章天才程序员周弈帆 | Stable Diffusion 解读(一):回顾早期工作早期工作总结,至少在学会了DDPM后再来学习Stable Diffusion。

1 摘要与引言

论文摘要的大意如下:扩散模型的生成效果很好,但是,在像素空间上训练和推理扩散模型的计算开销都很大。为了在不降低质量与易用性的前提下用较少的计算资源训练扩散模型,我们在一个预训练过的自编码器的隐空间上使用扩散模型。相较以往的工作,在这种表示下训练扩散模型首次在减少计算复杂度和维持图像细节间达到几近最优的平衡点,极大地提升了视觉保真度。通过向模型架构中引入交叉注意力层,我们把扩散模型变成了强大而灵活的带约束图像生成器,它支持常见的约束,如文字、边界框,且能够以纯卷积方式实现高分辨率的图像合成。我们的隐扩散模型(latent diffusion model, LDM) 在使用比像素扩散模型少得多的计算资源的前提下,在各项图像合成任务上取得最优成果或顶尖成果。

整理一下。论文提出了一种叫LDM的图像生成模型。论文想解决的问题是减少像素空间扩散模型的运算开销。为此,LDM借助了VQVAE「先压缩、再生成」的想法,把扩散模型用在AE的隐空间上,在几乎不降低生成质量的前提下减少了计算量。另外,LDM还支持带约束图像合成及纯卷积图像超分辨率。

在上一篇回顾LDM早期工作的文章天才程序员周弈帆 | Stable Diffusion 解读(一):回顾早期工作中,我们已经理解了LDM想解决的问题及解决问题的思路。因此,在读完摘要后,我们接下来读文章时只需要关注LDM的两个创新点

  1. LDM的AE是怎么设计以达到压缩比例与质量的平衡的

  2. LDM怎么实现带约束的图像合成

引言基本是摘要的扩写。首先,引言大致介绍了图像合成任务的背景,提及了扩散模型近期的突出表现。随后,引言介绍了本文想解决的主要问题:扩散模型的训练和推理太耗时了,需要在不降低效果的前提下减少扩散模型的运算量。最后,引言揭示了本工作的解决方法:使用类似VQGAN的两阶段图像生成方法

引言的前两部分没有什么关键信息,而最后一部分介绍了本工作改进扩散模型的动机,值得一读。如下图所示,DDPM的论文展示了从不同去噪时刻的同一个噪声图像开始的不同生成结果,比如x_750指从时刻t=750的去噪图像开始,多次以不同随机数执行DDPM的反向过程,生成的多幅图像。LDM作者认为,DDPM的这一实验表明,扩散模型的图像生成分两个阶段:先是对语义进行压缩,再是对图像的感知细节压缩。正因此,随机对早期的噪声图像去噪,生成图像的内容会更多样;而随机对后期的噪声图像去噪,生成图像只是在细节上有所不同。LDM的作者认为,扩散模型的大量计算都浪费在了生成整幅图像的细节上,不如只让扩散模型描述比较关键的语义压缩部分,而让自编码器(AE)负责感知细节压缩部分

引言在结尾总结了本工作的贡献

  1. 相比之前按序列处理图像的纯Transformer的方法,扩散模型能更好地处理二维数据。因此,LDM生成隐空间图像时不需要那么重的压缩比例(比如DIV2K数据集上,LDM只需要将图像下采样4倍,而之前的纯Transformer方法要下采样8倍或16倍),图像在压缩时能有更高的保真度,整套方法能更高效地生成高分辨率图像。

  2. 在大幅降低计算开销的前提下在多项图像生成任务上取得了顶尖成果

  3. 相比于之前同时训练图像压缩模型图像生成模型的方法,该方法分步训练两个模型,训练起来更加简单。

  4. 对于有着稠密约束的任务(如超分辨率、补全、语义生成),该方法的模型能换成一个纯卷积版本的,且能生成边长为1024的图像。

  5. 该工作设计了一种通用的约束机制,该机制基于交叉注意力支持多模态训练。作者训练了多种带约束的模型。

  6. 作者把工作开源了,并提供了预训练模型

我们来整理一下这些贡献。读论文时,可以忽略第6条。第2条是成果,与方法设计无关。第1、3条主要描述了提出两阶段图像生成建模方法的贡献。第4条是把方法拓展到稠密约束任务的贡献。第5条是提出了新约束机制的贡献。所以,在学习论文的方法时,我们还是主要关注摘要里就提过的那两个创新点。在读完引言后,我们可以把阅读目标再细化一下

  1. LDM的AE是怎么设计以达到压缩比例与质量的平衡的。与纯基于Transformer的VQGAN相比,它有什么不同。

  2. LDM怎么用交叉注意力机制实现带约束的图像生成。

2 相关工作

作者主要从两个角度回顾了早期工作:不同架构的图像生成模型两阶段的图像合成方法。其回顾逻辑与本系列的第一篇文章天才程序员周弈帆 | Stable Diffusion 解读(一):回顾早期工作类似,在此就不过多介绍了。除了介绍早期工作外,作者重申了引言中的对比结果,强调了LDM相对于扩散模型的创新和相对于两阶段图像生成模型的创新。

3 方法

在方法章节中,作者先是大致介绍了使用LDM这种两阶段图像生成架构的优点,再分三部分详细介绍了论文的实现细节:图像压缩AE的实现、LDM的实现、约束的实现。开头的介绍和AE的实现相对比较重要,我们放在一起详细阅读;相对于DDPM,LDM几乎没有做任何修改,只是把要拟合的图片从真实图片换成了压缩图片,这一部分我们会快速浏览一遍;而添加约束的方法有所创新,我们会详细阅读一遍。

3.1 AE与两阶段图像生成模型

我们来先读3.1节,看一看AE的具体实现方法,再回头读第3节开头介绍的两阶段图像生成模型的优点。

LDM配套的图像压缩模型(论文中称之为“感知压缩模型”和VQGAN几乎完全一样。该压缩模型的原型是一个AE。普通的AE会用原图像和重建图像的重建误差(L1误差或者L2误差)来训练。在普通的AE的基础上,该压缩模型参考了GAN的误差设置方法,使用感知误差代替重建误差,并添加了基于patch的对抗误差

但该图像压缩模型的输出与VQGAN有所不同。我们先回忆一下VQGAN的原理。VQGAN的输出会接到Transformer里,Transformer的输入必须是离散的。因此,VQGAN必须要额外完成两件事:1)让连续输出变成离散输出;2)用正则化方法防止过拟合。为此,VQGAN使用了VQVAE里的向量离散化操作,该操作能同时完成这两件事。

而LDM的压缩模型的输出会接入一个扩散模型里,扩散模型的输入是连续的。因此,LDM的压缩模型只需要额外完成使用正则化方法这一件事。该压缩模型不必像VQGAN一样非得用向量离散化来完成正则化。如我们在第一篇文章中讨论的,作者在LDM的压缩模型中使用了两种正则化方法:VQ正则化与KL正则化。前者来自于VQVAE,后者来自于VAE。

该压缩模型相较VQGAN有一项明显的优势。VQGAN的Transformer只能按一维序列来处理图像(通过把二维图像reshape成一维),且只能处理较小的压缩图像(16 x 16)。而本身用于二维图像生成的LDM能更好地利用二维信息,因此可以处理更大的压缩图像(64 x 64)。这样,LDM的压缩模型的压缩程度不必那么重,其保真度会比VQGAN高

看完了3.1节,我们来回头看第3节开头介绍了LDM的三项优点:1)通过规避在高维图像空间上训练扩散模型,作者开发出了一个因在低维空间上采样而计算效率大幅提升的扩散模型;2)作者发掘了扩散模型中来自U-Net架构的归纳偏置(inductive bias),使得它们能高效地处理有空间结构的数据(比如二维图像),避免像之前基于Transformer的方法一样使用激进、有损质量的压缩比例;3)本工作的压缩模型是通用的,它的隐空间能用来训练多种图像生成模型。第一个优点是相对于DDPM。第二个是优点是相对于使用Transformer的VQGAN,我们在上一段已经分析过了。第三个优点是相对于之前那些换一个任务就需要换一个压缩模型的两阶段图像生成模型。

归纳偏置可以简单理解为某个学习算法对一类数据的优势。比如CNN结构适合处理图像数据。

3.2 隐扩散模型(LDM)

3.3 约束机制

根据论文中实验的设计,对于作用于全局的约束,如文本描述,使用交叉注意力较好;对于有空间信息的约束,如语义分割图片,则用拼接的方式较好

4 实验

在这一章里,作者按照介绍方法的顺序,依次探究了图像压缩模型、无约束图像生成、带约束图像合成的实验结果。我们主要关心前两部分的实验结果。

4.1 感知压缩程度的折衷

4.2 图像生成效果

在这一节中,作者在几个常见的数据集上对比了LDM与其他模型的无约束图像生成效果。作者主要比较了两类指标:表示采样质量的FID表示数据分布覆盖率的精确率及召回率(Precision-and-Recall)

在介绍具体结果之前,先对这个不太常见的精确率及召回率指标做一个解释。精确率及召回率常用于分类等有确定答案的任务中,分别表示所有被分类为正的样本中有多少是分对了的、所有真值为正的样本中有多少是被成功分类成正的。而无约束图像生成中的精确率及召回率的解释可以参加论文Improved Precision and Recall Metric for Assessing Generative Models。如下图所示,设真实分布为蓝色,生成模型的分布为红色,则红色样本落在蓝色分布的比例为精确率,蓝色样本落在红色分布的比例为召回率。简单来说,精确率能描述采样质量,召回率能描述生成分布与真实分布的覆盖情况。

接下来,我们回头来看论文展示的无约束图像生成对比结果,如下图所示。整体上看,LDM的表现还不错。虽然在FID指标上无法超过GAN或其他扩散模型,但是在精确率和召回率上还是颇具优势。唯一没有被LDM战胜的是LSUN-Bedrooms上的ADM模型,但作者提到,相比ADM,LDM只用了一半的参数,且只需四分之一的训练资源。

4.3 带约束图像合成

这一节里,作者展示了LDM的文生图能力。论文中的LDM用了一个从头训练的基于Transformer的文本编码器,与后续使用CLIP的Stable Diffusion差别较大。这一部分的结果没那么重要,大致看一看就好。

本文的文生图模型是一个在LAION-400M数据集上训练的KL约束LDM。它的文本编码器是一个Transformer,编码后的特征会以交叉注意力的形式传入LDM。采样时,LDM使用了Classifier-Free Guidance。

Classifier-Free Guidance可以让输出图片更符合文本约束。这是一种适用于所有扩散模型的采样策略,并非要和LDM绑定,感兴趣可以去阅读相关论文。

LDM与其他模型的文生图效果对比如下图所示。虽然这个版本的LDM并没有显著优于其他模型,但它的参数量是最少的。

5 总结

论文末尾探讨了LDM的两大不足。首先,尽管LDM的计算需求比其他像素空间上的扩散模型要少得多,但受制于扩散模型本身的串行采样,它的采样速度还是比GAN慢上许多。其次,LDM使用了一个自编码器来压缩图像,重建图像带来的精度损失会成为某些需要精准像素值的任务的性能瓶颈。

论文最后再次总结了此方法的贡献。LDM的主要贡献其实只有两点:在不损失效果的情况下用两阶段的图像生成方法大幅提升了训练和采样效率、借助交叉注意力实现了各任务通用的约束机制。这两个贡献总结得非常精准。之后的Stable Diffusion之所以大受欢迎,第一就是因为它采样所需的计算资源不多,大众能使用消费级显卡完成图像生成,第二就是因为它强大的文字转图片生成效果。

我们再从知识学习的角度总结一下LDM。LDM的核心知识是DDPM和VQGAN。如果你能看懂之前这两篇论文,那你一下子就能明白LDM是的核心思想是什么,看论文时只需要精读交叉注意力约束机制那一段即可,其他实验内容在现在看来已经价值不大了。由于近两年有大量基于Stable Diffusion开发的工作,相比论文,阅读源代码的重要性会大很多。我们会在下一篇文章里详细学习Stable Diffusion的官方源码最常用的Stable Diffusion第三方实现——Diffusers框架

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

文章来源于互联网:天才程序员周弈帆 | Stable Diffusion 解读(二):论文精读

赞(0)
未经允许不得转载:5bei.cn大模型教程网 » 天才程序员周弈帆 | Stable Diffusion 解读(二):论文精读

天才程序员周弈帆 | Stable Diffusion 解读(四):Diffusers实现源码解读

本文来源公众号“天才程序员周弈帆”,仅用于学术分享,侵权删,干货满满。

原文链接:Stable Diffusion 解读(四):Diffusers实现源码解读

接上一篇文章[天才程序员周弈帆 | Stable Diffusion 解读(三):原版实现源码解读(篇幅略长,建议收藏!)-CSDN博客],我们来学习Stable Diffusion在Diffusers中的实现。

本文用到的Stable Diffusion版本是v1.5。Diffusers版本是0.25.0。为了提升可读性,本文对源代码做了一定的精简,部分不会运行到的分支会被略过。

1 Diffusers

Diffusers是由Hugging Face维护的一套Diffusion框架。这个库的代码被封装进了一个Python模块里,我们可以在安装了Diffusers的Python环境中用import diffusers随时调用该库。相比之下,Diffusers的代码架构更加清楚,且各类Stable Diffusion的新技术都会及时集成进Diffusers库中。

由于我们已经在上篇文章中学过了Stable Diffusion官方源码,在学习Diffusers代码时,我们只会大致过一过每一段代码是在做什么,而不会赘述Stable Diffusion的原理。

1.1 安装

安装该库时,不需要克隆仓库,只需要直接用pip即可。

pip install --upgrade diffusers[torch]

之后,随便在某个地方创建一个Python脚本文件,输入官方的示例项目代码。

from diffusers import DiffusionPipeline
import torch

pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipeline.to("cuda")
pipeline("An image of a squirrel in Picasso style").images[0].save('output.jpg')

运行代码后,”一幅毕加索风格的松鼠图片”的绘制结果会保存在output.jpg中。我得到的结果如下:

在Diffusers中,from_pretrained函数可以直接从Hugging Face的模型仓库中下载预训练模型。比如,示例代码中from_pretrained("runwayml/stable-diffusion-v1-5", ...)指的就是从模型仓库https://huggingface.co/runwayml/stable-diffusion-v1-5中获取模型。

如果在当前网络下无法从命令行中访问Hugging Face,可以先想办法在网页上访问上面的模型仓库,手动下载v1-5-pruned.ckpt。之后,克隆Diffusers的GitHub仓库,再用Diffusers的工具把Stable Diffusion原版模型文件转换成Diffusers支持的模型格式。

git clone git@github.com:huggingface/diffusers.git
cd diffusers
python scripts/convert_original_stable_diffusion_to_diffusers.py --checkpoint_path  --dump_path 

比如,假设你的模型文件存在ckpt/v1-5-pruned.ckpt,你想把输出的Diffusers的模型文件存在ckpt/sd15,则应该输入:

python scripts/convert_original_stable_diffusion_to_diffusers.py --checkpoint_path ckpt/v1-5-pruned.ckpt --dump_path ckpt/sd15 

之后修改示例脚本中的路径,就可以成功运行了。

from diffusers import DiffusionPipeline
import torch

pipeline = DiffusionPipeline.from_pretrained("ckpt/sd15", torch_dtype=torch.float16)
pipeline.to("cuda")
pipeline("An image of a squirrel in Picasso style").images[0].save('output.jpg')

对于其他的原版SD checkpoint(比如在civitai上下载的),也可以用同样的方式把它们转换成Diffusers兼容的版本。

1.2 采样

Diffusers使用Pipeline来管理一类图像生成算法。和图像生成相关的模块(如U-Net,DDIM采样器)都是Pipeline的成员变量。打开Diffusers版Stable Diffusion模型的配置文件model_index.json(在 https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/model_index.json 网页上直接访问或者在本地的模型文件夹中找到),我们能看到该模型使用的Pipeline:

{
  "_class_name": "StableDiffusionPipeline",
  ...
}

diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py中,我们能找到StableDiffusionPipeline类的定义。所有Pipeline类的代码都非常长,一般我们可以忽略其他部分,只看运行方法__call__里的内容。

def __call__(
    self,
    prompt: Union[str, List[str]] = None,
    height: Optional[int] = None,
    width: Optional[int] = None,
    num_inference_steps: int = 50,
    timesteps: List[int] = None,
    guidance_scale: float = 7.5,
    negative_prompt: Optional[Union[str, List[str]]] = None,
    num_images_per_prompt: Optional[int] = 1,
    eta: float = 0.0,
    ...
):

    # 0. Default height and width to unet
    height = height or self.unet.config.sample_size * self.vae_scale_factor
    width = width or self.unet.config.sample_size * self.vae_scale_factor
    # to deal with lora scaling and other possible forward hooks

    # 1. Check inputs. Raise error if not correct
    self.check_inputs(...)

    # 2. Define call parameters
    batch_size = ...

    device = self._execution_device

    # 3. Encode input prompt


    prompt_embeds, negative_prompt_embeds = self.encode_prompt(...)

    # For classifier free guidance, we need to do two forward passes.
    # Here we concatenate the unconditional and text embeddings into a single batch
    # to avoid doing two forward passes
    if self.do_classifier_free_guidance:
        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

    # 4. Prepare timesteps
    timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)

    # 5. Prepare latent variables
    num_channels_latents = self.unet.config.in_channels
    latents = self.prepare_latents(...)

    # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
    ...

    # 7. Denoising loop
    num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
    self._num_timesteps = len(timesteps)
    with self.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

            # predict the noise residual
            noise_pred = self.unet(
                latent_model_input,
                t,
                encoder_hidden_states=prompt_embeds,
                ...
            )[0]

            # perform guidance
            if self.do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

            if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
                # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
                noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)

            # compute the previous noisy sample x_t -> x_t-1
            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]


            # call the callback, if provided
            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                progress_bar.update()
                

    if not output_type == "latent":
        image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
            0
        ]
        image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
    else:
        image = latents
        has_nsfw_concept = None

    ...

    return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

虽然这段代码很长,但代码中的关键内容和我们在上篇文章中写的伪代码完全一致。

def ldm_text_to_image(image_shape, text, ddim_steps = 20, eta = 0)
  ddim_scheduler = DDIMScheduler()
  vae = VAE()
  unet = UNet()
  zt = randn(image_shape)
  eta = input()
  T = 1000
  timesteps = ddim_scheduler.get_timesteps(T, ddim_steps) # [1000, 950, 900, ...]

  text_encoder = CLIP()
  c = text_encoder.encode(text)

  for t = timesteps:
    eps = unet(zt, t, c)
    std = ddim_scheduler.get_std(t, eta)
    zt = ddim_scheduler.get_xt_prev(zt, t, eps, std)
  xt = vae.decoder.decode(zt)
  return xt

我们可以对照着上面的伪代码来阅读这个方法。经过Diffusers框架本身的一些前处理后,方法先获取了约束文本的编码。

# 3. Encode input prompt
# c = text_encoder.encode(text)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(...)

方法再从采样器里获取了要用到的时间戳,并随机生成了一个初始噪声。

# Preprocess
...

# 4. Prepare timesteps
# timesteps = ddim_scheduler.get_timesteps(T, ddim_steps)
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)

# 5. Prepare latent variables
# zt = randn(image_shape)
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
    ...
)

做完准备后,方法进入去噪循环。循环一开始是用U-Net算出当前应去除的噪声noise_pred。由于加入了CFG,U-Net计算的前后有一些对数据形状处理的代码。

with self.progress_bar(total=num_inference_steps) as progress_bar:
    for i, t in enumerate(timesteps):
        # eps = unet(zt, t, c)

        # expand the latents if we are doing classifier free guidance
        latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
        latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

        # predict the noise residual
        noise_pred = self.unet(
            latent_model_input,
            t,
            encoder_hidden_states=prompt_embeds,
            ...
        )[0]

        # perform guidance
        if self.do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

        if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
            # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
            noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)

有了应去除的噪声,方法会调用扩散模型采样器对当前的噪声图片进行更新。Diffusers把采样的逻辑全部封装进了采样器的step方法里。对于包括DDIM在内的所有采样器,都可以调用这个通用的接口,完成一步采样。eta等采样器参数会通过**extra_step_kwargs传入采样器的step方法里。

# std = ddim_scheduler.get_std(t, eta)
# zt = ddim_scheduler.get_xt_prev(zt, t, eps, std)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

经过若干次循环后,我们得到了隐空间下的生成图片。我们还需要调用VAE把隐空间图片解码成普通图片。代码中的self.vae.decode(latents / self.vae.config.scaling_factor, ...)用于解码图片。

if not output_type == "latent":
    image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
        0
    ]
    image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
    image = latents
    has_nsfw_concept = None

...

return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

就这样,我们很快就看完了Diffusers的采样代码。相比之下,Diffusers的封装确实更合理,主要的图像生成逻辑都写在Pipeline类的__call__里,剩余逻辑都封装在VAE、U-Net、采样器等各自的类里。

1.3 U-Net

接下来我们来看Diffusers中的U-Net实现。还是打开模型配置文件model_index.json,我们可以找到U-Net的类名。

{
  ...
  "unet": [
    "diffusers",
    "UNet2DConditionModel"
  ],
  ...
}

diffusers/models/unet_2d_condition.py文件中,我们可以找到类UNet2DConditionModel。由于Diffusers集成了非常多新特性,整个文件就像一锅大杂烩一样,掺杂着各种功能的实现代码。不过,这份U-Net的实现还是基于原版Stable Diffusion的U-Net进行开发的,原版代码的每一部分都能在这份代码里找到对应。在阅读代码时,我们可以跳过无关的功能,只看我们在Stable Diffusion官方仓库中见过的部分。

先看初始化函数的主要内容。初始化函数依然主要包括time_proj, time_embeddingdown_blocksmid_blockup_blocksconv_inconv_out这几个模块。

class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
    def __init__(...):
        ...
        self.conv_in = nn.Conv2d(
            in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
        )
        ...
        elif time_embedding_type == "positional":
            self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
        ...
        self.time_embedding = TimestepEmbedding(...)
        self.down_blocks = nn.ModuleList([])
        self.up_blocks = nn.ModuleList([])
        for i, down_block_type in enumerate(down_block_types):
            ...
            down_block = get_down_block(...)
        
        if mid_block_type == ...
            self.mid_block = ...

        for i, up_block_type in enumerate(up_block_types):
            up_block = get_up_block(...)

        self.conv_out = nn.Conv2d(...)

其中,较为重要的down_blocksmid_blockup_blocks都是根据模块类名称来创建的。我们可以在Diffusers的Stable Diffusion模型文件夹的U-Net的配置文件unet/config.json中找到对应的模块类名称。

{
    ...
    "down_block_types": [
    "CrossAttnDownBlock2D",
    "CrossAttnDownBlock2D",
    "CrossAttnDownBlock2D",
    "DownBlock2D"
  ],
  "mid_block_type": "UNetMidBlock2DCrossAttn",
  "up_block_types": [
    "UpBlock2D",
    "CrossAttnUpBlock2D",
    "CrossAttnUpBlock2D",
    "CrossAttnUpBlock2D"
  ],
  ...
}

diffusers/models/unet_2d_blocks.py中,我们可以找到这几个模块类的定义。和原版代码一样,这几个模块的核心组件都是残差块和Transformer块。在Diffusers中,残差块叫做ResnetBlock2D,Transformer块叫做Transformer2DModel。这几个类的执行逻辑和原版仓库的也几乎一样。比如CrossAttnDownBlock2D的定义如下:

class CrossAttnDownBlock2D(nn.Module):
    def __init__(...):
        for i in range(num_layers):
            resnets.append(ResnetBlock2D(...))
            if not dual_cross_attention:
                attentions.append(Transformer2DModel(...))

接着我们来看U-Net的forward方法。忽略掉其他功能的实现,该方法的主要内容如下:

def forward(
        self,
        sample: torch.FloatTensor,
        timestep: Union[torch.Tensor, float, int],
        encoder_hidden_states: torch.Tensor,
        ...):

    # 0. center input if necessary
    if self.config.center_input_sample:
        sample = 2 * sample - 1.0

    # 1. time
    timesteps = timestep
    t_emb = self.time_proj(timesteps)
    emb = self.time_embedding(t_emb, timestep_cond)

    # 2. pre-process
    sample = self.conv_in(sample)

    # 3. down
    down_block_res_samples = (sample,)
    for downsample_block in self.down_blocks:
        sample, res_samples = downsample_block(
            hidden_states=sample,
            temb=emb,
            encoder_hidden_states=encoder_hidden_states,
            ...)
        down_block_res_samples += res_samples
    # 4. mid
    sample = self.mid_block(
            sample,
            emb,
            encoder_hidden_states=encoder_hidden_states,
            ...)

    # 5. up
    for i, upsample_block in enumerate(self.up_blocks):
        res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
        down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
        sample = upsample_block(
            hidden_states=sample,
            temb=emb,
            res_hidden_states_tuple=res_samples,
            encoder_hidden_states=encoder_hidden_states,
            ...)

     # 6. post-process
    sample = self.conv_out(sample)

    return UNet2DConditionOutput(sample=sample)

该方法和原版仓库的实现差不多,唯一要注意的是栈相关的实现。在方法的下采样计算中,每个downsample_block会返回多个残差输出的元组res_samples,该元组会拼接到栈down_block_res_samples的栈顶。在上采样计算中,代码会根据当前的模块个数,从栈顶一次取出len(upsample_block.resnets)个残差输出。

down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
    sample, res_samples = downsample_block(...)
    down_block_res_samples += res_samples

for i, upsample_block in enumerate(self.up_blocks):
    res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
    down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
    sample = upsample_block(...)

现在,我们已经看完了Diffusers中U-Net的主要内容。可以看出,Diffusers的U-Net包含了很多功能,一般情况下是难以自己更改这些代码的。有没有什么办法能方便地修改U-Net的实现呢?由于很多工作都需要修改U-Net的Attention,Diffusers给U-Net添加了几个方法,用于精确地修改每一个Attention模块的实现。我们来学习一个修改Attention模块的示例。

U-Net类的attn_processors属性会返回一个词典,它的key是每个Attention运算类所在位置,比如down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor,它的value是每个Attention运算类的实例。默认情况下,每个Attention运算类都是AttnProcessor,它的实现在diffusers/models/attention_processor.py文件中。

为了修改Attention运算的实现,我们需要构建一个格式一样的词典attn_processor_dict,再调用unet.set_attn_processor(attn_processor_dict),取代原来的attn_processors。假如我们自己实现了另一个Attention运算类MyAttnProcessor,我们可以编写下面的代码来修改Attention的实现:

attn_processor_dict = {}
for k in unet.attn_processors.keys():
    if we_want_to_modify(k):
        attn_processor_dict[k] = MyAttnProcessor()
    else:
        attn_processor_dict[k] = AttnProcessor()

unet.set_attn_processor(attn_processor_dict)

MyAttnProcessor的唯一要求是,它需要实现一个__call__方法,且方法参数与AttnProcessor的一致。除此之外,我们可以自由地实现Attention处理的细节。一般来说,我们可以先把原来AttnProcessor的实现代码复制过去,再对某些细节做修改。

2 总结

在这篇文章中,我们学习了Stable Diffusion的原版实现和Diffusers实现的主要内容:采样算法和U-Net。具体来说,在原版仓库中,采样的实现一部分在主函数中,一部分在DDIM采样器类中。U-Net由一个简明的PyTorch模块类实现,其中比较重要的子模块是残差块和Transformer块。相比之下,Diffusers实现的封装更好,功能更多。Diffusers用一个Pipeline类来维护采样过程。Diffusers的U-Net实现与原版完全相同,且支持更复杂的功能。此外,Diffusers还给U-Net提供了精确修改Attention计算的接口。

不管是哪个Stable Diffusion的框架,都会提供一些相同的原子操作。各种基于Stable Diffusion的应用都应该基于这些原子操作开发,而无需修改这些操作的细节。在学习时,我们应该注意这些操作在不同的框架下的写法是怎么样的。常用的原子操作包括:

  • VAE的解码和编码

  • 文本编码器(CLIP)的编码

  • 用U-Net预测当前图像应去除的噪声

  • 用采样器计算下一去噪迭代的图像

在原版仓库中,相关的实现代码如下:

# VAE的解码和编码
model.decode_first_stage(...)
model.encode_first_stage(...)

# 文本编码器(CLIP)的编码
model.get_learned_conditioning(...)

# 用U-Net预测当前图像应去除的噪声
model.apply_model(...)

# 用采样器计算下一去噪迭代的图像
p_sample_ddim(...)

在Diffusers中,相关的实现代码如下:

# VAE的解码和编码
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
latents = self.vae.encode(image).latent_dist.sample(generator) * self.vae.config.scaling_factor

# 文本编码器(CLIP)的编码
self.encode_prompt(...)

# 用U-Net预测当前图像应去除的噪声
self.unet(..., return_dict=False)[0]

# 用采样器计算下一去噪迭代的图像
self.scheduler.step(..., return_dict=False)[0]

如今zero-shot(无需训练)的Stable Diffusion编辑技术一般只会修改采样算法和Attention计算,需训练的编辑技术有时会在U-Net里加几个模块。只要我们熟悉了普通的Stable Diffusion是怎么样生成图像的,知道原来U-Net的结构是怎么样的,我们在阅读新论文的源码时就可以把这份代码与原来的代码进行对比,只看那些有修改的部分。相信读完了本文后,我们不仅加深了对Stable Diffusion本身的理解,以后学习各种新出的Stable Diffusion编辑技术时也会更加轻松。

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

文章来源于互联网:天才程序员周弈帆 | Stable Diffusion 解读(四):Diffusers实现源码解读

赞(0)
未经允许不得转载:5bei.cn大模型教程网 » 天才程序员周弈帆 | Stable Diffusion 解读(四):Diffusers实现源码解读

天才程序员周弈帆 | Stable Diffusion 解读(一):回顾早期工作

本文来源公众号“天才程序员周弈帆”,仅用于学术分享,侵权删,干货满满。

原文链接:Stable Diffusion 解读(一):回顾早期工作

在2022年的这波AI绘画浪潮中,Stable Diffusion无疑是最受欢迎的图像生成模型。究其原因,第一,Stable Diffusion通过压缩图像尺寸显著提升了扩散模型的运行效率,使得每个用户能在自己的商业级显卡上运行模型;第二,有许多基于Stable Diffusion的应用,比如Stable Diffusion自带的文生图、图像补全,以及ControlNet、LoRA、DreamBooth等插件式应用;第三,得益于前两点,Stable Diffusion已经形成了一个庞大的用户社群,大家互相分享模型,交流心得。

不仅是大众,Stable Diffusion也吸引了大量科研人员,很多本来研究GAN的人纷纷转来研究扩散模型。然而,许多人在学习Stable Diffusion时却犯了难:又是公式扎堆的扩散模型,又是VAE,又是U-Net,这该怎么学起呀?

其实,一上来就读Stable Diffusion是很难读懂的。而如果你把之前的一些更基础的文章读懂,再回头来读Stable Diffusion,就会畅行无阻了。在这篇及之后的几篇文章中,我将从科研的角度对Stable Diffusion做一个全面的解读。(1)在第一篇文章中,我将面向完全没接触过图像生成的读者,从头介绍Stable Diffusion是怎样从早期工作中一步一步诞生的;(2)在第二篇文章中,我将详细解读Stable Diffusion的论文;(3)在最后的第三篇文章中,我将带领大家阅读Stable Diffusion的官方源码,以及一些流行的开源库的Stable Diffusion实现。后续我还会写其他和Stable Diffusion相关的文章,比如ControlNet的介绍。

1 从自编码器谈起

包括Stable Diffusion在内,很多图像生成模型都可以看成是一种非常简单的模型——自编码器——的改进版。要谈Stable Diffusion是怎么逐渐诞生的,其实就是在谈自编码器是一步一步进化的。我们的学习就从自编码器开始。

尽管PNG、JPG等图像压缩方法已经非常成熟,但我们会想,会不会还有更好的图像压缩算法呢?图像压缩,其实就是找两个映射,一个把图片编码成压缩数据,另一个把压缩数据解码回图片。我们知道,神经网络理论上可以拟合任何映射。那我们干脆用两个神经网络来拟合两种映射,以实现一个图像压缩算法。负责编码的神经网络叫编码器(Encoder),负责解码的神经网络叫做解码器(Decoder)

光定义了神经网络还不够,我们还需要给两个神经网络设置一个学习目标。在运行过程中,神经网络应该满足一个显然的约束:编码再解码后的重建图像应该和原图像尽可能一致,即二者的均方误差应该尽可能小。这样,我们只需要随便找一张图片,通过编码器和解码器得到重建图像,就能训练神经网络了。我们不需要给图片打上标签,整个训练过程是自监督的。所以我们说,整套模型是一个自编码器(Autoencoder,AE)

图像压缩模型AE为什么会和图像生成扯上关系呢?你可以试着把AE的输入图像和编码器遮住,只看解码部分。把一个压缩数据解码成图像,换个角度看,不就是在根据某一数据生成图像嘛。

很可惜,AE并不是一个合格的图像生成模型。我们常说的图像生成,具体是指让程序生成各种各样的图片。为了让程序生成不同的图片,我们一般是让程序根据随机数(或是随机向量)来生成图片。而普通的AE会有过拟合现象,这导致AE的解码器只认得训练集里的图片经编码器解码出来的压缩数据,而不认得随机生成的压缩数据,进而也无法达到图像生成的要求。

所谓过拟合,就是指模型只能处理训练数据,而不能推广到一般的数据上。举一个极端的例子,如下图所示,编码器和解码器直接记忆了整个数据集,把所有图片压缩成了一个数字。也就是模型把编码器当成一个图片到数字的词典,把解码器当成一个数字到图片的词典。这样,不管数据集有多大,所有图片都可以被压缩成一个数字。这样的AE确实压缩能力很强,但它完全没用,因为它过拟合了,处理不了训练集以外的数据。

过拟合现象在普通版AE中是不可避免的。为了利用AE的解码器来生成图片,许多工作都在试图克服AE的过拟合现象。AE的改进思路很多,在这篇文章中,我们仅把AE的改进路线粗略地分成两种:解决过拟合问题以直接用AE做图像生成、用AE压缩图像间接实现图像生成。

2 第一条路线:VAE 和 DDPM

在第一条改进路线中,许多后续工作都试图用更高级的数学模型来解决AE的过拟合问题。变分自编码器(Variational Autoencoder, VAE) 就是其中的代表。

VAE对AE做了若干改动。第一,VAE让编码器的输出不再是一个确定的数据,而是一个正态分布中的一个随机数据。更具体一点,训练时,编码器会同时输出一个均值和方差。随后,模型会从这个均值和方差表达的正态分布里随机采样一个数据,作为解码器的输入。直观上看,这一改动就是在AE的基础上,让编码器多输出了一个方差,使得原AE编码器的输出发生了一点随机扰动。

这一改动可以缓解过拟合现象。这是为什么呢?我们可以这样想:原来的AE之所以会过拟合,是因为它强行记住了训练集里每一个数据的编码输出。现在,我们在VAE里让编码器不再输出一个固定值,而是随机输出一个在均值附近的值。这样的话,VAE就不能死记硬背了,必须要找出数据中的规律。

VAE的第二项改动是多添加一个学习目标让编码器的输出和标准正态分布尽可能相似。前面我们谈过,图像生成模型一般会根据一个随机向量来生成图像。最常用的产生随机向量的方法是去标准正态分布里采样。也就是说,在用VAE生成图像时,我们会抛掉编码器,用下图所示的流程来生成图像。如果我们不约束编码器的输出分布,不让它输出一个和标准正态分布很相近的分布的话,解码器就不能很好地根据来自标准正态分布的随机向量生成图像了。

综上,VAE对AE做了两项改进使编码器输出一个正态分布,且该分布要尽可能和标准正态分布相似。训练时,模型从编码器输出的分布里随机采样一个数据作为解码器的输入;图像采样(图像生成)时,模型从标准正态分布里随机采样一个数据作为解码器的输入。VAE的误差函数由两部分组成:原图像和重建图像的重建误差、编码器输出和标准正态分布之间的误差。VAE要最小化重建误差,最大化编码器输出与标准正态分布的相似度

分布与分布之间的误差可以用一个叫KL散度的指标表示。所以,在上面那个误差函数公式中,负的相似度应该被替换成KL散度。VAE的这两项改动本质上都是在解决AE的过拟合问题,所以,VAE的改动可以被看成一种正则化方法。我们可以把VAE的正则化方法简称为KL正则化。(在机器学习中,正则化方法就是「降低模型过拟合的方法」的简称。)

【补充学习】原文链接:机器学习_KL散度详解(全网最详细)_kl散度计算公式-CSDN博客

VAE确实能减轻AE的过拟合。然而,由于VAE只是让重建图像和原图像的均方误差(重建误差)尽可能小,而没有对重建图像的质量施加更多的约束,VAE的重建结果和图像生成结果都非常模糊。以下是VAE在CelebA数据集上图像生成结果。

在众多对VAE的改进方法中,一个叫做去噪扩散概率模型(Denoising Diffusion Probabilistic Model, DDPM) 的图像生成模型脱颖而出。DDPM正是当今扩散模型的开山鼻祖。我们来看一下DDPM是怎样基于VAE对图像生成建模的。

VAE之所以效果不好,很可能是因为它的约束太少了。VAE的编码和解码都是用神经网络表示的。神经网络是一个黑盒,我们不好对神经网络的中间步骤施加约束,只好在编码器的输出(某个正态分布)和解码器的输出(重建图像)上施加约束。能不能让VAE的编码和解码过程更可控一点呢?

DDPM的设计灵感来自热力学一个分布可以通过一系列简单的变化(如添加高斯噪声)逐渐变成另一个分布。恰好,VAE的编码器不正是想让来自训练集的图像(训练集分布)变成标准正态分布吗?既然如此,就不要用一个可学习的神经网络来表示VAE的编码器了,干脆用一些预定义好的加噪声操作来表示解码过程。可以从数学上证明,经过了多次加噪声操作后,最后的图像分布会是一个标准正态分布。

既然编码是加噪声,那解码时就应该去掉噪声。DDPM的解码器也不再是一个不可解释的神经网络,而是一个能预测若干个去噪结果的神经网络

相比只有两个约束条件的VAE,DDPM的约束条件就多得多了。在DDPM中,第t个去噪操作应该尽可能抵消掉第t个加噪操作。

总结一下,DDPM对VAE做了如下改动

  1. 编码器是一系列不可学习(固定)的加噪声操作

  2. 解码器是一系列可学习的去噪声操作

  3. 图像尺寸自始至终不变

相比于VAE,DDPM的编码过程和解码过程的定义更加明确,可以施加的约束更多。因此,如下图所示,它的生成效果会比VAE好很多。同时,DDPM和VAE类似,它在编码时会从分布里采样,而不是只输出一个固定值,不会出现AE的过拟合问题

DDPM的图像生成结果

DDPM的生成效果确实很好。但是,由于DDPM始终会对同一个尺寸的数据进行操作,图像的尺寸极大地影响了DDPM的运行速度,用DDPM生成高分辨率图像需要耗费大量计算资源。因此,想要用DDPM生成高质量图像,还得经过另一条路线。

3 第二条路线:VQVAE

在AE的第二条改进路线中,一些工作干脆放弃使用AE做图像生成,转而利用AE的图像压缩能力,把图像生成拆成两步来做:(1)先用AE的编码器把图像压缩成更小的图像,(2)再用另一个图像生成模型生成小图像,并用AE的解码器把小图像重建回真实图像。

为什么会有这么奇怪的图像生成方法呢?这得从另一类图像生成模型讲起。在机器翻译模型Transformer横空出世后的一段时间里,有很多工作都想把Transformer用在图像生成上。但是,原本用来生成文本的Transformer无法直接应用在图像上。在自然语言处理(NLP)中,一个句子可以用若干个单词表示。而每个单词又是用一个整数表示。所以,Transformer生成句子时,实际上是在生成若干个离散的整数,也就是生成一个离散向量。而在图像生成模型中,每个像素的颜色值是一个连续的浮点数。想把Transformer直接用在图像生成上,就得想办法把图像用离散向量表示。我们知道,AE可以把图像编码成一个连续向量。能不能做一些修改,让AE把图像编码成一个离散向量呢?

Vector Quantised-Variational AutoEncoder (VQVAE) 就是一个能把图像编码成离散向量的AE(虽然作者在取名时用了VAE)。我们来简单看一下VQVAE是怎样把图像编码成离散向量的。

假设我们有了一个能编码出离散向量的AE。

由于神经网络不能很好地处理离散数据,我们要引入NLP里的通常做法,加一个把离散向量映射成连续向量的嵌入层。

现在我们再回头讨论怎么让编码器输出一个离散向量。我们可以让AE的编码器保持不变,还是输出一个连续向量,再通过一个「向量离散化」操作,把连续向量变成离散向量。这个操作会把编码器的输出对齐到嵌入层的向量上,其原理类似于把0.99和1.01离散化成1,只不过它是对向量整体考虑,而不是对每一个数单独考虑。向量离散化操作的具体原理我们不在此处细究。

忽略掉实现细节,我们可以认为VQVAE能够把图像压缩成离散向量。更准确地说,VQVAE能把图像等比例压缩成离散的「小图像」。压缩成二维图像而不是一维向量,能够保留原图像的一些空间特性,为之后第二步图像生成铺路。

整理一下,VQVAE是一个能把图像压缩成离散小图像的AE。为了用VQVAE生成图像,需要执行一个两阶段的图像生成流程:

  • 训练时,先训练一个图像压缩模型(VQVAE),再训练一个生成压缩图像的模型(比如Transformer)

  • 生成时,先用第二个模型生成出一个压缩图像,再用第一个模型的解码器把压缩图像复原成真实图像

之所以要执行两阶段的图像生成流程,而不是只用第二个模型生成大图像,有两个原因。(1)第一个原因是前面提到的,Transformer等生成模型只支持生成离散图像,需要用另一个模型把连续的颜色值变成离散值以兼容这些模型。(2)第二个原因是为了减少模型的运算量。以Transformer为例,Transformer的运算次数大致与像素数的平方成正比,拿Transformer生成高分辨率图像的运算开销是不可接受的。而如果用一个AE把图像压缩一下的话,用Transformer就可行了。

VQVAE给后续工作带来了三条启发:(1)第一,可以用AE把图像压缩成离散向量;(2)第二,如果一个图像生成模型生成高分辨率的图像的计算代价太高,可以先用AE把图像压缩,再生成压缩图像。这两条启发对应上一段提到的使用VQVAE的两条动机。(3)而第三条启发就比较有意思了。在讨论VQVAE的过程中,我们完全没有考虑过拟合的事。这是因为经过了向量离散化操作后,解码器的输入已经不再是编码器的输出,而是嵌入层里的向量了。这种做法杜绝了AE的死记硬背,缓解了过拟合现象。这样,我们可以换一个角度看待VQVAE:编码器还是AE的编码器,编码器的输出是连续向量,后续的向量离散化操作和嵌入层全部都是解码器的一部分。从这个角度看,VQVAE其实提出了一个由向量离散化和嵌入层组成的正则化模块。这个模块和VAE的KL散度约束一样,都解决了AE的过拟合问题。我们把VQVAE的正则化方法叫做VQ正则化

VQVAE论文提出的图像生成方法效果一般。和普通的AE一样,VQVAE在训练时只用了重建误差来约束图像质量,重建图像的细节依然很模糊。且VQVAE配套的第二阶段图像生成模型不是较为强力的Transformer,而是一个基于CNN的图像生成模型。

后续的VQGAN论文对VQVAE进行了改进。对于一阶段的图像压缩模型,VQGAN在VQVAE的基础上引入了生成对抗网络(GAN)中一些监督误差,提高了图像压缩模型的重建质量;对于两阶段的图像生成模型,该方法使用了Transformer。凭借这些改动,VQGAN方法能够生成高质量的高清图片。并且,通过把额外的约束条件(如语义分割图像、文字)输入进Transformer,VQGAN方法能够实现带约束的图像生成。以下是VQGAN方法根据语义分割图像生成的高清图片。

图像生成模型可以是无约束或带约束的。无约束图像生成模型只需要输入一个随机向量,训练数据不需要任何标注,可以进行无监督训练。带约束图像生成模型会在无约束图像生成模型的基础上多加一些输入,并给每个训练图像打上描述约束的标签,执行监督训练。比如要训练文生图模型,就要给每个训练图片带上文字描述。

4 路线的交汇点——Stable Diffusion

看完上面这两条AE的改进路线,相信你已经能够猜出Stable Diffusion的核心思想了。让我们看看Stable Diffusion是怎么从这两条路径中汲取灵感的。

在发布了VQGAN后,德国的CompVis实验室开始探索起VQGAN的改进方法。VQGAN能把图像边长压缩16倍,而VQGAN配套的Transformer只能一次生成16 x 16的图片。也就是说,整套方法一次只能生成256 x 256的图片。为了生成分辨率更高的图片,VQGAN方法需要借助滑动窗口。能不能让模型一次性生成分辨率更高的图片呢?制约VQGAN方法生成分辨率的主要因素是Transformer。如果能把Transformer换成一个效率更高,能生成更高分辨率的图像的模型,不就能生成比256 x 256更大的图片了吗?CompVis实验室开始把目光着眼于DDPM上。

于是,在发布VQGAN的一年后,CompVis实验室又发布了名为High-Resolution Image Synthesis with Latent Diffusion Models的论文,提出了一种叫做隐扩散模型(latent diffusion model, LDM) 的图像生成模型。通过与AI公司Stability AI合作,借助他们庞大的算力资源训练LDM,CompVis实验室发布了商业名为Stable Diffusion的开源文生图AI绘画模型。

LDM其实就是在VQGAN方法的基础上,把图像生成模型从Transformer换成了DDPM。或者从另一个角度说,为了让DDPM生成高分辨率图像,LDM利用了VQVAE的第二条启发:先用AE把图像压缩,再用DDPM生成压缩图像。LDM的AE一般是把图像边长压缩8倍,DDPM生成64 x64的压缩图像,整套LDM能生成256 x 256的图像。

和Transformer不同,DDPM处理的图像是用连续向量表示的。因此,在LDM中使用VQGAN做图像压缩时,不一定需要向量离散化操作,只需要在AE的基础上加一点轻微的正则化就行。作者在实现LDM时讨论了两类正则化,一类是VAE的KL正则化,一类是VQ正则化(对应VQVAE的第三条启发),两种正则化都能取得不错的效果。

LDM依然可以实现带约束的图像生成。用DDPM替换掉Transformer后,额外的约束会输入进DDPM中。作者在论文中讨论了几种把约束输入进DDPM的方式。

在搞懂了早期工作后,理解Stable Diffusion的核心思想就是这么简单。让我们把Stable Diffusion的发展过程及主要结构总结一下。Stable Diffusion由两类AE的变种发展而来,一类是有强大生成能力却需要耗费大量运算资源的DDPM,一类是能够以较高保真度压缩图像的VQVAE。Stable Diffusion是一个两阶段的图像生成模型,它先用一个使用KL正则化或VQ正则化的VQGAN来实现图像压缩,再用DDPM生成压缩图像。可以把额外的约束(如文字)输入进DDPM以实现带约束图像生成。

5 相关论文

本文仅仅对Stable Diffusion的早期工作做了一个简单的梳理。要把Stable Diffusion吃透,还需要多读一些早期论文。我来把早期论文按重要性分个类。

5.1 图像生成必读文章

Neural Discrete Representation Learning (VQVAE): https://arxiv.org/abs/1711.00937

Taming Transformers for High-Resolution Image Synthesis (VQGAN): https://arxiv.org/abs/2012.09841

Denoising Diffusion Probabilistic Models (DDPM): https://arxiv.org/abs/2006.11239

5.2 图像生成选读文章

Auto-Encoding Variational Bayes (VAE): https://arxiv.org/abs/1312.6114 提出VAE的文章。数学公式较多,只需要了解VAE的大致结构就好,不需要详细阅读论文。

Pixel Recurrent Neural Networks (PixelCNN): https://arxiv.org/abs/1601.06759 提出了一种拟合离散分布的图像生成模型,自回归图像生成模型的代表。这是VQVAE使用的第二阶段图像生成模型。有兴趣可以了解一下。

Deep Unsupervised Learning using Nonequilibrium Thermodynamics: https://arxiv.org/abs/1503.03585 DDPM的前作,首个提出扩散模型思想的文章。其核心原理和DDPM几乎完全一致,但是模型结构和优化目标不够先进,生成效果没有改进后的DDPM好。数学公式较多,不必细读,可以在学习DDPM时对比着阅读。

Denoising Diffusion Implicit Models (DDIM): https://arxiv.org/abs/2010.02502 一种加速DDPM采样的方法,广泛运用在包含Stable Diffusion在内的扩散模型中。推荐阅读。

Classifier-Free Diffusion Guidance: https://arxiv.org/abs/2207.12598 一种让扩散模型的输出更加贴近约束的方法,广泛运用在包含Stable Diffusion在内的扩散模型中,用于生成更符合文字描述的图片。推荐阅读。

Generative Adversarial Networks (GAN): https://arxiv.org/abs/1406.2661 以及 A Style-Based Generator Architecture for Generative Adversarial Networks (StyleGAN): https://arxiv.org/abs/1812.04948 可以了解一下GAN是怎么确保图像生成质量的,并认识CelebAHQ和FFHQ这两个常用的人脸数据集。

5.3 其他必读文章

Deep Residual Learning for Image Recognition (ResNet): https://arxiv.org/abs/1512.03385 深度学习的经典文章。其中提出的残差连接被用到了DDPM中。

Attention Is All You Need (Transformer): https://arxiv.org/abs/1706.03762 深度学习的经典文章。其中提出的自注意力模块被用到了DDPM中。

5.4 其他选读文章

Learning Transferable Visual Models From Natural Language Supervision (CLIP): https://arxiv.org/abs/2103.00020 提出了对齐文本和图像的方法。绝大多数文生图模型的核心。

U-Net: Convolutional Networks for Biomedical Image Segmentation (U-Net): https://arxiv.org/abs/1505.04597 一种被广泛运用的神经网络架构。DDPM的神经网络的主架构。U-Net的结构很简单,可以不用去读论文,直接看代码。

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

文章来源于互联网:天才程序员周弈帆 | Stable Diffusion 解读(一):回顾早期工作

赞(0)
未经允许不得转载:5bei.cn大模型教程网 » 天才程序员周弈帆 | Stable Diffusion 解读(一):回顾早期工作
分享到: 更多 (0)

AI大模型,我们的未来

小欢软考联系我们