AI大模型教程
一起来学习

LLaMa系列模型详解(原理介绍、代码解读):LLaMA 3

LLaMA 3

2024年4月18日,Meta 重磅推出了Meta Llama 3,Llama 3是Meta最先进开源大型语言模型的下一代,包括具有80亿和700亿参数的预训练和指令微调的语言模型,能够支持广泛的应用场景。这一代Llama在一系列行业标准基准测试中展示了最先进的性能,并提供了新的功能,包括改进的推理能力。

版本和性能

新的 8B 和 70B 参数 Llama 3 模型是 Llama 2 的重大飞跃,并为这些规模的 LLM 模型建立了新的最先进技术。由于预训练和训练后的改进,模型是当今 8B 和 70B 参数规模的最佳模型。我训练后程序的改进大大降低了错误拒绝率,改善了一致性并增加了模型响应的多样性。我们还看到了推理、代码生成和指令跟踪等功能的极大改进,使 Llama 3 更加易于操控。

模型架构

从模型架构上看,LLaMA 3和LLaMA 2基本没有区别,同样使用了Transformer的Decoder-only架构,加入RMSNorm预归一化,使用 SwiGLU 激活函数和旋转位置嵌入,使用了改进的注意力机制GQA,增加了上下文长度。故本文不具体解释。

上述具体的技术和方法可以查看LLaMA 2的博客:点击此处

模型代码如下,代码来自LLaMA 3:https://github.com/meta-llama/llama3

# Copyright (c) Meta Platforms, Inc. and affiliates.  
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.  
  
import math  
from dataclasses import dataclass  
from typing import Optional, Tuple  
  
import fairscale.nn.model_parallel.initialize as fs_init  
import torch  
import torch.nn.functional as F  
from fairscale.nn.model_parallel.layers import (  
    ColumnParallelLinear,  
    RowParallelLinear,  
    VocabParallelEmbedding,  
)  
from torch import nn  
  
  
@dataclass  
class ModelArgs:  
    dim: int = 4096  
    n_layers: int = 32  
    n_heads: int = 32  
    n_kv_heads: Optional[int] = None  
    vocab_size: int = -1  
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2  
    ffn_dim_multiplier: Optional[float] = None  
    norm_eps: float = 1e-5  
    rope_theta: float = 500000  
  
    max_batch_size: int = 32  
    max_seq_len: int = 2048  
  
  
class RMSNorm(torch.nn.Module):  
    def __init__(self, dim: int, eps: float = 1e-6):  
        super().__init__()  
        self.eps = eps  
        self.weight = nn.Parameter(torch.ones(dim))  
  
    def _norm(self, x):  
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)  
  
    def forward(self, x):  
        output = self._norm(x.float()).type_as(x)  
        return output * self.weight  
  
  
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):  
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))  
    t = torch.arange(end, device=freqs.device, dtype=torch.float32)  
    freqs = torch.outer(t, freqs)  
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64  
    return freqs_cis  
  
  
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):  
    ndim = x.ndim  
    assert 0  1  ndim  
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])  
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]  
    return freqs_cis.view(*shape)  
  
  
def apply_rotary_emb(  
    xq: torch.Tensor,  
    xk: torch.Tensor,  
    freqs_cis: torch.Tensor,  
) -> Tuple[torch.Tensor, torch.Tensor]:  
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))  
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))  
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)  
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)  
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)  
    return xq_out.type_as(xq), xk_out.type_as(xk)  
  
  
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:  
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""  
    bs, slen, n_kv_heads, head_dim = x.shape  
    if n_rep == 1:  
        return x  
    return (  
        x[:, :, :, None, :]  
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)  
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)  
    )  
  
  
