在 Stable Diffusion 的 U-Net 中,Cross Attention 是将文本提示与图像特征对齐融合的关键模块,本文将结合一段 Python 实现代码,逐行解释其原理。
Cross Attention 是什么?
Cross Attention(交叉注意力)是指:让一组 Query 向量(比如图像特征)去 attend 另一组 Key-Value 向量(比如文本上下文),以融合跨模态信息。
在 Stable Diffusion 中,它的作用就是:
让图像特征与文本 Prompt 对齐,从而生成符合描述的图像。
代码实现
下面是一个精简版的 cross attention 实现(基于 PyTorch):
import torch
import torch.nn as nn
import torch.nn.functional as F
def cross_attention(x: 'b c h w', context: 'b len dim'):
# 获取输入张量 x 的基本信息
batch_size, channels, height, width = x.shape
# 根据 context 的最后一个维度得到 dim
dim = context.shape[-1]
# 将 x reshape 后交换维度,使其形状为 (b, num_tokens, channels)
x_flat = x.view(batch_size, channels, -1).permute(0, 2, 1) # (b, h*w, c)
# 对 x_flat 和 context 分别做线性变换
q = nn.Linear(channels, dim)(x_flat) # (b, h*w, dim)
k = nn.Linear(dim, dim)(context) # (b, len, dim)
v = nn.Linear(dim, dim)(context) # (b, len, dim)
# 计算注意力分数,并缩放
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (dim ** 0.5) # (b, h*w, len)
# 对分数做 softmax 操作
attn_weights = F.softmax(attn_scores, dim=-1) # (b, h*w, len)
# 用注意力权重加权求和得到输出
attn_output = torch.matmul(attn_weights, v) # (b, h*w, dim)
# 再做一次线性变换,将通道数还原
out = nn.Linear(dim, channels)(attn_output) # (b, h*w, c)
# 恢复原来的空间形状
out = out.permute(0, 2, 1).view(batch_size, channels, height, width)
return out
if __name__ == '__main__':
batch_size = 2
channels = 32
height = 16
width = 16
len_context = 10
dim_contex = 64
x = torch.randn(batch_size, channels, height, width)
context = torch.randn(batch_size, len_context,dim_contex)
output = corss_attention(x, context)
print("output shape:", output.shape)
每一行在干什么?
1️⃣ 输入格式
• x: 图像特征图,形状为 (batch_size, channels, height, width)。
• context: 上下文特征,通常是文本 Prompt 编码,形状为 (batch_size, seq_len, dim)。
2️⃣ 将图像特征展平
x_flat = x.view(batch_size, channels, -1).permute(0, 2, 1) # (b, h*w, c)
将二维空间的图像特征展平成一个序列,即把每个像素位置看作一个 token。
3️⃣ 构造 Query / Key / Value
q = nn.Linear(channels, dim)(x_flat) # 查询向量,图像发出请求
k = nn.Linear(dim, dim)(context) # 键向量,文本提供信息索引
v = nn.Linear(dim, dim)(context) # 值向量,文本的实际信息内容
这里 q 是来自图像,k 和 v 是来自文本。
4️⃣ 计算注意力得分
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (dim ** 0.5)
这是标准的 Scaled Dot-Product Attention:
对每个图像位置的 q 与所有文本位置的 k 做点积。
5️⃣ softmax 得到注意力权重
attn_weights = F.softmax(attn_scores, dim=-1)
将得分转化为概率分布,得到每个图像位置对文本各 token 的注意力程度。
6️⃣ 加权求和融合文本特征
attn_output = torch.matmul(attn_weights, v) # (b, h*w, dim)
使用注意力权重对文本的值向量 v 做加权,得到融合文本信息后的图像 token。
7️⃣ 映射回原图像空间
out = nn.Linear(dim, channels)(attn_output)
out = out.permute(0, 2, 1).view(batch_size, channels, height, width)
将融合后的特征映射回原始的通道维度,并 reshape 成 (b, c, h, w) 格式。
Cross Attention 中的维度说明与 softmax 理解(知识点补充)
attn_weights = F.softmax(attn_scores, dim=-1) # (b, h*w, len)
在实现 cross attention 的过程中,正确理解各个张量的维度是至关重要的,尤其是 softmax(dim=-1) 背后的语义逻辑。下面我们来系统梳理一下。
各维度语义解释:
| 维度 | 含义 |
|---|---|
| b | batch size(样本数) |
| h*w | 图像的空间位置个数(flatten 后的 token 数) |
| len | 文本 token 数量(context 长度) |
| dim | 注意力计算的投影维度(hidden_dim) |
张量维度总览
| 张量名 | Shape | 含义 |
|---|---|---|
x_flat |
(b, h*w, c) |
将图像特征图展平后,每个像素位置作为一个 token |
context |
(b, len, dim) |
上下文特征,通常是文本 Prompt 的输出 |
q |
(b, h*w, dim) |
图像 token 对应的查询向量 |
k |
(b, len, dim) |
文本 token 对应的键向量 |
v |
(b, len, dim) |
文本 token 对应的值向量 |
attn_scores |
(b, h*w, len) |
每个图像位置对所有文本 token 的注意力得分 |
attn_weights |
(b, h*w, len) |
softmax 后的注意力权重 |
attn_output |
(b, h*w, dim) |
每个图像位置融合文本信息后的表示 |
out |
(b, c, h, w) |
最终恢复成图像形状的输出特征图 |
为什么使用 softmax(dim=-1)?
attn_scores 的 shape 是 (b, h*w, len):
-
b: batch size -
h*w: 每张图像中的所有位置(flatten 后) -
len: 文本 token 数
当我们执行:
attn_weights = F.softmax(attn_scores, dim=-1)
就是对最后一维 len 做 softmax,含义是:
对于图像中每个位置(每个 query),我们希望它能独立地对所有文本 token(keys)计算注意力权重。
因此:
• dim=-1 ✅ 正确:在文本维度上做 softmax,是 query → key 的正常注意力机制。
• dim=1 ❌ 错误:会在图像位置之间做归一化,打乱注意力语义。
一句话总结:
softmax(dim=-1) 是为了让每个 图像 token(query)对所有 文本 token(key)进行注意力分配,而不是在别的维度上归一化。
💬 如果你觉得这篇整理有帮助,欢迎点赞收藏!
5bei.cn大模型教程网










