AI大模型教程
一起来学习

Stable Diffusion 中 Cross Attention 实现原理解析(含代码讲解)

Stable Diffusion 的 U-Net 中,Cross Attention 是将文本提示与图像特征对齐融合的关键模块,本文将结合一段 Python 实现代码,逐行解释其原理。

Cross Attention 是什么?

Cross Attention(交叉注意力)是指:让一组 Query 向量(比如图像特征)去 attend 另一组 Key-Value 向量(比如文本上下文),以融合跨模态信息。

在 Stable Diffusion 中,它的作用就是:

让图像特征与文本 Prompt 对齐,从而生成符合描述的图像。


代码实现

下面是一个精简版的 cross attention 实现(基于 PyTorch):

import torch
import torch.nn as nn
import torch.nn.functional as F

def cross_attention(x: 'b c h w', context: 'b len dim'):
	 # 获取输入张量 x 的基本信息
    batch_size, channels, height, width = x.shape
    # 根据 context 的最后一个维度得到 dim
    dim = context.shape[-1]
    # 将 x reshape 后交换维度,使其形状为 (b, num_tokens, channels)
    x_flat = x.view(batch_size, channels, -1).permute(0, 2, 1)  # (b, h*w, c)
	
	# 对 x_flat 和 context 分别做线性变换
    q = nn.Linear(channels, dim)(x_flat)      # (b, h*w, dim)
    k = nn.Linear(dim, dim)(context)          # (b, len, dim)
    v = nn.Linear(dim, dim)(context)          # (b, len, dim)
	
	# 计算注意力分数,并缩放
    attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (dim ** 0.5)  # (b, h*w, len)
     # 对分数做 softmax 操作
    attn_weights = F.softmax(attn_scores, dim=-1)                      # (b, h*w, len)
	# 用注意力权重加权求和得到输出
    attn_output = torch.matmul(attn_weights, v)        # (b, h*w, dim)
    # 再做一次线性变换,将通道数还原
    out = nn.Linear(dim, channels)(attn_output)        # (b, h*w, c)
    # 恢复原来的空间形状
    out = out.permute(0, 2, 1).view(batch_size, channels, height, width)
    return out
    
if __name__ == '__main__':
    batch_size = 2
    channels = 32
    height = 16
    width = 16
    len_context = 10
    dim_contex = 64
    x = torch.randn(batch_size, channels, height, width)
    context = torch.randn(batch_size, len_context,dim_contex)
    output = corss_attention(x, context)
    print("output shape:", output.shape)

每一行在干什么?

1️⃣ 输入格式
• x: 图像特征图,形状为 (batch_size, channels, height, width)。
• context: 上下文特征,通常是文本 Prompt 编码,形状为 (batch_size, seq_len, dim)。
2️⃣ 将图像特征展平

x_flat = x.view(batch_size, channels, -1).permute(0, 2, 1)  # (b, h*w, c)

将二维空间的图像特征展平成一个序列,即把每个像素位置看作一个 token。
3️⃣ 构造 Query / Key / Value

q = nn.Linear(channels, dim)(x_flat)  # 查询向量,图像发出请求
k = nn.Linear(dim, dim)(context)      # 键向量,文本提供信息索引
v = nn.Linear(dim, dim)(context)      # 值向量,文本的实际信息内容

这里 q 是来自图像,k 和 v 是来自文本。
4️⃣ 计算注意力得分

attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (dim ** 0.5)

这是标准的 Scaled Dot-Product Attention:
对每个图像位置的 q 与所有文本位置的 k 做点积。
5️⃣ softmax 得到注意力权重

attn_weights = F.softmax(attn_scores, dim=-1)

将得分转化为概率分布,得到每个图像位置对文本各 token 的注意力程度。
6️⃣ 加权求和融合文本特征

attn_output = torch.matmul(attn_weights, v)  # (b, h*w, dim)

使用注意力权重对文本的值向量 v 做加权,得到融合文本信息后的图像 token。
7️⃣ 映射回原图像空间

out = nn.Linear(dim, channels)(attn_output)
out = out.permute(0, 2, 1).view(batch_size, channels, height, width)

将融合后的特征映射回原始的通道维度,并 reshape 成 (b, c, h, w) 格式。

Cross Attention 中的维度说明与 softmax 理解(知识点补充)

   attn_weights = F.softmax(attn_scores, dim=-1)                      # (b, h*w, len)

在实现 cross attention 的过程中,正确理解各个张量的维度是至关重要的,尤其是 softmax(dim=-1) 背后的语义逻辑。下面我们来系统梳理一下。


各维度语义解释:

维度 含义
b batch size(样本数)
h*w 图像的空间位置个数(flatten 后的 token 数)
len 文本 token 数量(context 长度)
dim 注意力计算的投影维度(hidden_dim)

张量维度总览

张量名 Shape 含义
x_flat (b, h*w, c) 将图像特征图展平后,每个像素位置作为一个 token
context (b, len, dim) 上下文特征,通常是文本 Prompt 的输出
q (b, h*w, dim) 图像 token 对应的查询向量
k (b, len, dim) 文本 token 对应的键向量
v (b, len, dim) 文本 token 对应的值向量
attn_scores (b, h*w, len) 每个图像位置对所有文本 token 的注意力得分
attn_weights (b, h*w, len) softmax 后的注意力权重
attn_output (b, h*w, dim) 每个图像位置融合文本信息后的表示
out (b, c, h, w) 最终恢复成图像形状的输出特征图

为什么使用 softmax(dim=-1)

attn_scores 的 shape 是 (b, h*w, len)

  • b: batch size
  • h*w: 每张图像中的所有位置(flatten 后)
  • len: 文本 token 数

当我们执行:

attn_weights = F.softmax(attn_scores, dim=-1)

就是对最后一维 len 做 softmax,含义是:

对于图像中每个位置(每个 query),我们希望它能独立地对所有文本 token(keys)计算注意力权重。

因此:
• dim=-1 ✅ 正确:在文本维度上做 softmax,是 query → key 的正常注意力机制。
• dim=1 ❌ 错误:会在图像位置之间做归一化,打乱注意力语义。

一句话总结

softmax(dim=-1) 是为了让每个 图像 token(query)对所有 文本 token(key)进行注意力分配,而不是在别的维度上归一化。


💬 如果你觉得这篇整理有帮助,欢迎点赞收藏!

文章来源于互联网:Stable Diffusion 中 Cross Attention 实现原理解析(含代码讲解)

赞(0)
未经允许不得转载:5bei.cn大模型教程网 » Stable Diffusion 中 Cross Attention 实现原理解析(含代码讲解)
分享到: 更多 (0)

AI大模型,我们的未来

小欢软考联系我们