class Attention(nn.Module):  
    def __init__(self, args: ModelArgs):  
        super().__init__()  
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads  
        model_parallel_size = fs_init.get_model_parallel_world_size()  
        self.n_local_heads = args.n_heads // model_parallel_size  
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size  
        self.n_rep = self.n_local_heads // self.n_local_kv_heads  
        self.head_dim = args.dim // args.n_heads  
  
        self.wq = ColumnParallelLinear(  
            args.dim,  
            args.n_heads * self.head_dim,  
            bias=False,  
            gather_output=False,  
            init_method=lambda x: x,  
        )  
        self.wk = ColumnParallelLinear(  
            args.dim,  
            self.n_kv_heads * self.head_dim,  
            bias=False,  
            gather_output=False,  
            init_method=lambda x: x,  
        )  
        self.wv = ColumnParallelLinear(  
            args.dim,  
            self.n_kv_heads * self.head_dim,  
            bias=False,  
            gather_output=False,  
            init_method=lambda x: x,  
        )  
        self.wo = RowParallelLinear(  
            args.n_heads * self.head_dim,  
            args.dim,  
            bias=False,  
            input_is_parallel=True,  
            init_method=lambda x: x,  
        )  
  
        self.cache_k = torch.zeros(  
            (  
                args.max_batch_size,  
                args.max_seq_len,  
                self.n_local_kv_heads,  
                self.head_dim,  
            )  
        ).cuda()  
        self.cache_v = torch.zeros(  
            (  
                args.max_batch_size,  
                args.max_seq_len,  
                self.n_local_kv_heads,  
                self.head_dim,  
            )  
        ).cuda()  
  
    def forward(  
        self,  
        x: torch.Tensor,  
        start_pos: int,  
        freqs_cis: torch.Tensor,  
        mask: Optional[torch.Tensor],  
    ):  
        bsz, seqlen, _ = x.shape  
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)  
  
        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)  
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)  
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)  
  
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)  
  
        self.cache_k = self.cache_k.to(xq)  
        self.cache_v = self.cache_v.to(xq)  
  
        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk  
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv  
  
        keys = self.cache_k[:bsz, : start_pos + seqlen]  
        values = self.cache_v[:bsz, : start_pos + seqlen]  
  
        # repeat k/v heads if n_kv_heads 
        keys = repeat_kv(  
            keys, self.n_rep  
        )  # (bs, cache_len + seqlen, n_local_heads, head_dim)  
        values = repeat_kv(  
            values, self.n_rep  
        )  # (bs, cache_len + seqlen, n_local_heads, head_dim)  
  
        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)  
        keys = keys.transpose(1, 2)  # (bs, n_local_heads, cache_len + seqlen, head_dim)  
        values = values.transpose(  
            1, 2  
        )  # (bs, n_local_heads, cache_len + seqlen, head_dim)  
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)  
        if mask is not None:  
            scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)  
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)  
        output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)  
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)  
        return self.wo(output)  
  
  
class FeedForward(nn.Module):  
    def __init__(  
        self,  
        dim: int,  
        hidden_dim: int,  
        multiple_of: int,  
        ffn_dim_multiplier: Optional[float],  
    ):  
        super().__init__()  
        hidden_dim = int(2 * hidden_dim / 3)  
        # custom dim factor multiplier  
        if ffn_dim_multiplier is not None:  
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)  
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)  
  
        self.w1 = ColumnParallelLinear(  
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x  
        )  
        self.w2 = RowParallelLinear(  
            hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x  
        )  
        self.w3 = ColumnParallelLinear(  
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x  
        )  
  
    def forward(self, x):  
        return self.w2(F.silu(self.w1(x)) * self.w3(x))  
  
  
class TransformerBlock(nn.Module):  
    def __init__(self, layer_id: int, args: ModelArgs):  
        super().__init__()  
        self.n_heads = args.n_heads  
        self.dim = args.dim  
        self.head_dim = args.dim // args.n_heads  
        self.attention = Attention(args)  
        self.feed_forward = FeedForward(  
            dim=args.dim,  
            hidden_dim=4 * args.dim,  
            multiple_of=args.multiple_of,  
            ffn_dim_multiplier=args.ffn_dim_multiplier,  
        )  
        self.layer_id = layer_id  
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)  
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)  
  
    def forward(  
        self,  
        x: torch.Tensor,  
        start_pos: int,  
        freqs_cis: torch.Tensor,  
        mask: Optional[torch.Tensor],  
    ):  
        h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)  
        out = h + self.feed_forward(self.ffn_norm(h))  
        return out  
  
  
class Transformer(nn.Module):  
    def __init__(self, params: ModelArgs):  
        super().__init__()  
        self.params = params  
        self.vocab_size = params.vocab_size  
        self.n_layers = params.n_layers  
  
        self.tok_embeddings = VocabParallelEmbedding(  
            params.vocab_size, params.dim, init_method=lambda x: x  
        )  
  
        self.layers = torch.nn.ModuleList()  
        for layer_id in range(params.n_layers):  
            self.layers.append(TransformerBlock(layer_id, params))  
  
        self.norm = RMSNorm(params.dim, eps=params.norm_eps)  
        self.output = ColumnParallelLinear(  
            params.dim, params.vocab_size, bias=False, init_method=lambda x: x  
        )  
  
        self.freqs_cis = precompute_freqs_cis(  
            params.dim // params.n_heads,  
            params.max_seq_len * 2,  
            params.rope_theta,  
        )  
  
    @torch.inference_mode()  
    def forward(self, tokens: torch.Tensor, start_pos: int):  
        _bsz, seqlen = tokens.shape  
        h = self.tok_embeddings(tokens)  
        self.freqs_cis = self.freqs_cis.to(h.device)  
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]  
  
        mask = None  
        if seqlen > 1:  
            mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)  
  
            mask = torch.triu(mask, diagonal=1)  
  
            # When performing key-value caching, we compute the attention scores  
            # only for the new sequence. Thus, the matrix of scores is of size            # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for            # j > cache_len + i, since row i corresponds to token cache_len + i.            mask = torch.hstack(  
                [torch.zeros((seqlen, start_pos), device=tokens.device), mask]  
            ).type_as(h)  
  
        for layer in self.layers:  
            h = layer(h, start_pos, freqs_cis, mask)  
        h = self.norm(h)  
        output = self.output(h).float()  
        return output

