AI大模型教程
一起来学习

【大模型与AIGC】VLM基础知识汇总

LLM输入时的理解

1. Tokenizer的实现:Word极大似然估计

LLM推理:关于Attention mask的理解

1. CausalModel 与 AttentionMask

先看一个简易的attention的实现

import torch

# 随机生成Q,K,V仿真输入
batch, max_length, dim = 2, 4, 3
Q, K, V = torch.rand(3, batch, max_length, dim)

# LLM处于prefilling阶段,输入的token数为3个要推理第四个了
cur_length = 3  # prefilling stage cur_length 

# 只需要前3个token参与运算即可
W = torch.einsum('bmd,bnd->bmn', Q[:, :cur_length], K[:, :cur_length])
Y = torch.einsum('bmn,bnd->bmd', W, V[:, :cur_length])

添加attention_mask

# 构建attention_mask,由于使用了cur_length作为切片,所以只要构建一个全是1的mask即可
atten_mask = torch.ones(max_length)
cur_length = 3  # prefilling stage cur_length 
atten_mask = atten_mask[:cur_length].repeat(cur_length, 1)
print("attention mask:n", atten_mask)

# 只需要前3个token参与运算即可
W = torch.einsum('bmd,bnd->bmn', Q[:, :cur_length], K[:, :cur_length])
W = torch.softmax((W * atten_mask[None, :, :]), dim=-1)
Y = torch.einsum('bmn,bnd->bmd', W, V[:, :cur_length])

当前的LLM模型往往采用CausalModel,它的mask构建如下,即计算Y[i]时不会有Q[i+t]/K[i+t]/V[i+t] (t>0)引入attention,Q[i]不会与K[i+t]/V[i+t]计算

idx = torch.arange(cur_length)
causal_mask = (idx[None, :]  idx[:, None]).float()
print("causal_mask:n", causal_mask)

输出如下:

attention mask:
 tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
causal_mask:
 tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

2. attention mask乘法变成加法

原理如下:

exp(x * 0) = exp(x + (-inf)) ~= exp(x + min)
exp(x * 1) = exp(x + 0)

因此下面两个实现是等价的:

idx = torch.arange(cur_length)
causal_mask = (idx[None, :]  idx[:, None]).float()
print("causal_mask:n", causal_mask)

W = torch.einsum('bmd,bnd->bmn', Q[:, :cur_length], K[:, :cur_length])
W = torch.softmax((W * causal_mask[None, :, :]), dim=-1)
Y = torch.einsum('bmn,bnd->bmd', W, V[:, :cur_length])
print(Y.shape)
print(Y)
neg_inf = torch.finfo(Q.dtype).min
m = torch.zeros_like(causal_mask)
m[causal_mask == 0] = neg_inf
causal_mask = m

W = torch.einsum('bmd,bnd->bmn', Q[:, :cur_length], K[:, :cur_length])
W = torch.softmax((W + causal_mask[None, :, :]), dim=-1)
Y = torch.einsum('bmn,bnd->bmd', W, V[:, :cur_length])
print(Y.shape)
print(Y)

输出如下:

causal_mask:
 tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
torch.Size([2, 3, 3])
tensor([[[0.4954, 0.8203, 0.4379],
         [0.5537, 0.8697, 0.4225],
         [0.4435, 0.7512, 0.4051]],

        [[0.4347, 0.9123, 0.5548],
         [0.4604, 0.9270, 0.6272],
         [0.4866, 0.9341, 0.6768]]])
torch.Size([2, 3, 3])
tensor([[[0.5419, 0.9893, 0.6665],
         [0.6266, 0.9406, 0.4199],
         [0.4435, 0.7512, 0.4051]],

        [[0.2483, 0.8655, 0.2131],
         [0.4129, 0.9380, 0.6106],
         [0.4866, 0.9341, 0.6768]]])

3. 参考代码

causal_mask有两处实现,一个是huggingface的transformers中的基本类PreTrainedModel,它的一个负类ModuleUtilsMixin中,这个代码只有在config中is_decoder为True时才被使用,而往往这个是False(默认也是)

class ModuleUtilsMixin:    
    @staticmethod
    def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device=None):
        if device is not None:
            warnings.warn(
                "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
            )
        else:
            device = attention_mask.device
        batch_size, seq_length = input_shape
        seq_ids = torch.arange(seq_length, device=device)
        causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1)  seq_ids[None, :, None]
        # in case past_key_values are used we need to add a prefix ones mask to the causal mask
        # causal and attention masks must have same type with pytorch version 
        causal_mask = causal_mask.to(attention_mask.dtype)

        if causal_mask.shape[1]  attention_mask.shape[1]:
            prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
            causal_mask = torch.cat(
                [
                    torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
                    causal_mask,
                ],
                axis=-1,
            )

        extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
        return extended_attention_mask

if self.config.is_decoder:
    extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
        input_shape, attention_mask, device
    )

因此,很多代码是自己实现这个mask,以InterLM为例子

causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)

4. KV cache

KV cache原理

