系列文章目录
第一章 DiT详解
提示:本文搬运:https://zhuanlan.zhihu.com/p/684068402
Stable Diffusion核心基础内容
Stable Diffusion3在多主题提示词的控制编辑一致性能力、文字渲染控制能力以及图像生成的整体质量三个维度都有很大的提升
以下是本篇文章正文内容
1.Stable Diffusion整体架构初始
Stable Diffusion仍然是一个end-to-end模型,最大的亮点是扩散模型部分使用了全新的MM-DiT架构,同时采用优化改进的flow-matching技术训练SD3模型。目前官方开源了2B参数量的Stable Diffusion 3 medium版本,在FP16精度下Stable Diffusion 3 medium模型大小为15.8G(FP32:33.6G,FP8:10.9G),其中MM-DiT大小为4.17G(参数量约2B),VAE模型大小为168M(参数量约80M),CLIP ViT-L大小为246M(参数量约124M),OpenCLIP ViT-bigG大小为1.39G(参数量约695M),T5-XXL Encoder在FP16精度下大小为9.79G(参数量约4.7B,FP8精度下大小为4.89G)
Stable Diffusion3整体架构如下
2.VAE模型
在SD3中VAE不仅将像素级图像编码成Latent特征,因为SD3的扩散模型部分全部由Transformer架构构成,还需要将Latent特征转化为patches特征,再送到扩散模型部分进行处理。
之前的SD系列使用的VAE模型是将H✖️W✖️3的图像编码为H/8✖️W/8✖️d的Latent特征,在8倍采样下同时设置d=4,这样的话存在一定的压缩损失,产生的直接影响是对Latent特征重建时容易产生小物体畸变(比如人眼崩溃、文字畸变)。这是因为在压缩过程中,图像的细节信息被大量丢弃。对于小物体而言,它们在图像中原本占据的像素数量就较少,经过 8 倍下采样后,小物体可能仅由几个像素来表示,很多关键的细节信息丢失。例如,人眼这样的小物体,其复杂的结构和纹理在低分辨率下难以准确保留,重建时就容易出现形状和结构的畸变。当图像被下采样时,小物体的边缘信息也会变得模糊。因为下采样是一种平均化的过程,会将相邻像素的信息混合在一起,导致小物体的边缘无法清晰界定。在重建时,由于缺乏准确的边缘信息,模型难以恢复小物体的真实形状,从而产生畸变。
所以SD3模型通过提升d来增强VAE的重建能力,不同d的对比实验如下
当设置d=16时,VAE模型的整体性能(FID指标降低、Perceptual Similarity指标降低、SSIM指标提升、PSNR指标提升)比d=4时有较大的提升,所以SD 3确定使用了 d=16(16通道)的VAE模型。
但是通道数的增加也带来了拟合内容的增加,这样我们就需要增加整体参数量级来提升模型容量。
下图是原作者整理的VAE完整结构图
从上图可以看到,SD 3VAE中有三个基础组件:
1.GSC组件:GroupNorm+SiLU+Conv
2.Downsample:padding+Conv
3.Upsample:Interpolate+Conv
同时还有两个核心组件:ResNetBlock模块和SelfAttention模块
SD 3 VAE Encoder部分包含了三个DownBlock模块、一个ResNetBlock模块以及一个MidBlock模块,将输入图像压缩到Latent空间,转换成为Gaussian Distribution。
而VAE Decoder部分正好相反,其输入Latent特征,并重建成为像素级图像作为输出。其包含了三个UpBlock模块、一个ResNetBlock模块以及一个MidBlock模块。
3.MMDiT
MMDiT的一个核心关键是对图像的Laten Tokens和文本的Tokens设置了两套独立的参数,并在Attention机制前拼接在一起,再送入Attention机制进行注意力机制的计算
图像和文本属于两个不同的模态,SD3中采用两套独立的权重参数来处理学习这两个不同模态的特征,两种模态特征在所有Transformer层的权重参数并不是共享的,只通过Self-Attention机制来实现特征的交互融合。这相当于使用了两个独立的Transformer模型来处理文本和图像信息,这也是SD 3技术报告中称这个结构为MM-DiT的本质原因,这是一个多模态扩散模型。
这是原作者梳理的SD3完整的MM-DiT结构图
SD3 MMDiT主要包含了以下的核心模块
1.MM-DiT Block:一共有24个MM-DiT Blocks构成了MM-DiT架构的主体,每个MM-DiT Block中包含了两个AdapterLayerNormZero层+MM-DiT Attention层+两个LayerNorm层+两个FeedForward层
2.MM-DiT Attention Structure:MM-DiT Block中的核心组件,用于将图像特征和文本特征进行同等级别的Attention机制。
3.FeedForward:由GELU+Dropout+Linear组成。
MM-DiT和原生DiT模型一样在Latent空间中将图像的Latent特征转成patches特征,这里的patch size=2×2,和原生DiT的默认配置一致。接着和ViT一样,将得到的Patch Embedding与Positional Embedding相加(add)一起输入到Transformer的主架构中。
4.Text Encoder
SD系列模型的版本迭代中,相比于之前的版本,SD3进一步增加了Text Encoder的数量,加入了一个参数量更大的T5-XXL Encoder模型
SD 3 ViT-L CLIP Text Encoder是只包含Transformer结构的模型,一共由12个CLIPEncoderLayer模块组成。同时每个CLIPEncoderLayer模块包含一个Self-Attention层和MLP层。
SD 3 ViT-bigG CLIP Text Encoder同样只包含Transformer结构的模型,一共由32个CLIPEncoderLayer模块组成。同时每个CLIPEncoderLayer模块同样包含一个Self-Attention层和MLP层。