Tokenizer

LLaMA3 改进了Tokenizer,使得对长文本的处理更快。

# Copyright (c) Meta Platforms, Inc. and affiliates.  
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.  
  
import os  
from logging import getLogger  
from pathlib import Path  
from typing import (  
    AbstractSet,  
    cast,  
    Collection,  
    Dict,  
    Iterator,  
    List,  
    Literal,  
    Sequence,  
    TypedDict,  
    Union,  
)  
  
import tiktoken  
from tiktoken.load import load_tiktoken_bpe  
  
  
logger = getLogger(__name__)  
  
  
Role = Literal["system", "user", "assistant"]  
  
  
class Message(TypedDict):  
    role: Role  
    content: str  
  
  
Dialog = Sequence[Message]  
  
  
class Tokenizer:  
    """  
    Tokenizing and encoding/decoding text using the Tiktoken tokenizer.    """  
    special_tokens: Dict[str, int]  
  
    num_reserved_special_tokens = 256  
  
    pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^rnp{L}p{N}]?p{L}+|p{N}{1,3}| ?[^sp{L}p{N}]+[rn]*|s*[rn]+|s+(?!S)|s+"  # noqa: E501  
  
    def __init__(self, model_path: str):  
        """  
        Initializes the Tokenizer with a Tiktoken model.  
        Args:            model_path (str): The path to the Tiktoken model file.        """        assert os.path.isfile(model_path), model_path  
  
        mergeable_ranks = load_tiktoken_bpe(model_path)  
        num_base_tokens = len(mergeable_ranks)  
        special_tokens = [  
            "",  
            "",  
            "",  
            "",  
            "",  
            "",  
            "",  
            "",  
            "",  
            "",  # end of turn  
        ] + [  
            f"{i}|>"  
            for i in range(5, self.num_reserved_special_tokens - 5)  
        ]  
        self.special_tokens = {  
            token: num_base_tokens + i for i, token in enumerate(special_tokens)  
        }  
        self.model = tiktoken.Encoding(  
            name=Path(model_path).name,  
            pat_str=self.pat_str,  
            mergeable_ranks=mergeable_ranks,  
            special_tokens=self.special_tokens,  
        )  
        logger.info(f"Reloaded tiktoken model from {model_path}")  
  
        self.n_words: int = self.model.n_vocab  
        # BOS / EOS token IDs  
        self.bos_id: int = self.special_tokens[""]  
        self.eos_id: int = self.special_tokens[""]  
        self.pad_id: int = -1  
        self.stop_tokens = {  
            self.special_tokens[""],  
            self.special_tokens[""],  
        }  
        logger.info(  
            f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"  
        )  
  
    def encode(  
        self,  
        s: str,  
        *,  
        bos: bool,  
        eos: bool,  
        allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),  
        disallowed_special: Union[Literal["all"], Collection[str]] = (),  
    ) -> List[int]:  
        """  
        Encodes a string into a list of token IDs.  
        Args:            s (str): The input string to be encoded.            bos (bool): Whether to prepend the beginning-of-sequence token.            eos (bool): Whether to append the end-of-sequence token.            allowed_tokens ("all"|set[str]): allowed special tokens in string            disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string  
        Returns:            list[int]: A list of token IDs.  
        By default, setting disallowed_special=() encodes a string by ignoring        special tokens. Specifically:        - Setting `disallowed_special` to () will cause all text corresponding          to special tokens to be encoded as natural text (insteading of raising          an error).        - Setting `allowed_special` to "all" will treat all text corresponding          to special tokens to be encoded as special tokens.        """        assert type(s) is str  
  
        # The tiktoken tokenizer can handle 
        # pyo3_runtime.PanicException.        TIKTOKEN_MAX_ENCODE_CHARS = 400_000  
  
        # https://github.com/openai/tiktoken/issues/195  
        # Here we iterate over subsequences and split if we exceed the limit        # of max consecutive non-whitespace or whitespace characters.        MAX_NO_WHITESPACES_CHARS = 25_000  
  
        substrs = (  
            substr  
            for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)  
            for substr in self._split_whitespaces_or_nonwhitespaces(  
                s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS  
            )  
        )  
        t: List[int] = []  
        for substr in substrs:  
            t.extend(  
                self.model.encode(  
                    substr,  
                    allowed_special=allowed_special,  
                    disallowed_special=disallowed_special,  
                )  
            )  
        if bos:  
            t.insert(0, self.bos_id)  
        if eos:  
            t.append(self.eos_id)  
        return t  
  
    def decode(self, t: Sequence[int]) -> str:  
        """  
        Decodes a list of token IDs into a string.  
        Args:            t (List[int]): The list of token IDs to be decoded.  
        Returns:            str: The decoded string.        """        # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.  
        return self.model.decode(cast(List[int], t))  
  
    @staticmethod  
    def _split_whitespaces_or_nonwhitespaces(  
        s: str, max_consecutive_slice_len: int  
    ) -> Iterator[str]:  
        """  
        Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`        consecutive whitespaces or consecutive non-whitespaces.        """        current_slice_len = 0  
        current_slice_is_space = s[0].isspace() if len(s) > 0 else False  
        slice_start = 0  
  
        for i in range(len(s)):  
            is_now_space = s[i].isspace()  
  
            if current_slice_is_space ^ is_now_space:  
                current_slice_len = 1  
                current_slice_is_space = is_now_space  
            else:  
                current_slice_len += 1  
                if current_slice_len > max_consecutive_slice_len:  
                    yield s[slice_start:i]  
                    slice_start = i  
                    current_slice_len = 1  
        yield s[slice_start:]  
  
  