#mermaid-svg-06gLi8YpGYVPK50Z {font-family:”trebuchet ms”,verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-06gLi8YpGYVPK50Z .error-icon{fill:#552222;}#mermaid-svg-06gLi8YpGYVPK50Z .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-06gLi8YpGYVPK50Z .edge-thickness-normal{stroke-width:2px;}#mermaid-svg-06gLi8YpGYVPK50Z .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-06gLi8YpGYVPK50Z .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-06gLi8YpGYVPK50Z .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-06gLi8YpGYVPK50Z .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-06gLi8YpGYVPK50Z .marker{fill:#333333;stroke:#333333;}#mermaid-svg-06gLi8YpGYVPK50Z .marker.cross{stroke:#333333;}#mermaid-svg-06gLi8YpGYVPK50Z svg{font-family:”trebuchet ms”,verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-06gLi8YpGYVPK50Z .label{font-family:”trebuchet ms”,verdana,arial,sans-serif;color:#333;}#mermaid-svg-06gLi8YpGYVPK50Z .cluster-label text{fill:#333;}#mermaid-svg-06gLi8YpGYVPK50Z .cluster-label span{color:#333;}#mermaid-svg-06gLi8YpGYVPK50Z .label text,#mermaid-svg-06gLi8YpGYVPK50Z span{fill:#333;color:#333;}#mermaid-svg-06gLi8YpGYVPK50Z .node rect,#mermaid-svg-06gLi8YpGYVPK50Z .node circle,#mermaid-svg-06gLi8YpGYVPK50Z .node ellipse,#mermaid-svg-06gLi8YpGYVPK50Z .node polygon,#mermaid-svg-06gLi8YpGYVPK50Z .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-06gLi8YpGYVPK50Z .node .label{text-align:center;}#mermaid-svg-06gLi8YpGYVPK50Z .node.clickable{cursor:pointer;}#mermaid-svg-06gLi8YpGYVPK50Z .arrowheadPath{fill:#333333;}#mermaid-svg-06gLi8YpGYVPK50Z .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-06gLi8YpGYVPK50Z .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-06gLi8YpGYVPK50Z .edgeLabel{background-color:#e8e8e8;text-align:center;}#mermaid-svg-06gLi8YpGYVPK50Z .edgeLabel rect{opacity:0.5;background-color:#e8e8e8;fill:#e8e8e8;}#mermaid-svg-06gLi8YpGYVPK50Z .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-06gLi8YpGYVPK50Z .cluster text{fill:#333;}#mermaid-svg-06gLi8YpGYVPK50Z .cluster span{color:#333;}#mermaid-svg-06gLi8YpGYVPK50Z div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:”trebuchet ms”,verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-06gLi8YpGYVPK50Z :root{–mermaid-font-family:”trebuchet ms”,verdana,arial,sans-serif;}
T=1
T=2
T=3
X1_1
X2_1
X2_2
X1_2
X2_3
X1_3
X3_1
X3_2
X3_3

通过上面的分析,可以知道,对于CausalModel的LLM,第t个词的结果在整过过程是不变的,且不依赖于后面时刻的输入,所以可以使用KV cache,把之前的结果缓存下来,只预测新的token的结果。

为什么需要KV cache | KV cache工作机制 | 有几个阶段 |估计KV cache 显存占用 | KV cache为何成为长上下文瓶颈 |优化方案 | 最新技术
sumary:

  • 因果结构的LLM在attention计算时候存在重复计算,且计算时间随着token数O(N^2)复杂度增长
  • 存储前面token结果的K/V,只完成新的token的Q到结果的计算,从而将负责度降为线性增长。
  • Prefill和decoding两个阶段,prefill将输入的prompt计算并为每个token存储K/V,生成第一个token (gemm矩阵);decoding就是完成第二个到最后一个token的生成,同时记录新token的K/V (gemv向量)。
  • KV cache显存:层数 × KV注意力头的数量 × 注意力头的维度 × (位宽/8) × 2(K/V),一般使用bfloat16/float16,也就是位宽为16。
    • Llama 3 8B:32 × 8 × 128 × 2 × 2 = 131,072B ~= 0.1M, (Llama 3 8B 有 32 个注意力头,不过由于 GQA 的存在,只有 8 个注意力头用于键和值)。
    • Llama 3 70B: 80 × 8 × 128 × 2 × 2 = 327,680B ~=0.3M
  • 如果上下文大小是 8192,Llama 3 8B 131,072x 8192 ~= 1.G,Llama 3 80B ~=2.5G;如果上下文大小是 128k,Llama 3 8B 131,072x 128000~=15.6G,Llama 3 80B ~=39.1G
  • 优化方案:量化,4bit甚至2bit,但量化过程会减慢解码过程,并且可能显著影响 LLM 的准确性。使用时需要调整量化超参来减轻影响。
  • 最新发展:层内KV Cache 共享(MQA/GQA)发展为层间KV-cache(YOCO)

完全避免KV cache问题,可能依赖于架构的调整,比如换成 Mamba (Falcon Mamba 7B)

如何加速量化模型推理:将乘法计算转换为LUT查表

LLM推理:生成的策略有哪些?

1. 生成下一个token时候的sample策略

严昕:解读transformers库generate接口的解码策略

2. input length 以及超出长度后如何处理

MoE: DeepSeek-V2是怎么实现的?

用PyTorch从零开始编写DeepSeek-V2

lora的Perf代码实现与使用

模型量化bitsandbytes代码实现使用

deepspeed模型切片分卡代码实现与使用

最新研究进展了解:

视觉-语言模型(Vision-Language Model)方面有哪些值得关注的学者?

文章来源于互联网:【大模型与AIGC】VLM基础知识汇总

相关推荐: SD教程|从零开始,手把手教你本地部署Stable Diffusion Webui AI绘画

StableDiffusion是一款基于[深度学习 的图像生成模型,它能够在没有任何人类指导的情况下生成高质量、逼真的图像。想要在自己的电脑上体验StableDiffusion的强大功能吗?本文将带你一步步了解如何在本地部署Stable Diffusion,让…

赞(0)
未经允许不得转载:5bei.cn大模型教程网 » 【大模型与AIGC】VLM基础知识汇总
分享到: 更多 (0)

AI大模型,我们的未来

小欢软考联系我们