SD 3一共需要提取输入文本的全局语义和文本细粒度两个层面的信息特征。
首先需要提取CLIP ViT-L和OpenCLIP ViT-bigG的Pooled Text Embeddings,它们代表了输入文本的全局语义特征,维度大小分别是768和1280,两个embeddings拼接(concat操作)得到2048的embeddings,然后经过一个MLP网络并和Timestep Embeddings相加(add操作)
接着我们需要提取输入文本的细粒度特征。这里首先分别提取CLIP ViT-L和OpenCLIP ViT-bigG的倒数第二层的特征,拼接在一起得到77×2048维度的CLIP Text Embeddings;再从T5-XXL Encoder中提取最后一层的T5 Text Embeddings特征,维度大小是77×4096(这里也限制token长度为77)。紧接着对CLIP Text Embeddings使用zero-padding得到和T5 Text Embeddings相同维度的编码特征。最后,将padding后的CLIP Text Embeddings和T5 Text Embeddings在token维度上拼接在一起,得到154×4096维度的混合Text Embeddings。这个混合Text Embeddings将通过一个linear层映射到与图像Latent的Patch Embeddings特征相同的维度大小,最终和Patch Embeddings拼接在一起送入MM-DiT中
SD 3采用CLIP ViT-L + OpenCLIP ViT-bigG + T5-XXL Encoder的组合带来了文字渲染和文本一致性等方面的效果增益,但是也限制了T5-XXL Encoder的能力。因为CLIP ViT-L和OpenCLIP ViT-bigG都只能默认编码77 tokens长度的文本,这让原本能够编码512 tokens的T5-XXL Encoder在SD 3中也只能处理77 tokens长度的文本。而SD系列的“友商”模型DALL-E 3由于只使用了T5-XXL Encoder一个语言模型作为Text Encoder模块,所以可以输入512 tokens的文本,从而发挥T5-XXL Encoder的全部能力。
?有没有办法可以让这个限制解除
5.SD3使用优化的RF采样方法
二、使用步骤
1.引入库
代码如下(示例):
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
2.读入数据
代码如下(示例):
data = pd.read_csv(
'https://labfile.oss.aliyuncs.com/courses/1283/adult.data.csv')
print(data.head())
该处使用的url网络请求的数据。
总结
提示:这里对文章进行总结:
例如:以上就是今天要讲的内容,本文仅仅简单介绍了pandas的使用,而pandas提供了大量能使我们快速便捷地处理数据的函数和方法。
文章来源于互联网:深入浅出完整理解Stable Diffusion3核心基础知识
5bei.cn大模型教程网