class ChatFormat:  
    def __init__(self, tokenizer: Tokenizer):  
        self.tokenizer = tokenizer  
  
    def encode_header(self, message: Message) -> List[int]:  
        tokens = []  
        tokens.append(self.tokenizer.special_tokens[""])  
        tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False))  
        tokens.append(self.tokenizer.special_tokens[""])  
        tokens.extend(self.tokenizer.encode("nn", bos=False, eos=False))  
        return tokens  
  
    def encode_message(self, message: Message) -> List[int]:  
        tokens = self.encode_header(message)  
        tokens.extend(  
            self.tokenizer.encode(message["content"].strip(), bos=False, eos=False)  
        )  
        tokens.append(self.tokenizer.special_tokens[""])  
        return tokens  
  
    def encode_dialog_prompt(self, dialog: Dialog) -> List[int]:  
        tokens = []  
        tokens.append(self.tokenizer.special_tokens[""])  
        for message in dialog:  
            tokens.extend(self.encode_message(message))  
        # Add the start of an assistant message for the model to complete.  
        tokens.extend(self.encode_header({"role": "assistant", "content": ""}))  
        return tokens
  • 为了防止因字符串过长而产生的性能问题,encode 方法使用一个循环来处理不超过 400,000 字符的子字符串。这种方法可以避免运行时错误,例如在 Python 的外部库(如 C 或 Rust 写的库)中可能发生的内存错误。
  • 使用 _split_whitespaces_or_nonwhitespaces 方法来处理可能的大量连续空格或非空格字符,限制每个片段的最大长度为 25,000 字符。这样做既保证了处理的灵活性,也避免了处理过长片段可能带来的问题。

训练数据

为了训练最佳的语言模型,收集一个大规模、高质量的训练数据集至关重要。Meta AI在预训练数据上投入了大量资金。Llama 3在超过15T的token上进行预训练,所有数据都来自公开可用的来源。我们的训练数据集比用于Llama 2的数据集大了七倍,并且包括了四倍的代码。为了准备即将到来的多语言用例,超过5%的Llama 3预训练数据集由高质量的非英语数据组成,覆盖了超过30种语言。然而,我们不期望在这些语言中达到与英语相同的性能水平。

为了确保Llama 3训练的数据质量最高,我们开发了一系列数据过滤管道。这些管道包括使用启发式过滤器、NSFW过滤器、语义去重方法和文本分类器来预测数据质量。我们发现,Llama的前几代在识别高质量数据方面出奇地好,因此我们使用Llama 2生成了为Llama 3提供动力的文本质量分类器的训练数据。

为了在Llama 3模型中有效利用我们的预训练数据,我们投入了大量精力来扩大预训练规模。具体来说,我们为下游基准评估开发了一系列详细的扩展法则。这些扩展法则使我们能够选择最佳数据混合方案,并就如何最佳利用我们的训练计算资源做出明智的决策。重要的是,扩展法则允许我们在实际训练模型之前预测我们最大模型在关键任务上的性能。这帮助我们确保最终模型在各种使用场景和能力上的强劲性能。

在Llama 3的开发过程中,我们对扩展行为做出了几项新的观察。例如,虽然对于80亿参数模型来说,Chinchilla最优的训练计算量对应于约2000亿个token,但我们发现即使模型在数据量增加两个数量级后,模型性能仍然在持续提升。在我们的80亿和700亿参数模型经过高达15T个token的训练后,它们的性能继续以对数线性方式提升。大型模型可以在较少的训练计算量下匹配这些小型模型的性能,但通常更倾向于使用小型模型,因为它们在推理过程中效率更高。

为了训练我们最大的Llama 3模型,我们结合了三种类型的并行化:数据并行化、模型并行化和流水线并行化。我们最有效的实现方式在同时训练16K个GPU时,每个GPU的计算利用率超过400 TFLOPS。我们在两个定制构建的24K GPU集群上执行了训练运行。为了最大化GPU的运行时间,我们开发了一个新的高级训练堆栈,自动化了错误检测、处理和维护。我们还大大提高了硬件的可靠性和检测机制,用于静默数据损坏,并开发了新的可扩展存储系统,减少了检查点和回滚的开销。这些改进使得整体有效训练时间超过了95%。综合来看,这些改进将Llama 3的训练效率提高了约三倍,与Llama 2相比。

指令微调

为了充分释放我们预训练模型在聊天用例中的潜力,我们对指令调整方法也进行了创新。我们的后训练方法是监督式微调(SFT)、拒绝采样、近端策略优化(PPO)和直接策略优化(DPO)的组合。用于SFT的提示质量和用于PPO和DPO的偏好排名对对齐模型的性能有巨大影响。我们在模型质量上的一些最大改进来自于仔细筛选这些数据,并对人类标注者提供的多轮质量保证进行多次审查。

通过PPO和DPO从偏好排名中学习也大大提高了Llama 3在推理和编码任务上的性能。我们发现,如果你问一个模型一个它难以回答的推理问题,模型有时会产生正确的推理轨迹:模型知道如何产生正确的答案,但它不知道如何选择它。在偏好排名上进行训练使模型学会了如何选择它。

文章来源于互联网:LLaMa系列模型详解(原理介绍、代码解读):LLaMA 3

相关推荐: 抓住AIGC行业的未来:现在正是进入的最佳时机

引言 在当今信息爆炸和技术迅猛发展的时代,人工智能和生成内容(AIGC)行业正迅速崛起,成为创新和创业的新热点。AIGC技术正在改变我们获取、处理和创造信息的方式,无论是在新闻、娱乐还是教育领域。如果你对技术和创新充满热情,那么现在正是进入AIGC行业的最好时…

赞(0)
未经允许不得转载:5bei.cn大模型教程网 » LLaMa系列模型详解(原理介绍、代码解读):LLaMA 3

LLaMa系列模型详解(原理介绍、代码解读):LLaMa

LLaMA详解

LLaMA(Large Language Model Meta AI)是由Meta(前身为Facebook)开发的一种大规模语言模型,旨在提高自然语言处理(NLP)任务的性能。LLaMA基于变换器(Transformer)架构,并经过大规模数据训练,以便在多种语言任务中表现出色。

Meta AI认为:对于给定的计算预算,最佳性能不是通过最大的模型实现的,而是通过在更多数据上训练的较小模型实现的。

模型结构

与GPT等生成模型类似,LLaMA也只使用了Transformer的解码器,但基于Transformer进行了三个改进:

  1. 使用了GPT3的预标准化。为了提高训练稳定性,对每个Transformer子层的输入进行归一化,而不是对输出进行归一化。使用由RMSNorm 归一化函数。
  2. 用 SwiGLU 激活函数替换 ReLU 非线性,以提高性能。使用

    2

    3

    4

    d

    frac{2}{3}4d

    324d的维度代替PaLM中的

    4

    d

    4d

    4d

  3. 类似GPTNeo,删除了绝对位置嵌入,而是添加了旋转位置嵌入(RoPE)。

下面逐一介绍这三个改进:

RMSNorm

RMSNorm(Root Mean Square Normalization)是一种归一化技术,用于稳定和加速神经网络的训练过程。与其他归一化方法(如BatchNorm和LayerNorm)不同,RMSNorm通过计算输入张量的均方根(RMS)来进行归一化。RMSNorm公式如下:

RMSNorm

(

x

)

=

x

1

d

i

=

1

d

x

i

2

+

ϵ

γ

text{RMSNorm}(x) = frac{x}{sqrt{frac{1}{d} sum_{i=1}^{d} x_i^2 + epsilon}} cdot gamma

RMSNorm(x)=d1i=1dxi2+ϵ
x
γ

其中

x

x

x是输入向量,

d

d

d 是输入向量的维度,

ϵ

epsilon

ϵ是一个小常数,用于避免除零错误,

γ

gamma

γ是一个可学习的缩放参数。

LLaMa中的实现如下:

class RMSNorm(torch.nn.Module):  
    def __init__(self, dim: int, eps: float = 1e-6):  
        super().__init__()  
        self.eps = eps  
        self.weight = nn.Parameter(torch.ones(dim))  
  
    def _norm(self, x):  
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)  
  
    def forward(self, x):  
        output = self._norm(x.float()).type_as(x)  
        return output * self.weight

SwiGLU激活函数

SwiGLU (Swish-Gated Linear Unit) 是一种用于神经网络的激活函数,它结合了Swish激活函数和门控机制,能够有效地增强模型的表达能力和性能。公式如下:

SwiGLU

(

x

)

=

Swish

(

x

)

(

Gated Linear Unit

(

x

)

)

text{SwiGLU}(x) = text{Swish}(x) cdot (text{Gated Linear Unit}(x))

SwiGLU(x)=Swish(x)(Gated Linear Unit(x))

Swish

(

x

)

=

x

σ

(

x

)

text{Swish}(x) = x cdot sigma(x)

Swish(x)=xσ(x)

Gated Linear Unit

(

x

)

=

Linear

1

(

x

)

σ

(

Linear

2

(

x

)

)

text{Gated Linear Unit}(x) = text{Linear}_1(x) cdot sigma(text{Linear}_2(x))

Gated Linear Unit(x)=Linear1(x)σ(Linear2(x))

σ

(

x

)

=

1

1

+

e

x

sigma(x) = frac{1}{1 + e^{-x}}

σ(x)=1+ex1

Linear

1

text{Linear}_1

Linear1

Linear

2

text{Linear}_2

Linear2是两个单独的线性变换。

LLaMa代码中使用

F

.

s

i

l

u

(

x

)

F.silu(x)

F.silu(x)添加SwiGLU激活函数

RoPE

旋转位置嵌入(Rotary Position Embedding, RoPE)是一种为序列模型(如Transformer)提供位置编码的方法。RoPE通过将输入向量在复数域进行旋转变换,来编码序列中位置的信息。与传统的位置编码方法(如正弦-余弦位置编码)相比,RoPE能够更好地捕捉序列中的相对位置信息,提高模型的表现力。

旋转位置嵌入(RoPE)是一种为序列模型提供位置编码的方法。其通过将输入向量在复数域进行旋转变换来编码位置信息。以下是RoPE的具体实现步骤:

  1. 频率向量的计算:

    f

    i

    =

    1

    θ

    2

    i

    d

    f_i = frac{1}{theta^{frac{2i}{d}}}

    fi=θd2i1
    其中

    θ

    theta

    θ是一个常数(通常取 10000),

    i

    i

    i是向量维度的索引。

  2. 旋转角度的计算:

    angle

    (

    t

    )

    =

    t

    f

    i

    text{angle}(t) = t cdot f_i

    angle(t)=tfi
    其中

    t

    t

    t是位置索引。

  3. 应用旋转变换:
    对每个位置

    t

    t

    t的输入向量

    x

    t

    x_t

    xt,在复数域进行旋转变换:

    x

    t

    =

    x

    t

    e

    j

    angle

    (

    t

    )

    x_t’ = x_t cdot e^{j cdot text{angle}(t)}

    xt=xtejangle(t)
    对于位置编码,常规的做法是在计算 query,key 和 value 向量之前,会计算一个位置编码向量 加到词嵌入上,位置编码向量同样也是维向量,然后再乘以对应的变换矩阵。

RoPE 的 self-attention 操作的流程是:对于 token 序列中的每个词嵌入向量,首先计算其对应的 query 和 key 向量,然后对每个 token 位置都计算对应的旋转位置编码,接着对每个 token 位置的 query 和 key 向量的元素按照两两一组应用旋转变换,最后再计算 query 和 key 之间的内积得到 self-attention 的计算结果。

下图很直观的展示了旋转变换的过程:

旋转编码 RoPE 可以有效地保持位置信息的相对关系,即相邻位置的编码之间有一定的相似性,而远离位置的编码之间有一定的差异性。 这样可以增强模型对位置信息的感知和利用。这一点是其他绝对位置编码方式(如正弦位置编码、学习的位置编码等)所不具备的,因为它们只能表示绝对位置,而不能表示相对位置。

为什么旋转位置嵌入有效?

  1. 捕捉相对位置信息:传统的位置嵌入方法通常仅编码绝对位置,这可能在处理长序列或需要捕捉相对位置信息的任务中表现不佳。而RoPE通过旋转变换自然地引入了相对位置信息,使得模型能够更好地理解序列中各个位置之间的相对关系。
  2. 由于RoPE通过复数域的旋转变换来编码位置,这种变换能够捕捉更加丰富的位置信息。相比于简单的线性变换,旋转变换提供了更强的非线性表达能力,使得模型在处理复杂任务时具有更好的表现力。
  3. RoPE的计算相对简单,不需要复杂的矩阵运算。预计算频率向量和应用旋转变换的过程可以高效地实现,适合在实际应用中大规模部署。
  4. RoPE能够无缝集成到现有的Transformer架构中,不需要对模型结构进行大的修改。这种兼容性使得RoPE成为一种易于应用和推广的位置编码方法。
  5. 在长序列处理任务中,传统的位置编码方法可能会遇到信息稀释或计算复杂度增加的问题。RoPE通过引入旋转变换,可以更好地保持长序列中的位置信息,使得模型在长序列任务中表现更加稳定和高效。
  6. (这一点是我的猜想)在高维向量中,方向是比模长更重要的量,常规位置编码直接在词嵌入上加上位置编码,相当于改变了模长,旋转位置编码改变了方向,实际上比常规位置编码多获得了一部分信息。

下面这篇文章给出了公式原理和推导,讲解十分详细:点击此处

在LLaMA中,RoPE使用下面的方式实现:

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):  
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))  
    t = torch.arange(end, device=freqs.device)  # type: ignore  
    freqs = torch.outer(t, freqs).float()  # type: ignore  
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64  
    return freqs_cis  
  
  
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):  
    ndim = x.ndim  
    assert 0  1  ndim  
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])  
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]  
    return freqs_cis.view(*shape)  
  
  
def apply_rotary_emb(  
    xq: torch.Tensor,  
    xk: torch.Tensor,  
    freqs_cis: torch.Tensor,  
) -> Tuple[torch.Tensor, torch.Tensor]:  
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))  
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))  
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)  
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)  
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)  
    return xq_out.type_as(xq), xk_out.type_as(xk)

下面的代码给出了加入旋转位置嵌入的注意力机制:

class Attention(nn.Module):  
    def __init__(self, args: ModelArgs):  
        super().__init__()  
  
        self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()  
        self.head_dim = args.dim // args.n_heads  
  
        self.wq = ColumnParallelLinear(  
            args.dim,  
            args.n_heads * self.head_dim,  
            bias=False,  
            gather_output=False,  
            init_method=lambda x: x,  
        )  
        self.wk = ColumnParallelLinear(  
            args.dim,  
            args.n_heads * self.head_dim,  
            bias=False,  
            gather_output=False,  
            init_method=lambda x: x,  
        )  
        self.wv = ColumnParallelLinear(  
            args.dim,  
            args.n_heads * self.head_dim,  
            bias=False,  
            gather_output=False,  
            init_method=lambda x: x,  
        )  
        self.wo = RowParallelLinear(  
            args.n_heads * self.head_dim,  
            args.dim,  
            bias=False,  
            input_is_parallel=True,  
            init_method=lambda x: x,  
        )  
  
        self.cache_k = torch.zeros(  
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)  
        ).cuda()  
        self.cache_v = torch.zeros(  
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)  
        ).cuda()  
  
    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):  
        bsz, seqlen, _ = x.shape  
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)  
  
        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)  
        xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)  
        xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)  
  
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)  
  
        self.cache_k = self.cache_k.to(xq)  
        self.cache_v = self.cache_v.to(xq)  
  
        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk  
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv  
  
        keys = self.cache_k[:bsz, : start_pos + seqlen]  
        values = self.cache_v[:bsz, : start_pos + seqlen]  
  
        xq = xq.transpose(1, 2)  
        keys = keys.transpose(1, 2)  
        values = values.transpose(1, 2)  
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)  
        if mask is not None:  
            scores = scores + mask  # (bs, n_local_heads, slen, cache_len + slen)  
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)  
        output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)  
        output = output.transpose(  
            1, 2  
        ).contiguous().view(bsz, seqlen, -1)  
  
        return self.wo(output)

接下来给出LLaMA实现的全部代码:

# Copyright (c) Meta Platforms, Inc. and affiliates.  
# This software may be used and distributed according to the terms of the GNU General Public License version 3.  
  
from typing import Optional, Tuple  
from dataclasses import dataclass  
import math  
  
import torch  
from torch import nn  
import torch.nn.functional as F  
  
import fairscale.nn.model_parallel.initialize as fs_init  
from fairscale.nn.model_parallel.layers import (  
    ParallelEmbedding,  
    RowParallelLinear,  
    ColumnParallelLinear,  
)  
  
  
@dataclass  
class ModelArgs:  
    dim: int = 512  
    n_layers: int = 8  
    n_heads: int = 8  
    vocab_size: int = -1  # defined later by tokenizer  
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2  
    norm_eps: float = 1e-5  
  
    max_batch_size: int = 32  
    max_seq_len: int = 2048  
  
  
class RMSNorm(torch.nn.Module):  
    def __init__(self, dim: int, eps: float = 1e-6):  
        super().__init__()  
        self.eps = eps  
        self.weight = nn.Parameter(torch.ones(dim))  
  
    def _norm(self, x):  
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)  
  
    def forward(self, x):  
        output = self._norm(x.float()).type_as(x)  
        return output * self.weight  
  
  
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):  
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))  
    t = torch.arange(end, device=freqs.device)  # type: ignore  
    freqs = torch.outer(t, freqs).float()  # type: ignore  
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64  
    return freqs_cis  
  
  
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):  
    ndim = x.ndim  
    assert 0  1  ndim  
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])  
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]  
    return freqs_cis.view(*shape)  
  
  
def apply_rotary_emb(  
    xq: torch.Tensor,  
    xk: torch.Tensor,  
    freqs_cis: torch.Tensor,  
) -> Tuple[torch.Tensor, torch.Tensor]:  
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))  
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))  
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)  
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)  
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)  
    return xq_out.type_as(xq), xk_out.type_as(xk)  
  
  
class Attention(nn.Module):  
    def __init__(self, args: ModelArgs):  
        super().__init__()  
  
        self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()  
        self.head_dim = args.dim // args.n_heads  
  
        self.wq = ColumnParallelLinear(  
            args.dim,  
            args.n_heads * self.head_dim,  
            bias=False,  
            gather_output=False,  
            init_method=lambda x: x,  
        )  
        self.wk = ColumnParallelLinear(  
            args.dim,  
            args.n_heads * self.head_dim,  
            bias=False,  
            gather_output=False,  
            init_method=lambda x: x,  
        )  
        self.wv = ColumnParallelLinear(  
            args.dim,  
            args.n_heads * self.head_dim,  
            bias=False,  
            gather_output=False,  
            init_method=lambda x: x,  
        )  
        self.wo = RowParallelLinear(  
            args.n_heads * self.head_dim,  
            args.dim,  
            bias=False,  
            input_is_parallel=True,  
            init_method=lambda x: x,  
        )  
  
        self.cache_k = torch.zeros(  
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)  
        ).cuda()  
        self.cache_v = torch.zeros(  
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)  
        ).cuda()  
  
    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):  
        bsz, seqlen, _ = x.shape  
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)  
  
        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)  
        xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)  
        xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)  
  
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)  
  
        self.cache_k = self.cache_k.to(xq)  
        self.cache_v = self.cache_v.to(xq)  
  
        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk  
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv  
  
        keys = self.cache_k[:bsz, : start_pos + seqlen]  
        values = self.cache_v[:bsz, : start_pos + seqlen]  
  
        xq = xq.transpose(1, 2)  
        keys = keys.transpose(1, 2)  
        values = values.transpose(1, 2)  
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)  
        if mask is not None:  
            scores = scores + mask  # (bs, n_local_heads, slen, cache_len + slen)  
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)  
        output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)  
        output = output.transpose(  
            1, 2  
        ).contiguous().view(bsz, seqlen, -1)  
  
        return self.wo(output)  
  
  
class FeedForward(nn.Module):  
    def __init__(  
        self,  
        dim: int,  
        hidden_dim: int,  
        multiple_of: int,  
    ):  
        super().__init__()  
        hidden_dim = int(2 * hidden_dim / 3)  
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)  
  
        self.w1 = ColumnParallelLinear(  
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x  
        )  
        self.w2 = RowParallelLinear(  
            hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x  
        )  
        self.w3 = ColumnParallelLinear(  
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x  
        )  
  
    def forward(self, x):  
        return self.w2(F.silu(self.w1(x)) * self.w3(x))  
  
  
class TransformerBlock(nn.Module):  
    def __init__(self, layer_id: int, args: ModelArgs):  
        super().__init__()  
        self.n_heads = args.n_heads  
        self.dim = args.dim  
        self.head_dim = args.dim // args.n_heads  
        self.attention = Attention(args)  
        self.feed_forward = FeedForward(  
            dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of  
        )  
        self.layer_id = layer_id  
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)  
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)  
  
    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):  
        h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask)  
        out = h + self.feed_forward.forward(self.ffn_norm(h))  
        return out  
  
  
class Transformer(nn.Module):  
    def __init__(self, params: ModelArgs):  
        super().__init__()  
        self.params = params  
        self.vocab_size = params.vocab_size  
        self.n_layers = params.n_layers  
  
        self.tok_embeddings = ParallelEmbedding(  
            params.vocab_size, params.dim, init_method=lambda x: x  
        )  
  
        self.layers = torch.nn.ModuleList()  
        for layer_id in range(params.n_layers):  
            self.layers.append(TransformerBlock(layer_id, params))  
  
        self.norm = RMSNorm(params.dim, eps=params.norm_eps)  
        self.output = ColumnParallelLinear(  
            params.dim, params.vocab_size, bias=False, init_method=lambda x: x  
        )  
  
        self.freqs_cis = precompute_freqs_cis(  
            self.params.dim // self.params.n_heads, self.params.max_seq_len * 2  
        )  
  
    @torch.inference_mode()  
    def forward(self, tokens: torch.Tensor, start_pos: int):  
        _bsz, seqlen = tokens.shape  
        h = self.tok_embeddings(tokens)  
        self.freqs_cis = self.freqs_cis.to(h.device)  
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]  
  
        mask = None  
        if seqlen > 1:  
            mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)  
            mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)  
  
        for layer in self.layers:  
            h = layer(h, start_pos, freqs_cis, mask)  
        h = self.norm(h)  
        output = self.output(h[:, -1, :])  # only compute last logits  
        return output.float()

文章来源于互联网:LLaMa系列模型详解(原理介绍、代码解读):LLaMa

赞(0)
未经允许不得转载:5bei.cn大模型教程网 » LLaMa系列模型详解(原理介绍、代码解读):LLaMa
分享到: 更多 (0)

AI大模型,我们的未来

小欢软考联系我们