Deepdive-llama3-from-scratch
Deepdive-llama3-from-scratch 是一个专为深度学习爱好者打造的开源教育项目,旨在通过从零构建代码的方式,手把手带你掌握 Llama3 大模型的推理原理。针对许多开发者在研读大模型源码时面临的“黑盒”困境——即只知代码运行结果却不懂背后数学推导与维度变化的痛点,该项目对原有教程进行了全面升级。
它特别适合希望深入理解 Transformer 架构的 AI 开发者、研究人员以及计算机专业学生。不同于普通的调用指南,Deepdive-llama3-from-scratch 的核心亮点在于其极致的透明度:不仅重构了学习路径使逻辑更清晰,还为每一行代码添加了详尽注释,并全程追踪矩阵维度的变化,让复杂的计算过程一目了然。项目更深入剖析了“为什么要这样做”,补充了包括 KV-Cache 机制在内的核心原理推导,帮助用户从根本上吃透模型设计思想。此外,项目提供高质量的原生中英文双语文档与代码,避免了机器翻译的歧义,让中文用户能更顺畅地开启大模型底层技术的学习之旅。
使用场景
某高校 AI 实验室的研究生团队正试图从零复现 Llama3 推理过程,以深入理解大模型底层机制并完成课程项目。
没有 Deepdive-llama3-from-scratch 时
- 面对原始代码中复杂的矩阵运算,学生难以追踪维度变化,常因形状不匹配导致调试失败,耗费大量时间在排查基础错误上。
- 缺乏对“为什么这样做”的原理解释,只能机械地复制代码,无法真正掌握注意力机制和 RMS 归一化等核心设计思想。
- 关于 KV-Cache 的资料零散且晦涩,团队在推导缓存优化逻辑时陷入瓶颈,难以理解其在加速推理中的具体作用。
- 英文文档配合机器翻译的代码注释存在表达歧义,初学者阅读障碍大,团队协作沟通成本极高。
使用 Deepdive-llama3-from-scratch 后
- 代码中详尽的维度追踪注释让每一步矩阵变换清晰可见,学生能迅速定位计算逻辑,调试效率提升数倍。
- 丰富的原理推导章节不仅展示了代码实现,更深度剖析了设计初衷,帮助团队成员从根本上吃透了模型架构。
- 新增的 KV-Cache 专属推导章将抽象概念具象化,团队顺利完成了从理论推导到代码落地的全过程,掌握了加速推理的关键。
- 提供地道的中英文双语代码与文档,消除了语言隔阂,使得组内不同英语水平的成员都能无障碍协作,学习曲线显著平缓。
Deepdive-llama3-from-scratch 通过“知其然更知其所以然”的深度教学,将黑盒般的模型推理转化为透明、可掌控的学习路径。
运行环境要求
- 未说明
需要 NVIDIA GPU 以加载和运行 Meta-Llama-3-8B 模型(具体显存需求取决于模型大小,8B 模型通常建议 16GB+ 显存),CUDA 版本未说明
建议 16GB+(用于加载 8B 参数模型及中间计算)

快速开始
从零开始深度解析 Llama3
[ 查看英文 | 中文版文档点这里 ]
本项目是在 naklecha/llama3-from-scratch 的基础上进行的增强版本。在原项目的基础上,我们对其进行了全面的改进和优化,旨在帮助大家更轻松地理解并掌握 Llama3 模型的实现原理及其详细的推导过程。感谢原作者的贡献 :)
以下是本项目的几项核心改进:
结构优化
重新梳理了内容的呈现顺序,并调整了目录结构,使学习流程更加清晰合理,便于大家逐步理解代码。代码注释
添加了大量的详细注释,手把手教你理解每一行代码的作用,即使是初学者也能轻松上手。维度追踪
对每一步计算中矩阵维度的变化都进行了完整标注,让你能够更直观地把握整个计算流程。原理讲解
增加了丰富的原理性说明和大量细致的推导过程,不仅告诉你“怎么做”,还深入解释“为什么这么做”,帮助你从根本上掌握模型的设计思想。KV 缓存详解
新增了 KV 缓存的推导章节,涵盖了其核心概念、原理推导以及在注意力机制中的应用过程,让你能够从根源上理解 KV 缓存的每一个细节与设计哲学。双语文档
提供了中英文双语版本的代码文件,其中中文翻译为本地化编写,避免了机器翻译可能导致的表达不准确问题。
目录
- 加载模型
- 将输入文本转换为嵌入向量
- 构建第一个Transformer块
- 一切都准备好了,让我们完成所有32个Transformer块的计算吧。祝阅读愉快:)
- 让我们完成最后一步,预测下一个标记
- 让我们深入探讨一下,不同的嵌入方式或标记掩码策略可能会如何影响预测结果:)
- 需要预测多个标记吗?只需使用KV缓存即可!(这真的让我费了很大的劲才弄清楚。Orz)
- 感谢大家。感谢你们持续的学习。爱你们:)
- LICENSE
现在,让我们正式开始学习吧!
在这个文件中,我从零开始逐次实现Llama3,每次只进行一次张量和矩阵的乘法。
此外,我将直接从Meta为Llama3提供的模型文件中加载张量(Meta-Llama-3-8B),在运行此文件之前,你需要先下载权重。以下是下载权重的官方链接:https://llama.meta.com/llama-downloads/
注1:本项目采用了基于Huggingface的模型文件下载方法。你将在下面的加载模型部分看到相关内容。同样,你也可以直接从官方网站、ModelScope或其他模型下载平台下载模型,而无需运行下方的模型下载代码。
注2:本项目使用的是原始模型文件,即下载的模型文件中“original”文件夹内的模型。
请注意!图中有一个小错误:
在每个Transformer块中,第二个“add”操作的输入应该是前馈层的输出和第一个“add”操作的输出,而不是归一化后的结果。
如果我们将多头自注意力和前馈层视为同一类型的操作(都用于特征变换),那么这两个“归一化 - 特征变换 - 残差连接(add)”的形式和流程是完全相同的。
如果我们将多头自注意力和前馈层视为同一类型的操作(都用于特征变换),那么这两个“归一化 - 特征变换 - 残差连接(add)”的形式和流程是完全相同的。
加载模型
加载分词器
分词器用于将输入的文本字符串分割成一系列子词,从而更容易输入到模型中。
我不会自己实现一个BPE分词器(不过Andrej Karpathy有一个非常简洁的实现),
他的实现链接:https://github.com/karpathy/minbpe
BPE分词器加载步骤总结:
- 加载常规词汇:加载本地分词器模型字典(仅包含常规子词,不包含特殊标记)。
- 定义特殊标记:手动定义特殊标记(可以使用现成的,也可以在现成的基础上进行修改)。
- 定义文本粗切规则:定义用于文本粗切的正则表达式(直接使用现成的)。输入文本将经过两步处理:先根据正则表达式进行粗切,再根据BPE算法进行细切,最终得到完整的分词结果。
- 创建分词器:基于OpenAI开源的tiktoken库创建文本编码解码对象(该库可以根据BPE算法对粗切结果进一步细分)。
# 加载基于BPE的分词器
# 导入相关库
from pathlib import Path # 用于从文件路径中获取文件名/模型名
import tiktoken # OpenAI开发的开源文本编码解码库(文本与标记ID之间的相互转换)
from tiktoken.load import load_tiktoken_bpe # 加载BPE模型
import torch # 用于构建模型和矩阵计算
import json # 用于加载配置文件
import matplotlib.pyplot as plt # 用于绘制图表
tokenizer_path = "Meta-Llama-3-8B/original/tokenizer.model" # 分词器模型路径
# 常规字典之外的特殊标记。
# 这些特殊标记存在于“Meta-Llama-3-8B/”路径下的‘tokenizer.json’和‘tokenizer_config.json’文件的‘added_tokens’字段中
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>", # 保留的特殊标记,编号从0到250
"<|reserved_special_token_1|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>", # 标头信息开始,用于标记包裹结构化数据的标头信息,如元数据
"<|end_header_id|>", # 标头信息结束
"<|reserved_special_token_4|>",
"<|eot_id|>", # 轮次结束,用于标记多轮对话中当前轮次的结束
] + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)]
# 加载BPE模型(实际上是一个字典)
# 字典由子词(字节类型,用utf-8解码)和排名(id)组成,共有128000个词,不包括上述256个特殊标记,
# 因此模型字典的总大小在后续操作中将达到128256个条目(但此处未计算)。
# 排名值是从0开始的递增序列,用于确定子词单元合并的优先级顺序,
# 优先级越高,合并越早。因此这里的变量名为“mergeable_ranks”,而不是类似BPE或词汇表之类的名称。
# 特殊标记未被加入字典,可能是为了灵活性考虑,
# 这样在面对不同模型架构或需要不同特殊标记的任务时,可以方便地添加特定标记,同时保持字典大小不变。
mergeable_ranks = load_tiktoken_bpe(tokenizer_path)
# 创建文本编码解码对象
# pat_str大致可分为三类:带有缩写的单词及普通单词、中文片段、1-3位数字及其他特殊字符
tokenizer = tiktoken.Encoding(
name=Path(tokenizer_path).name, # 编码器名称,在调试和日志记录时使用不同编码器会更方便
pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+", # 用于初步将文本粗略分割为标记序列的正则表达式
mergeable_ranks=mergeable_ranks, # 传入已加载的BPE模型
special_tokens={token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)}, # 用于添加特殊标记-ID对的字典
)
# 测试创建是否成功,即编码解码器是否能正常运行
print(tokenizer.decode(tokenizer.encode("create tokenizer successed!")))
# 下面是一个案例测试,用于测试pat_str的粗切与分词器的细切之间的效果和差异。
# pat_str的正则表达式只提供初步的分割,
# 一些长句子或中文文本可能无法被分割,这些部分会在分词器中根据BPE算法进一步细化。
import regex # 由于pat_str中使用了\p{L}等Unicode语法,因此不能使用re库
## 创建正则表达式
pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
pattern = regex.compile(pat_str)
## 文本分割
text = "Hello world! It's a test. 这是一个测试. alongwords. a long words. 123 456 789." # 测试字符串
re_tokens = pattern.findall(text) # 使用正则表达式分割字符串
merge_tokens_id = tokenizer.encode(text) # 使用分词器分割字符串
merge_tokens = [tokenizer.decode([i]) for i in merge_tokens_id] # 将分词器分割结果的ID序列转换为实际的子词序列
## 输出结果
print("原始字符串: ", text)
print("正则表达式分割结果: ", re_tokens)
print("分词器分割结果: ", merge_tokens)
print("分词器分割结果ID: ", list(zip(merge_tokens, merge_tokens_id)))
## 从结果可以看出,所有单词前面的空格都被保留了下来,而不是被合并为单个空格标记或直接删除。
## 这有利于模型正确理解单词之间的边界信息,例如示例中的‘alongwords’。
创建分词器成功!
原始字符串: 你好世界!这是一个测试。这是一个测试. alongwords. a long words. 123 456 789.
正则表达式分割结果: ['你好', ' 世界', '!', ' 这', "'s", ' a', ' test', '.', ' 这是一个测试', '.', ' alongwords', '.', ' a', ' long', ' words', '.', ' ', '123', ' ', '456', ' ', '789', '.']
分词器分割结果: ['你好', ' 世界', '!', ' 这', "'s", ' a', ' test', '.', ' 这', '是一个', '测试', '.', ' along', 'words', '.', ' a', ' long', ' words', '.', ' ', '123', ' ', '456', ' ', '789', '.']
分词器分割结果对应的ID: [('你好', 9906), (' 世界', 1917), ('!', 0), (' 这', 1102), ("'s", 596), (' a', 264), (' test', 1296), ('.', 13), (' 这', 122255), ('是一个', 122503), ('测试', 82805), ('.', 13), (' along', 3235), ('words', 5880), ('.', 13), (' a', 264), (' long', 1317), (' words', 4339), ('.', 13), (' ', 220), ('123', 4513), (' ', 220), ('456', 10961), (' ', 220), ('789', 16474), ('.', 13)]
读取模型文件和配置文件
通常,读取模型文件取决于其模型类的编写方式以及其中的变量名。
然而,由于我们是从零开始实现Llama3,我们将逐个读取张量文件。
# 加载模型,即一个字典,如{"网络层名称": 张量类型参数}
model = torch.load("Meta-Llama-3-8B/original/consolidated.00.pth")
# 打印前20个网络层的名字,以验证模型是否正确加载。
print(json.dumps(list(model.keys())[:20], indent=4))
[
"tok_embeddings.weight",
"layers.0.attention.wq.weight",
"layers.0.attention.wk.weight",
"layers.0.attention.wv.weight",
"layers.0.attention.wo.weight",
"layers.0.feed_forward.w1.weight",
"layers.0.feed_forward.w3.weight",
"layers.0.feed_forward.w2.weight",
"layers.0.attention_norm.weight",
"layers.0.ffn_norm.weight",
"layers.1.attention.wq.weight",
"layers.1.attention.wk.weight",
"layers.1.attention.wv.weight",
"layers.1.attention.wo.weight",
"layers.1.feed_forward.w1.weight",
"layers.1.feed_forward.w3.weight",
"layers.1.feed_forward.w2.weight",
"layers.1.attention_norm.weight",
"layers.1.ffn_norm.weight",
"layers.2.attention.wq.weight"
]
# 加载配置文件。
# 每个配置的具体含义将在下一节中说明。
with open("Meta-Llama-3-8B/original/params.json", "r") as f:
config = json.load(f)
config
{'dim': 4096,
'n_layers': 32,
'n_heads': 32,
'n_kv_heads': 8,
'vocab_size': 128256,
'multiple_of': 1024,
'ffn_dim_multiplier': 1.3,
'norm_eps': 1e-05,
'rope_theta': 500000.0}
利用配置文件推断模型细节
| 配置项 | 配置值 | 含义 |
|---|---|---|
| dim | 4096 | 隐藏层的维度,即每个标记的向量表示具有4096维。 |
| n_layers | 32 | 模型层数,即该模型包含32个Transformer层或Transformer块。 |
| n_heads | 32 | 多头注意力机制中的头数,即每个多头注意力块有32个头。所谓多头,是指同时使用多个独立的注意力机制来捕捉输入数据的不同特征或信息。 |
| n_kv_heads | 8 | 键值注意力中的头数,用于分组查询注意力(GQA)。也就是说,键值注意力有8个头,而查询有n_heads=32个头。每4个查询头会共享一组键值对。 |
| vocab_size | 128256 | 词汇表大小,包括128000个普通标记和256个特殊标记。 |
| multiple_of | 1024 | 隐藏层维度的倍数约束。也就是说,为了优化计算效率,模型的隐藏层维度应为1024的倍数。 |
| ffn_dim_multiplier | 1.3 | 前馈网络层隐藏层维度的乘数因子,用于计算FFN的隐藏层维度。具体计算过程见相应部分。 |
| norm_eps | 1e-05 | 层归一化计算中分母上添加的常数,用于防止除以零并确保数值稳定性。 |
| rope_theta | 500000.0 | 旋转位置编码(RoPE)中的基本频率缩放因子,控制位置编码的周期性和分辨率,从而影响模型捕捉不同长度序列及位置关系的能力。 |
根据配置详情,可以推断出给定输入时注意力的内部计算过程如下:
输入(L, 4096) -> query_proj(L, 128, 32)
-> key_proj(L, 128, 8)
-> value_proj(L, 128, 8)
-> group_query_attention(L, 128, 32)
-> output_proj(L, 4096)
-> 输出(L, 4096)
# 记录这些配置,后续将逐步使用。
dim = config["dim"]
n_layers = config["n_layers"]
n_heads = config["n_heads"]
n_kv_heads = config["n_kv_heads"]
vocab_size = config["vocab_size"]
multiple_of = config["multiple_of"]
ffn_dim_multiplier = config["ffn_dim_multiplier"]
norm_eps = config["norm_eps"]
rope_theta = torch.tensor(config["rope_theta"])
将输入文本转换为嵌入
在将字符串形式的文本输入到网络层之前,需要将其转换为向量形式以便进行数学计算。
所需步骤是:使用分词器将输入文本拆分为子词序列 -> 将子词转换为向量表示。
将文本转换为标记ID序列
这里我们使用tiktoken(OpenAI提供的库)作为分词器。
# 将输入提示转换为标记ID序列
prompt = "生命、宇宙以及一切问题的终极答案是 " # 输入文本
tokens = [128000] + tokenizer.encode(prompt) # 进行子词分割,并在文本开头添加一个表示文本开始的特殊标记<|begin_of_text|>。维度:[17]
print(tokens) # 检查分割结果
tokens = torch.tensor(tokens) # 转换为张量类型,以便后续进行矩阵计算。[17]
# 将标记 ID 转换为特定的标记子词序列,这仅用于显示目的,并非实际所需
prompt_split_as_tokens = [tokenizer.decode([token]) for token in tokens]
print(prompt_split_as_tokens)
[128000, 1820, 4320, 311, 279, 17139, 3488, 315, 2324, 11, 279, 15861, 11, 323, 4395, 374, 220]
['<|begin_of_text|>', 'the', ' answer', ' to', ' the', ' ultimate', ' question', ' of', ' life', ',', ' the', ' universe', ',', ' and', ' everything', ' is', ' ']
将标记 ID 序列转换为嵌入
抱歉,这是该代码库中唯一使用内置神经网络模块的部分。
简而言之,我们原来的 [17×1] 标记序列现在变成了 [17×4096],即 17 个长度为 4096 的嵌入(每个标记对应一个)。
注意:请留意这个张量形状的变化,这将有助于你更好地理解整个过程(我也会在所有步骤中标注形状的变化)。
# 创建一个嵌入层,用于将离散的标记 ID 映射到连续的向量空间
embedding_layer = torch.nn.Embedding(vocab_size, dim)
# 使用 Llama3 中的预训练参数更新嵌入层的参数
embedding_layer.weight.data.copy_(model["tok_embeddings.weight"])
# 使用嵌入层将输入的标记 ID 序列转换为向量表示
# 嵌入层只是根据 ID 在字典中查找对应的向量,不涉及标记之间的交互。
# [17] -> [17×4096]
token_embeddings_unnormalized = embedding_layer(tokens).to(torch.bfloat16) # 默认是全精度的 float32,这里改为半精度格式以减少内存占用。
token_embeddings_unnormalized.shape
torch.Size([17, 4096])
构建第一个 Transformer 块
从下面所示的第一个 Transformer 块所涉及的预训练参数来看,它包括:
- 两个归一化层(attention_norm 和 ffn_norm)
- 注意力机制的实现(4 个 attention.w)
- 前馈网络层的实现(3 个 feed_forward.w)
- (当然,还包括两个不需要预训练参数的残差连接操作)
一般来说,Transformer 块中的操作流程如下:
归一化 -> 多头自注意力 -> 残差连接 -> 归一化 -> 前馈神经网络 -> 残差连接
# 展示第一个 Transformer 块的所有权重参数及其形状
for k, v in model.items():
if not k.startswith('layers'):
continue
if k.startswith('layers.1'):
break
print(k, v.shape)
layers.0.attention.wq.weight torch.Size([4096, 4096])
layers.0.attention.wk.weight torch.Size([1024, 4096])
layers.0.attention.wv.weight torch.Size([1024, 4096])
layers.0.attention.wo.weight torch.Size([4096, 4096])
layers.0.feed_forward.w1.weight torch.Size([14336, 4096])
layers.0.feed_forward.w3.weight torch.Size([14336, 4096])
layers.0.feed_forward.w2.weight torch.Size([4096, 14336])
layers.0.attention_norm.weight torch.Size([4096])
layers.0.ffn_norm.weight torch.Size([4096])
这里需要注意两点:
- 神经网络权重矩阵的形状是 (输出维度, 输入维度)。在计算时,参数矩阵 W 会先转置为 (输入维度, 输出维度),然后与输入 X 相乘,即输出 Y = XW.T。这一点你会在后续计算中看到。
- 由于 Llama3 使用分组注意力机制,每 4 个查询头会共享一组键值向量(详情请参阅上面关于配置文件细节的部分)。因此,键值权重矩阵的维度是 [1024, 4096], 是查询权重矩阵 [4096, 4096] 的 1/4。
归一化
归一化操作旨在约束数据中的尺度差异,避免因向量数值差异过大而导致训练过程不稳定等问题。
归一化后,张量的形状仍保持为 [17×4096],与嵌入的形状相同。
对嵌入使用 RMS 归一化
Llama3 使用均方根(RMS)归一化方法,其计算公式如图所示。
需要注意的是,我们需要一个 norm_eps 参数(来自配置),以防止 RMS 被意外地设为 0,从而导致除以零的错误。
公式如下:
此外,你可能已经注意到公式中的 gi 参数。这是一个在模型训练过程中学习到的缩放因子,用于再次缩放每个维度的归一化结果,以增强模型的表达能力。它的维度与嵌入的特征维度相同,即 [4096]。
# 定义 RMS 归一化的计算函数
# 每个标记将被独立归一化
# norm_weights 是预训练的缩放因子(即公式中的 gi),用于增强模型的表征能力。可以从模型文件中加载,具有 4096 个维度
# torch.rsqrt 用于计算张量平方根的倒数,即 1/RMS(a)
def rms_norm(tensor, norm_weights):
return (tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + norm_eps)) * norm_weights
# 对输入进行归一化
token_embeddings = rms_norm(token_embeddings_unnormalized, model["layers.0.attention_norm.weight"]) # [17×4096] & [4096] -> [17×4096]
model["layers.0.attention_norm.weight"].shape, token_embeddings.shape
(torch.Size([4096]), torch.Size([17, 4096]))
从零开始实现单头注意力机制
在每一层的多头注意力计算中,涉及32个头。然而,这些头的计算过程是完全相同且相互独立的。因此,在本节中,我们将首先实现单头注意力的计算过程,并在下一节将其扩展到多头计算。
注意力机制的核心计算公式如图所示。
- 我们需要通过对输入嵌入进行线性映射,得到查询、键和值向量。
- 随后,基于查询和键向量,我们计算出各个标记之间的注意力权重,即对于每个标记,其他标记对其的重要性或相关性的得分。
- 最后,根据注意力权重对值向量进行加权,得到每个标记对应的注意力结果。
回到正题。让我们先加载第一层Transformer的注意力头。
> 当我们从模型中加载查询、键、值以及输出权重矩阵时(输出权重用于融合多个头的信息,以生成最终的注意力输出),我们会发现它们的形状分别是:[4096×4096]、[1024×4096]、[1024×4096]、[4096×4096]。
> 初看起来这似乎有些奇怪,因为理想情况下,我们希望每个头的q、k、v彼此独立(在这种情况下,它们的形状应该是:32×[128×4096]、8×[128×4096]、8×[128×4096])。
> 代码作者将它们捆绑在一起,是因为这样有助于并行化注意力头的乘法计算。
> 但我们将会把这一切展开...
# 显示当前q、k、v和o的注意力权重矩阵的形状。
print(
model["layers.0.attention.wq.weight"].shape, # [4096×4096]
model["layers.0.attention.wk.weight"].shape, # [1024×4096]
model["layers.0.attention.wv.weight"].shape, # [1024×4096]
model["layers.0.attention.wo.weight"].shape # [4096×4096]
)
torch.Size([4096, 4096]) torch.Size([1024, 4096]) torch.Size([1024, 4096]) torch.Size([4096, 4096])
获取输入标记对应的QKV向量
在本节中,我们将把输入的标记嵌入转换为查询、键和值向量,以便进行注意力机制的计算。
获取查询向量
展开查询权重矩阵
我们首先将来自多个注意力头的查询展开,最终的形状将是[32×128×4096]。
这里,32是Llama3中的注意力头数量,128是查询头的向量维度,而4096则是标记嵌入的维度(嵌入维度位于最后一个维度的原因是,在进行输入与权重相乘时,通常是X*W.T,即与权重的转置相乘)。
# 加载并修改第0层的查询权重矩阵的形状,以将其展开为多头形式
q_layer0 = model["layers.0.attention.wq.weight"] # 默认形状为[4096×4096]
head_dim = q_layer0.shape[0] // n_heads # 注意力头的维度,4096/32 = 128
q_layer0 = q_layer0.view(n_heads, head_dim, dim) # 展开后的维度,[32×128×4096]
q_layer0.shape
torch.Size([32, 128, 4096])
获取第一个头
在这里,我访问了第一层查询权重矩阵的第一个头。该查询权重矩阵的形状是[128×4096]。
# 提取第一个头的权重
q_layer0_head0 = q_layer0[0] # [32×128×4096] -> [128×4096]
q_layer0_head0.shape
torch.Size([128, 4096])
将标记嵌入与查询权重相乘,得到标记对应的查询向量
在这里,你可以看到结果的形状是[17×128]。这是因为我们有17个标记,而对于每个标记,都有一个长度为128的查询向量。
# 计算第一个查询头上的输入查询值
# Q0_head0 = XW0_Q_head0.T
q_per_token = torch.matmul(token_embeddings, q_layer0_head0.T) # [17×4096] x [4096×128] = [17×128]
q_per_token.shape
torch.Size([17, 128])
获取键向量(几乎与查询向量相同)
我想偷个懒,所以不再详细说明键向量的计算过程了。Orz。你只需要记住一点:
> 键同样会生成一个128维的向量。
> 键的权重矩阵参数数量仅为查询的四分之一,这是因为每个键的权重由4个头同时共享,从而减少了所需的计算量。
# 加载并修改第0层的键权重矩阵的形状,使其以多头形式展开
# 与查询权重矩阵不同,键有8个注意力头,因此其参数数量是查询矩阵的四分之一
k_layer0 = model["layers.0.attention.wk.weight"] # [1024×4096]
k_layer0 = k_layer0.view(n_kv_heads, k_layer0.shape[0] // n_kv_heads, dim) # [8×128×4096]
k_layer0.shape
torch.Size([8, 128, 4096])
# 提取第一个头的权重
k_layer0_head0 = k_layer0[0] # [8×128×4096] -> [128×4096]
k_layer0_head0.shape
torch.Size([128, 4096])
# 计算第一个头对应的输入键向量
# K0_head0 = XW0_K_head0.T
k_per_token = torch.matmul(token_embeddings, k_layer0_head0.T) # [17×4096] x [4096×128] = [17×128]
k_per_token.shape
torch.Size([17, 128])
获取值向量(几乎与键向量相同)
> 类似于键的权重,值的权重也是由每4个注意力头共享的(以节省计算量)。
> 因此,值权重矩阵的形状是[8×128×4096]。
# 加载并修改第0层的值权重矩阵的形状,使其以多头形式展开
# 与键权重矩阵类似,值也有8个注意力头,因此其参数数量同样是查询矩阵的四分之一
v_layer0 = model["layers.0.attention.wv.weight"] # [1024×4096]
v_layer0 = v_layer0.view(n_kv_heads, v_layer0.shape[0] // n_kv_heads, dim) # [1024×4096] -> [8×128×4096]
v_layer0.shape
torch.Size([8, 128, 4096])
# 提取第一个头的权重
v_layer0_head0 = v_layer0[0] # [8×128×4096] -> [128×4096]
v_layer0_head0.shape
torch.Size([128, 4096])
# 计算第一个头对应的输入值向量
# V0_head0 = XW0_V_head0.T
v_per_token = torch.matmul(token_embeddings, v_layer0_head0.T) # [17x4096] x [4096x128] = [17x128]
v_per_token.shape
torch.Size([17, 128])
将位置信息添加到查询和键向量中
- 对于自然语言来说,词语之间的顺序关系和相对位置极其重要。例如,“The dog bites the man”和“The man bites the dog”具有完全不同的语义信息。此外,我们的直觉也告诉我们,距离较近的词语之间的相关性通常大于距离较远的词语。
- 因此,在注意力计算过程中,我们需要为每个标记提供位置信息,以便模型能够更好地捕捉序列中的依赖关系。
- 为什么要把位置信息加到查询和键向量上?因为查询和键向量用于计算注意力权重,即每个标记对其他标记的重要性。这就要求它们在计算相似度时,能够同时知道任意两个标记的位置及其相对位置关系。
- 为什么不把位置信息加到值向量上呢?因为值向量只用于加权求和。位置信息已经在查询和键的交互中被考虑到了,因此值向量只需要提供内容信息即可。
我们将使用RoPE(旋转位置编码)来为这些向量添加位置信息。
旋转位置编码(RoPE)
你可以观看这个视频来详细了解它的数学原理(这也是我观看过的视频):
https://www.youtube.com/watch?v=o29P0Kpobz0&t=530s
RoPE 的基本思想是将向量视为处于复数空间中,然后根据位置生成特定的旋转矩阵。通过将向量与旋转矩阵相乘,可以在复数空间中实现旋转,从而将相对位置信息添加到向量中。(也就是说,将输入向量之间的位置关系看作是在复数空间中以不同角度进行的旋转。)
(类似于机器人运动学中通过基于三角函数的矩阵乘法来实现平面位置坐标绕轴的旋转。)
RoPE 通常应用于自注意力机制中的查询和键向量。在计算注意力分数时,首先会根据 RoPE 的相应旋转矩阵对查询和键向量进行旋转,然后再进行点积计算和 softmax 归一化等操作。这样,Transformer 在计算注意力时就能考虑到位置信息,从而更好地捕捉文本中的依赖关系。
RoPE 的具体计算过程如下:
- 将每个向量的维度分成若干对(因为高维旋转矩阵的推导较为复杂,且维度过高会显著增加计算复杂度,而二维旋转的公式相对成熟且简单,易于计算)。
- 对每一对,计算 $\Large \theta=\frac{1}{rope\_theta^{i/D}}$,其中 $i$ 是第 $i$ 对,$D$ 是总对数。这表示当前维度对在向量中的位置信息。
- 对于每个向量,计算 $\Large m$,表示该向量对应于第 $m$ 个标记。即当前向量在整个输入向量中的位置信息。
- 对于每一对,
,其中 $res$ 是向量对在复数空间中旋转 $m\theta$ 度后的结果。 - 对所有向量的所有维度对重复上述计算,得到最终的 RoPE 结果。
在实际代码实现中,为了简化计算过程,上述基于旋转矩阵的计算(步骤 4)会被转换为复数域内的计算。其原理如下:
- 直角坐标 $(x, y)$ 可以被视为复数 $\large x + yi$ 在复平面上的表示。
- 复数的极坐标形式可以表示为 $\large re^{i\theta}$,其中 $r$ 是模长,$\theta$ 是角度。
- 极坐标下的乘法运算 $\large r_1e^{i\theta_1} \times r_2e^{i\theta_2} = r_1r_2e^{i(\theta_1 + \theta_2)}$ 可以理解为将坐标_1 的长度放大 $r_2$ 倍,并将其旋转 $\theta_2$ 度。
- 因此,如果想要将坐标旋转 $m\theta$ 度,可以定义一个模长为 1、角度为 $m\theta$ 的旋转因子 $\large e^{im\theta}$。将其与坐标相乘,就相当于基于旋转矩阵的旋转方法。
- 此外,根据欧拉公式,我们有 $\large re^{i\theta} = r\cos\theta + r\sin{\theta i} = x + yi$,以及 $\large e^{im\theta} = \cos{m\theta} + \sin{m\theta i}$。
- 因此,将二维坐标 $(x, y)$ 旋转 $m\theta$ 度可以通过 $\large re^{i\theta^\prime} \times e^{im\theta} = (x + yi) \times (\cos{m\theta} + sin{m\theta i})$ 来实现(两个复数的乘积)。
向查询向量中添加位置信息
在接下来的步骤中,我们将首先把查询向量按维度方向分成若干对,然后按照上述步骤对每一对进行角度旋转。
现在我们有一个形状为 [17x64x2] 的向量。这是通过将提示中每个标记对应的 128 维查询向量分成 64 对,并对每一对旋转 $m\theta$ 度得到的。
# 按维度方向将查询向量分成若干对。
# .float() 是为了切换回双精度,以确保后续三角函数计算的精度和数值稳定性。
# [17x128] -> [17x64x2]
q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)
q_per_token_split_into_pairs.shape
torch.Size([17, 64, 2])
开始获取旋转矩阵的复数域表示。
# 计算 θ。第一步:计算 i/D。
# [64]
zero_to_one_split_into_64_parts = torch.tensor(range(64))/64 # 每个特征分割后有64对维度,因此需要64个θ值
zero_to_one_split_into_64_parts
tensor([0.0000, 0.0156, 0.0312, 0.0469, 0.0625, 0.0781, 0.0938, 0.1094, 0.1250,
0.1406, 0.1562, 0.1719, 0.1875, 0.2031, 0.2188, 0.2344, 0.2500, 0.2656,
0.2812, 0.2969, 0.3125, 0.3281, 0.3438, 0.3594, 0.3750, 0.3906, 0.4062,
0.4219, 0.4375, 0.4531, 0.4688, 0.4844, 0.5000, 0.5156, 0.5312, 0.5469,
0.5625, 0.5781, 0.5938, 0.6094, 0.6250, 0.6406, 0.6562, 0.6719, 0.6875,
0.7031, 0.7188, 0.7344, 0.7500, 0.7656, 0.7812, 0.7969, 0.8125, 0.8281,
0.8438, 0.8594, 0.8750, 0.8906, 0.9062, 0.9219, 0.9375, 0.9531, 0.9688,
0.9844])
# 计算θ。步骤2:获取θ。
# rope_theta用于控制位置编码的周期性等信息。
# 详情请参阅配置信息部分。
freqs = 1.0 / (rope_theta ** zero_to_one_split_into_64_parts) # [64]
freqs
tensor([1.0000e+00, 8.1462e-01, 6.6360e-01, 5.4058e-01, 4.4037e-01, 3.5873e-01,
2.9223e-01, 2.3805e-01, 1.9392e-01, 1.5797e-01, 1.2869e-01, 1.0483e-01,
8.5397e-02, 6.9566e-02, 5.6670e-02, 4.6164e-02, 3.7606e-02, 3.0635e-02,
2.4955e-02, 2.0329e-02, 1.6560e-02, 1.3490e-02, 1.0990e-02, 8.9523e-03,
7.2927e-03, 5.9407e-03, 4.8394e-03, 3.9423e-03, 3.2114e-03, 2.6161e-03,
2.1311e-03, 1.7360e-03, 1.4142e-03, 1.1520e-03, 9.3847e-04, 7.6450e-04,
6.2277e-04, 5.0732e-04, 4.1327e-04, 3.3666e-04, 2.7425e-04, 2.2341e-04,
1.8199e-04, 1.4825e-04, 1.2077e-04, 9.8381e-05, 8.0143e-05, 6.5286e-05,
5.3183e-05, 4.3324e-05, 3.5292e-05, 2.8750e-05, 2.3420e-05, 1.9078e-05,
1.5542e-05, 1.2660e-05, 1.0313e-05, 8.4015e-06, 6.8440e-06, 5.5752e-06,
4.5417e-06, 3.6997e-06, 3.0139e-06, 2.4551e-06])
# 计算mθ
# 'outer'用于计算外积,'arange(17)'表示每个向量对应的m值(由于输入有17个token,因此需要17个m值)。
# 结果的形状为[17x64],这意味着每个token对应的向量都有64个mθ值,这些值用于计算64对维度的旋转。
freqs_for_each_token = torch.outer(torch.arange(17), freqs) # [17] & [64] -> [17x64]
# 获取(cos mθ + sin mθ i),即把mθ转换为复数形式
# 将旋转角度mθ视为模为1的极坐标形式,然后将其转换为复数表示
# 'polar'的两个输入分别表示模(设置为1,意味着只改变角度而不影响长度)和角度(即mθ)
freqs_cis = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token) # [17x64] -> [17x64]
print(freqs_cis.shape)
# 查看freqs_cis在某些位置的值,仅用于展示
token_to_show = [1, 3, 5] # 查看第2、4、6行
fig, axs = plt.subplots(1, len(token_to_show), figsize=(5 * len(token_to_show), 4)) # 生成一个包含3个子图的单行图窗
for i, index in enumerate(token_to_show):
value = freqs_cis[index]
for j, element in enumerate(value):
# 从原点到坐标点画一条蓝色线,实部作为x坐标,虚部作为y坐标。
axs[i].plot([0, element.real], [0, element.imag], color='blue', linewidth=1, label=f"Index: {j}")
# 用红色数字标注表示第i对维度。
axs[i].annotate(f"{j}", xy=(element.real, element.imag), color='red')
axs[i].set_xlabel('Real')
axs[i].set_ylabel('Imaginary')
axs[i].set_title(f'Plot of {index + 1}th of freqs_cis')
plt.show()
"""
注意:如图所示,位置靠后的token具有更大的旋转角度,但在单个token内,较早的向量维度具有更大的旋转角度。
如果感兴趣,可以进一步探索这背后是否存在数学原因。X_X
"""
torch.Size([17, 64])

'\n注意:如图所示,位置靠后的token具有更大的旋转角度,但在单个token内,较早的向量维度具有更大的旋转角度。\n 如果感兴趣,可以进一步探索这背后是否存在数学原因。X_X\n'
现在我们已经为每个token对应的查询向量的每一对维度提供了一个复数(一个改变角度的向量)。
现在我们可以将我们的查询(已分成对的)转换为复数,然后通过点积计算来旋转这些查询。 :)
# 获取(x + yi)
# 即将维度对转换为复数。转换后,维度的形状将由[17x64x2]变为[17x64]。
q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs) # [17x64x2] -> [17x64]
q_per_token_as_complex_numbers.shape
torch.Size([17, 64])
# 计算(x + yi) * (cos mθ + sin mθ i)
# 即执行旋转操作以得到最终结果。
q_per_token_as_complex_numbers_rotated = q_per_token_as_complex_numbers * freqs_cis # [17x64] * [17x64] = [17x64]
q_per_token_as_complex_numbers_rotated.shape
torch.Size([17, 64])
获取旋转后的向量(恢复形状)。
我们可以将复数再次表示为实数,从而以维度对的形式获得查询结果。
# 将复数结果转换回实数维度对的形式。
q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers_rotated) # [17x64] -> [17x64x2]
q_per_token_split_into_pairs_rotated.shape
torch.Size([17, 64, 2])
合并旋转后的维度。这样,我们就得到了一个新的查询向量(旋转后的查询向量),其形状为[17x128],其中17表示token的数量,128表示查询向量的维度。
# 将维度对的结果恢复为原始的查询向量形式,得到最终的查询向量。
q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape) # [17x64x2] -> [17x128]
q_per_token_rotated.shape
torch.Size([17, 128])
为键向量添加位置信息(与查询相同)
# 沿维度方向将键向量分成对,形成维度对(修改形状)。
k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2) # [17x128] -> [17x64x2]
k_per_token_split_into_pairs.shape
torch.Size([17, 64, 2])
# 获取(x + yi)
# 即将维度对转换为复数。转换后,维度的形状将从 [17x64x2] 变为 [17x64]。
k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs) # [17x64x2] -> [17x64]
k_per_token_as_complex_numbers.shape
torch.Size([17, 64])
# 计算 (x + yi) * (cosmθ + sinmθi)
# 即执行旋转操作以得到最终结果。
# 然后将结果转换回实数形式。
k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis) # [17x64] * [17x64] = [17x64] -> [17x64x2]
k_per_token_split_into_pairs_rotated.shape
torch.Size([17, 64, 2])
# 将维度对的结果恢复为原始的键向量形式,从而得到最终的键向量。
k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape) # [17x64x2] -> [17x128]
k_per_token_rotated.shape
torch.Size([17, 128])
至此,我们已经得到了每个 token 对应的旋转后的查询向量和键向量。
每个查询和键向量的形状仍然是 [17x128]。
一切准备就绪,现在开始计算 token 之间的注意力权重。
这将涉及三个步骤:
- 计算注意力分数:score = Q x K
- 掩码未来 token:score = mask(score)
- 计算注意力权重:res = softmax(score)
让我们开始吧! :)
将查询向量和键向量相乘,得到注意力分数。
通过这种方式,我们将得到每个 token 与所有其他 token 之间的分数值。
这些分数表示每个 token 的查询与所有其他 token 的键之间的相关性强度。
这就是自注意力机制!
这个注意力分数矩阵(qk_per_token)的形状是 [17x17],其中 17 是输入 prompt 中的 token 数量。
# 计算注意力分数
# 同时进行归一化处理,以防止后续的 softmax 计算结果过于偏向 0 或 1,
# (当维度较大时,点积值可能会过大),
# 这可能导致梯度消失或梯度爆炸,从而保持数值稳定性。
qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(head_dim)**0.5 # [17x128] x [128x17] = [17x17]
qk_per_token.shape
torch.Size([17, 17])
现在我们需要对未来的查询-键分数进行掩码处理。
在 Llama 3 的训练过程中,未来 token 的 QK 分数会被掩码掉。
为什么呢?因为在训练时,我们只学习如何利用过去的 token 来预测当前的 token。如果不进行掩码处理,就会导致预测信息的泄露。
因此,在推理过程中,我们也需要将未来 token 的分数置为 0(以确保训练和推理过程的一致性)。
当然,如果你和我一样好奇不进行掩码会有什么后果,可以在学完本节内容后查看我在最后一节中进行的额外实验结果。(^_<)
# 首先看一下掩码前的分数矩阵
def display_qk_heatmap(qk_per_token):
_, ax = plt.subplots() # 创建一个绘图窗口
# `imshow` 常用于以二维数组或矩阵的形式显示数据,
# 它会将矩阵元素映射为灰度或颜色值,因此可以用来绘制热图。
# 先将张量转换回全精度,然后从计算图中分离出来,以避免潜在的梯度计算和存储问题。
# 指定使用 'viridis' 颜色映射方案来显示图像(蓝色 -> 绿色 -> 黄色)。
im = ax.imshow(qk_per_token.to(float).detach(), cmap='viridis')
# 设置 x 轴和 y 轴刻度的数量和标签,以确保正确的一一对应关系。
ax.set_xticks(range(len(prompt_split_as_tokens)))
ax.set_yticks(range(len(prompt_split_as_tokens)))
ax.set_xticklabels(prompt_split_as_tokens)
ax.set_yticklabels(prompt_split_as_tokens)
# 在旁边添加一个颜色条。
# 指定 `im` 来识别正确的颜色映射和取值范围。
# 指定它所属的子图为 `ax`(如果有多个子图,则为 `ax = ax[i]`)。
ax.figure.colorbar(im, ax=ax)
display_qk_heatmap(qk_per_token)

# 生成掩码矩阵
# 将需要掩码的位置设置为负无穷,不需要掩码的位置设置为 0。
# 然后将其加到分数矩阵上,以实现掩码效果(计算 softmax 时,负无穷会趋近于 0)。
# `torch.full` 用于生成具有指定形状和填充值的张量。
# 这里首先生成一个充满负无穷的 [17x17] 矩阵。
# 指定该矩阵的设备与之前 token 的设备相同,以确保后续计算不会出错,
# 例如,如果之前的 token 在 GPU 上,而这里没有指定设备,那么 `mask` 就会在 CPU 上重新创建,相加时就会出错。
mask = torch.full((len(tokens), len(tokens)), float("-inf"), device=tokens.device) # [17x17]
# `torch.triu` 用于返回矩阵的上三角部分,并将其余部分置零(使用 `torch.tril` 获取下三角部分)。
# `diagonal` 是对角线的偏移量。当它为 1 时,表示取主对角线上方 1 个位置开始的上三角部分,以避免掩码掉当前 token 自身。
mask = torch.triu(mask, diagonal=1) # [17x17]
mask, mask.shape
(tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
[0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
[0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
[0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
[0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
[0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
[0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
[0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
[0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]),
torch.Size([17, 17]))
# 掩码未来 token 的得分
qk_per_token_after_masking = qk_per_token + mask # [17x17] + [17x17] = [17x17]
display_qk_heatmap(qk_per_token_after_masking) # 显示掩码后的注意力得分

计算最终的注意力权重,即对得分进行 softmax 操作。
# 计算注意力权重
# 即计算得分的 softmax 值。
# `dim = 1` 表示按行进行 softmax 计算,结果转换为半精度以与后续的值向量保持一致。
qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16) # [17x17] -> [17x17]
display_qk_heatmap(qk_per_token_after_masking_after_softmax)

终于!计算单头注意力机制的最终结果!
原理:利用之前的注意力权重(范围在 0 到 1 之间),确定每个 token 应该使用每个值向量的多少比例(即对值向量进行加权)。
示例:如果输入包含 3 个 token,第一个 token 的注意力结果可能是:res = 0.6 * value_1 + 0.3 * value_2 + 0.1 * value_3
权重矩阵与值矩阵相乘后,注意力结果的形状为 [17x128]。
# 计算单头注意力的最终结果
qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token) # [17x17] x [17x128] = [17x128]
qkv_attention.shape
torch.Size([17, 128])
计算多头注意力机制(通过简单循环重复上述过程)
我们现在得到了第一层第一个头的注意力值。
接下来需要通过一个循环,对第一层中的每一个头执行与上一单元格完全相同的数学运算。
值得注意的是,在 官方 Llama3 代码实现中,多头注意力的计算采用了一次性矩阵乘法的方式,而不是耗时的 for 循环计算。其一般流程如下:
- 基于矩阵并行性,计算 QKV 向量:[17x4096] × [4096x4096] 或 [4096x1024] = [17x4096] 或 [17x1024],然后将其重塑为 [32x17x128] 或 [8x17x128]。
- 获得 QKV 向量后,将 K 和 V 向量的内部部分复制,使其形状与 Q 向量一致。此时三者的形状均为 [32x17x128]。
- 在计算得分时,使用转置方法交换张量最后两个维度的位置,完成矩阵乘法。例如,
torch.matmul(q, k.transpose(1,2)) / head_dim ** 0.5。此时为 [32x17x128] × [32x128x17] = [32x17x17]。 - 其他矩阵计算也遵循同样的原理。
注:上述过程中每一步的矩阵形状变化都是简化版本,仅用于说明以便理解,与官方 Llama3 实现中的变化过程有所不同(官方实现涉及大量的形状变换操作)。
计算每个头的结果
# 计算多头注意力结果
# 即,上一次单头注意力计算过程的循环
qkv_attention_store = []
for head in range(n_heads):
# 提取当前头对应的QKV权重矩阵
q_layer0_head = q_layer0[head] # [32x128x4096] -> [128x4096]
k_layer0_head = k_layer0[head//4] # 每4个头共享一个键权重,[8x128x4096] -> [128x4096]
v_layer0_head = v_layer0[head//4] # 每4个头共享一个值权重,[8x128x4096] -> [128x4096]
# 计算XW以得到QKV向量
# [17x4096] x [4096x128] = [17x128]
q_per_token = torch.matmul(token_embeddings, q_layer0_head.T)
k_per_token = torch.matmul(token_embeddings, k_layer0_head.T)
v_per_token = torch.matmul(token_embeddings, v_layer0_head.T)
# 将位置信息添加到查询向量(RoPE)
q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2) # 沿维度方向将向量分成成对的形式,形成维度对。[17x128] -> [17x64x2]
q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs) # 转换为复数表示,(x,y) -> (x+yi)。[17x64x2] -> [17x64]
q_per_token_as_complex_numbers_rotated = q_per_token_as_complex_numbers * freqs_cis[:len(tokens)] # 计算(x+yi)*(cosmθ+sinmθi),完成旋转操作。[17x64] * [17x64] = [17x64]
q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers_rotated) # 将结果转换回实数表示,(x+yi) -> (x,y)。[17x64] -> [17x64x2]
q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape) # 将结果恢复为原始向量形状,得到最终的查询向量。[17x64x2] -> [17x128]
# 将位置信息添加到键向量(RoPE)
k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2) # 沿维度方向将向量分成成对的形式,形成维度对。[17x128] -> [17x64x2]
k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs) # 转换为复数表示,(x,y) -> (x+yi)。[17x64x2] -> [17x64]
k_per_token_as_complex_numbers_rotated = k_per_token_as_complex_numbers * freqs_cis[:len(tokens)] # 计算(x+yi)*(cosmθ+sinmθi),完成旋转操作。[17x64] * [17x64] = [17x64]
k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers_rotated) # 将结果转换回实数表示,(x+yi) -> (x,y)。[17x64] -> [17x64x2]
k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape) # 将结果恢复为原始向量形状,得到最终的键向量。[17x64x2] -> [17x128]
# 同时计算注意力分数并对其进行归一化(即Q×K/sqrt(dim))
qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(head_dim)**0.5 # [17x128] x [128x17] = [17x17]
# 对未来 tokens 的分数进行掩码处理
mask = torch.full(qk_per_token.shape, float("-inf"), device=tokens.device) # 创建与注意力分数相同形状的矩阵,填充负无穷,并存储在与其他向量相同的设备上,以防止后续计算出错。[17x17]
mask = torch.triu(mask, diagonal=1) # 保留上三角部分的负无穷,其余部分置为零(即上三角区域代表需要被掩码的未来 tokens)。对角线偏移为1,以避免掩码当前 token 自身。[17x17]
qk_per_token_after_masking = qk_per_token + mask # 将注意力分数与掩码矩阵相加,使分数矩阵的上三角部分变为负无穷,这将在后续的 softmax 操作后趋近于零。[17x17]
# 计算注意力权重(即 softmax(score))
# 同时将其转换回半精度(因为稍后会与值向量 v_per_token 相乘,因此数据类型需要一致)。
qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16) # 按行计算 softmax。[17x17]
# 计算注意力机制的最终结果(即 softmax(score) × V)
qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token) # [17x17] × [17x128] = [17x128]
# 记录该头的结果
qkv_attention_store.append(qkv_attention)
len(qkv_attention_store)
32
将各头的结果合并为一个大矩阵
我们几乎完成了注意力层的计算 :)
# 合并多头注意力矩阵
stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1) # 沿第二个维度拼接,32x[17x128] -> [17x4096]
stacked_qkv_attention.shape
torch.Size([17, 4096])
头与头之间的信息交互(线性映射),自注意力层的最后一道工序!
# 加载 layer0 的输出权重矩阵
w_layer0 = model["layers.0.attention.wo.weight"] # [4096x4096]
w_layer0.shape
torch.Size([4096, 4096])
这只是一个简单的线性层,所以我们只需要进行矩阵乘法。
# 对注意力矩阵进行线性映射
# 这就是注意力层的最终输出
embedding_delta = torch.matmul(stacked_qkv_attention, w_layer0.T) # [17x4096] x [4096x4096] = [17x4096]
embedding_delta.shape
torch.Size([17, 4096])
进行残差运算(相加)
# 将注意力层的输出与原始输入相加,完成残差运算
embedding_after_edit = token_embeddings_unnormalized + embedding_delta # [17x4096] + [17x4096] = [17x4096]
embedding_after_edit.shape
torch.Size([17, 4096])
进行第二次归一化操作
# 对残差操作的结果进行归一化
embedding_after_edit_normalized = rms_norm(embedding_after_edit, model["layers.0.ffn_norm.weight"]) # [17x4096] & [4096] -> [17x4096]
embedding_after_edit_normalized.shape
torch.Size([17, 4096])
执行前馈神经网络(FFN)层的计算
在Llama3中,他们使用了SwiGLU前馈网络。这种网络架构可以在模型需要时有效增强非线性特性。
如今,这类前馈网络架构在大型语言模型中非常常见。
为什么引入非线性层:
- 非线性是神经网络模型能够被视为“通用函数逼近器”的核心原因。在传统的神经网络模型中,我们通过使用非线性激活函数(如Sigmoid、ReLU等)来增加模型的表达能力,使其能够拟合训练数据中隐藏的复杂模式。
- 然而,在Transformer中,注意力机制本质上是对值向量的线性加权求和(尽管权重是通过Softmax函数的非线性计算得到的,但它仍然是对值的线性加权)。因此,虽然它可以捕捉全局依赖关系,但其输出仍然只是输入的线性组合。此时,Transformer模型实际上缺乏非线性能力。
- 因此,在自注意力层之后添加一个FFN网络,为模型引入非线性变换能力,从而提升模型对复杂语义关系的建模能力,是十分必要的。
通常,引入非线性层可以起到以下作用:
- 为模型增加非线性能力,以促进模型的学习和训练。
- 增强模型的信息抽象能力,使模型能够在逐层学习过程中表示不同层次的数据特征和模式。例如,较低层的网络可以识别基本的语言结构(如词性),而较高层的网络则可以理解更复杂的语义信息(如情感、意图)。
- 此外,目前有一种观点认为,注意力层主要用于处理输入上下文交互,而FFN层则是大型语言模型在训练过程中主要存储和记忆通用知识的地方(由于其非线性表示能力),以便在回答输入问题时能够从这些通用知识中找到答案。
SwiGLU网络结构:
- 对输入进行线性变换:$X^\prime = XW_3$
- 门控单元:$GATE = Activation\_Function(XW_1)$,用于有选择地传递信息。也就是说,假设$X^\prime$中的信息具有不同的重要性,那么应根据门控单元的得分对信息进行加权并传递,从而提高模型的表达能力。
- 使用的激活函数是Swish激活函数(因此该网络被称为SwiGLU,它是Swish激活函数与门控线性单元(GLU)的结合)。公式为:$Swish = X \cdot \sigma(\beta X)$,其中$\sigma$为Sigmoid激活函数。在SwiGLU中,$\beta$被设置为1(在原始公式中,它是一个可学习的参数)。
- 因此,门控单元的具体计算为:$GATE = XW_1 \cdot \sigma(XW_1)$。在PyTorch中,这个激活函数被称为silu,即$GATE = silu(XW_1)$。
- 应用门控机制:$X^\prime = X^\prime \cdot GATE$
- 再次进行线性变换:$Y = X^\prime W_2$
前馈层隐藏层维度大小的计算(基于Llama3的官方实现过程):
- 输入维度为dim = 4096
- hidden_dim = 4 * dim = 16384 # 首先将其放大四倍。在初始化Transformer块中的前馈层时,输入的hidden_dim会被乘以四。
- hidden_dim = int(2 * hidden_dim / 3) = 10922 # 然后将其放大2/3倍。这种缩放首先在前馈层内部进行。
- hidden_dim = int(ffn_dim_multiplier * hidden_dim) = int(1.3 * 10922) = 14198 # 接着再按ffn_dim_multiplier倍数放大。模型配置文件中将ffn_dim_multiplier定义为1.3。
- hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) = 1024 * ((14198 + 1024 - 1) // 1024) = 14336 # 最后调整为multiple_of的整数倍。模型配置文件中将multiple_of定义为1024,以确保模型中所有隐藏层的维度都是1024的倍数,从而提高计算效率。
- 最终,我们得到隐藏层的维度大小为14336。
# 计算前馈网络层
# 隐藏层的维度大小为14336
w1 = model["layers.0.feed_forward.w1.weight"] # [14336x4096]
w3 = model["layers.0.feed_forward.w3.weight"] # [14336x4096]
w2 = model["layers.0.feed_forward.w2.weight"] # [4096x14336]
print(w1.shape, w3.shape, w2.shape)
# output = (silu(XW1) * XW3)W2
# [17x4096] x [4096x14336] x [14336x4096] = [17x4096]
output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T)
output_after_feedforward.shape
torch.Size([14336, 4096]) torch.Size([14336, 4096]) torch.Size([4096, 14336]) torch.Size([17, 4096])
再次执行残差操作(最终我们得到了Transformer块的最终输出!)
# 将前馈层的输出加到原始输入上,完成残差操作
# 这就是Transformer块的最终结果
layer_0_embedding = embedding_after_edit+output_after_feedforward # [17x4096] + [17x4096] = [17x4096]
layer_0_embedding.shape
torch.Size([17, 4096])
终于,我们得到了经过第一层处理后的每个token的新嵌入。
接下来只需再完成31层即可(只需要一个循环)。
你可以想象,这个经过处理的嵌入包含了第一层中提出的token的所有信息。
现在,每一层都会对问题中提出的查询进行更复杂的编码。直到最后,我们将得到一个能够掌握下一个所需token所有信息的嵌入。
一切都在这里。让我们完成所有32个Transformer块的计算吧。祝阅读愉快 :)
是的,就是这样。我们之前所做的所有工作都将在这里一次性呈现出来,以完成每一层的计算。
# 现在,让我们开始完成所有32个Transformer块的计算吧!
# 使用输入标记的嵌入作为初始输入。
final_embedding = token_embeddings_unnormalized # [17×4096]
# 对32层Transformer块进行逐层计算
for layer in range(n_layers):
#########################################################################################################################
################### 第一轮:归一化 - 特征变换 - 残差操作 ###############################
########################### 第一次归一化 ###################################################
# 第一次归一化
layer_embedding_norm = rms_norm(final_embedding, model[f"layers.{layer}.attention_norm.weight"]) # [17×4096] & [4096] -> [17×4096]
################ 第一次特征变换 - 多头自注意力 ########################
# 获取当前层注意力机制的qkv权重矩阵
q_layer = model[f"layers.{layer}.attention.wq.weight"] # [4096×4096]
q_layer = q_layer.view(n_heads, q_layer.shape[0] // n_heads, dim) # [32×128×4096]
k_layer = model[f"layers.{layer}.attention.wk.weight"] # [1024×4096]
k_layer = k_layer.view(n_kv_heads, k_layer.shape[0] // n_kv_heads, dim) # [8×128×4096]
v_layer = model[f"layers.{layer}.attention.wv.weight"] # [1024×4096]
v_layer = v_layer.view(n_kv_heads, v_layer.shape[0] // n_kv_heads, dim) # [8×128×4096]
# 用于存储每个头的注意力机制计算结果
qkv_attention_store = []
# 计算每个头的注意力机制结果
for head in range(n_heads):
# 提取当前头对应的QKV权重矩阵
q_layer_head = q_layer[head] # [32×128×4096] -> [128×4096]
k_layer_head = k_layer[head//4] # 每4个头共享一个键权重,[8×128×4096] -> [128×4096]
v_layer_head = v_layer[head//4] # 每4个头共享一个值权重,[8×128×4096] -> [128×4096]
# 计算XW以得到QKV向量
# [17×4096] × [4096×128] = [17×128]
q_per_token = torch.matmul(layer_embedding_norm, q_layer_head.T)
k_per_token = torch.matmul(layer_embedding_norm, k_layer_head.T)
v_per_token = torch.matmul(layer_embedding_norm, v_layer_head.T)
# 为查询向量添加位置信息(RoPE)
q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2) # 沿维度方向将向量分成对,形成维度对。[17×128] -> [17×64×2]
q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs) # 转换为复数表示,(x,y) -> (x+yi)。[17×64×2] -> [17×64]
q_per_token_as_complex_numbers_rotated = q_per_token_as_complex_numbers * freqs_cis # 计算(x+yi)×(cosθ+sinθi),完成旋转操作。[17×64] × [17×64] = [17×64]
q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers_rotated) # 将结果转回实数表示,(x+yi) -> (x,y)。[17×64] -> [17×64×2]
q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape) # 将结果转回原始向量形状,得到最终的查询向量。[17×64×2] -> [17×128]
# 为键向量添加位置信息(RoPE)
k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2) # 沿维度方向将向量分成对,形成维度对。[17×128] -> [17×64×2]
k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs) # 转换为复数表示,(x,y) -> (x+yi)。[17×64×2] -> [17×64]
k_per_token_as_complex_numbers_rotated = k_per_token_as_complex_numbers * freqs_cis # 计算(x+yi)×(cosθ+sinθi),完成旋转操作。[17×64] × [17×64] = [17×64]
k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers_rotated) # 将结果转回实数表示,(x+yi) -> (x,y)。[17×64] -> [17×64×2]
k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape) # 将结果转回原始向量形状,得到最终的键向量。[17×64×2] -> [17×128]
# 计算注意力分数并同时对分数进行归一化(即Q×K/√dim)
qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(128)**0.5 # [17×128] × [128×17] = [17×17]
# 掩码未来token的分数
mask = torch.full(qk_per_token.shape, float("-inf"), device=qk_per_token.device) # 创建与注意力分数相同形状的矩阵,填充负无穷,并存储在与其他向量相同的设备上,以避免后续计算出错。[17×17]
mask = torch.triu(mask, diagonal=1) # 保留上三角部分的负无穷,其余设为0(即上三角区域代表需要掩码的未来token)。对角线偏移1,以避免掩码当前token本身。[17×17]
qk_per_token_after_masking = qk_per_token + mask # 将注意力分数与掩码矩阵相加,使分数矩阵的上三角部分变为负无穷,在后续的softmax操作后会趋近于0。[17×17]
# 计算注意力权重(即softmax(分数))
# 同时将其转换为半精度(因为后续会与值向量v_per_token相乘,数据类型需一致)。
qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16) # 按行计算softmax。[17×17]
# 计算注意力机制的最终结果(即softmax(分数) × V)
qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token) # [17×17] × [17×128] = [17×128]
# 记录该头的结果
qkv_attention_store.append(qkv_attention)
# 合并多头注意力结果
stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1) # 沿第二维度合并,即32×[17×128] -> [17×4096]
# 对结果进行线性映射,生成最终的多头自注意力机制结果
o_layer = model[f"layers.{layer}.attention.wo.weight"]
embedding_delta = torch.matmul(stacked_qkv_attention, o_layer.T) # [17×4096] × [4096×4096] = [17×4096]
########################### 第一个残差操作 ##############################################
# 第一个残差操作
# 将注意力层的输出与原始输入相加,完成残差连接
embedding_after_edit = final_embedding + embedding_delta # [17x4096] + [17x4096] = [17x4096]
#########################################################################################################################
#################### 第二轮:归一化 - 特征变换 - 残差操作 ##############################
########################### 第二次归一化 ##################################################
# 第二次归一化
embedding_after_edit_normalized = rms_norm(embedding_after_edit, model[f"layers.{layer}.ffn_norm.weight"]) # [17x4096] & [4096] -> [17x4096]
################## 第二次特征变换 - 前馈网络 ##########################
# 加载前馈网络(SwiGLU)的参数矩阵
w1 = model[f"layers.{layer}.feed_forward.w1.weight"] # [14336x4096]
w3 = model[f"layers.{layer}.feed_forward.w3.weight"] # [14336x4096]
w2 = model[f"layers.{layer}.feed_forward.w2.weight"] # [4096x14336]
# 计算前馈网络的结果(输出 = (silu(XW1) * XW3)W2)
# [17x4096] x [4096x14336] x [14336x4096] = [17x4096]
output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T)
########################### 第二次残差操作 ##############################################
# 第二次残差操作,得到当前Transformer块的最终输出结果
# 将前馈层的输出与原始输入相加,完成残差连接
final_embedding = embedding_after_edit+output_after_feedforward # [17x4096] + [17x4096] = [17x4096]
让我们完成最后一步,预测下一个标记
现在我们已经得到了最终的嵌入表示,其中包含了预测下一个标记所需的所有信息。
这个嵌入的形状与输入标记嵌入的形状相同,都是[17x4096],其中17是标记的数量,4096是嵌入的维度。
首先,对最后一个Transformer层的输出进行最后一次归一化
# 在整个模型中执行最后一次归一化
final_embedding = rms_norm(final_embedding, model["norm.weight"]) # [17x4096] & [4096] -> [17x4096]
final_embedding.shape
torch.Size([17, 4096])
然后,基于最后一个标记对应的嵌入进行预测(通过线性映射到词汇表维度)
我们将使用输出解码器(一个线性映射层)将最后一个标记的嵌入向量转换为下一个标记的预测结果(维度为词汇表大小。如果我们对结果应用softmax函数,每个维度的值就代表下一个标记属于该词的概率)。
为什么我们只使用最后一个标记的输出向量来预测下一个标记呢?
因为在训练过程中,模型的目标是根据当前标记及其之前的所有标记来预测下一个标记。因此,每个标记对应的输出向量用于预测它自己之后的下一个标记,而不是整个输入序列的下一个标记。
在我们的示例中,我们希望答案是42 :)
注:42是《银河系漫游指南》一书中“生命、宇宙以及任何事情的终极问题的答案”。大多数现代大型语言模型都会回答42,这将验证我们整个代码的正确性!祝我们好运 :)
# 执行最后一次线性映射,将嵌入映射到词汇表维度大小,作为下一个标记的预测
logits = torch.matmul(final_embedding[-1], model["output.weight"].T) # [17x4096] -> [4096] -> [4096] x [4096x128256] = [128256]
logits.shape
torch.Size([128256])
这就是预测结果!
# 提取概率最高的维度对应的id,
# 就是预测的下一个标记的id
next_token = torch.argmax(logits, dim=-1) # 获取最大值对应的索引,即预测的下一个标记id。[128256] -> [1]
next_token
tensor(2983)
# 根据预测的id,还原为具体的预测值
tokenizer.decode([next_token.item()])
'42'
让我们深入探讨一下,不同的嵌入或标记掩码策略可能会如何影响预测结果 :)
现在我们已经得到了最终的预测结果。如果你仍然感兴趣,不妨探索一下之前提到的一些问题~
我们将简要探讨三种情况:
- 除了top-1结果之外,当前预测中还预测了哪些内容,即top-k结果?
- 如果我们使用其他标记的输出嵌入来进行预测,会得到什么结果?
- 如果在之前的注意力计算中没有对未来的标记进行掩码,预测结果会有什么不同?
# 首先来看看top-k预测结果
logits_sort, logits_idx = torch.sort(logits, dim=-1, descending=True) # 将预测概率最高的标记放在最前面,[128256]
[tokenizer.decode([i]) for i in logits_idx[:10]] # 查看概率最高的前10个结果
['42', '6', '43', '41', '4', '1', '45', '3', '2', '46']
# 接下来,让我们看看使用其他标记的嵌入进行预测能得到什么
logits_all_token = torch.matmul(final_embedding, model["output.weight"].T) # 将嵌入映射到与词汇表相同大小,[17x4096] x [4096x128256] = [17x128256]
logits_all_token_sort, logits_all_token_idx = torch.sort(logits_all_token, dim=-1, descending=True) # 将预测概率最高的标记放在最前面,[17x128256]
print('输入标记:', prompt_split_as_tokens) # 显示输入标记,[17]
# 根据每个标记的输出嵌入显示下一个标记预测的结果
for i in range(len(final_embedding)):
print(f'基于第{i+1}个标记的预测结果:', [tokenizer.decode([j]) for j in logits_all_token_idx[i][:10]]) # 输出概率最高的前10个结果
_="""
可以看出,当基于每个标记进行预测时,预测结果是“当前标记”之后下一个标记的可能结果,
而不是对整个完整输入的预测结果。
因此,在实际预测中,只会使用最后一个标记的嵌入来进行预测。
"""
输入标记: ['<|begin_of_text|>', 'the', ' answer', ' to', ' the', ' ultimate', ' question', ' of', ' life', ',', ' the', ' universe', ',', ' and', ' everything', ' is', ' ']
基于第1个标记的预测结果: ['Question', 'def', '#', 'The', 'import', 'Tags', 'A', 'package', 'Home', 'I']
基于第2个标记的预测结果: [' ', ' best', ' first', ' most', ' new', ' world', ' last', ' same', ' way', ' number']
基于第3个标记的预测结果: [' to', ' is', ' was', ' of', ' lies', ',', ' for', ' you', ' key', ' will']
基于第4个标记的预测结果: [' the', ' this', ' your', ' all', ' that', ' a', ' my', ' life', ' "', ' everything']
基于第5个标记的预测结果: [' question', ' problem', ' above', ' ultimate', ' first', ' r', ' following', ' questions', ' most', ' previous']
基于第6个标记的预测结果: [' question', ' questions', ' mystery', '\xa0', ' quest', '\n', ' life', ' philosophical', ' qu', ' problem']
基于第7个标记的预测结果: [' of', '\n', ' to', ' is', '?\n', ',', '.\n', ':', '...']
基于第8个标记的预测结果: [' life', ' Life', ' the', '\xa0', ' everything', ' existence', '\n', ' LIFE', ' all', ' human']
基于第9个标记的预测结果: [',', ' the', '\n', ' and', ' is', ',\n', '.\n', '?\n', '...']
基于第10个标记的预测结果: [' the', ' universe', ' and', ' etc', '\xa0', ' is', ' death', ' of', ' or', ' everything']
基于第11个标记的预测结果: [' universe', ' Universe', '\n', '\xa0', ' un', ' univers', ' uni', ' cosmos', ' universal', ' u']
基于第12个标记的预测结果: [',', ' and', ' &', '\n', ',\n', ' ,', '...', ',and', '...', '\xa0']
基于第13个标记的预测结果: [' and', ' everything', ' &', ' the', ' etc', '\xa0', ' is', ' or', ' ...\n', ' an']
基于第14个标记的预测结果: [' everything', '\xa0', ' the', ' every', '\n', ' ever', ' all', ' Everything', ' EVERY', '...']
基于第15个标记的预测结果: ['\n', ' is', '.\n', '.', '?\n', ',', ' (', '\n\n', '...', '\n', ' in']
基于第16个标记的预测结果: [' ', '\n', '...', '...', ':', ' forty', ' not', ' "', '…', ' a']
基于第17个标记的预测结果: ['42', '6', '43', '41', '4', '1', '45', '3', '2', '46']
# 最后,让我们看看在计算注意力时不屏蔽未来标记时,预测结果会是什么样子
# 此时,基于每个标记的预测结果如下
# 可以看出,由于可以看到未来的标记,每个标记的嵌入能够更准确地预测“它后面的下一个标记”(有点像“作弊”)
_="""
输入标记: ['<|begin_of_text|>', 'the', ' answer', ' to', ' the', ' ultimate', ' question', ' of', ' life', ',', ' the', ' universe', ',', ' and', ' everything', ' is', ' ']
基于第1个标记的预测结果: ['://', '.Forms', '_REF', ' Angeles', '.swing', '', 'php', 'во', 'ysics', '']
基于第2个标记的预测结果: [' answer', ' Hitch', ' universe', ' question', ' ultimate', ' meaning', ' hitch', ' Universe', ' Answer', ' reason']
基于第3个标记的预测结果: [' to', ' is', ',', ':', ' was', '\n', ' ', ' (', '\n\n', ' of']
基于第4个标记的预测结果: [' the', ' life', ' this', ' which', ' everything', ' that', ' how', ' why', ' ', ' all']
基于第5个标记的预测结果: [' ultimate', ' question', ' great', ' meaning', ' universe', ' Ultimate', ' everything', ' life', ' holy', ' greatest']
基于第6个标记的预测结果: [' question', ' answer', ' is', ' was', '\n', ' questions', ' mystery', '\n\n', ' what', ' Question']
基于第7个标记的预测结果: [' of', ' is', '\n', ',', ' about', ':', ' to', ' in', ' (', '<|end_of_text|>']
基于第8个标记的预测结果: [' life', ' existence', ' everything', ' Life', ' the', ' death', ' time', ' all', ' why', ' which']
基于第9个标记的预测结果: [',', ' is', ' the', '\n', ':', ' (', '...', ' and', ' ,', ' -']
基于第10个标记的预测结果: [' the', ' and', ' is', ' death', ' The', ' which', ' or', '\xa0', ' existence', ' don']
基于第11个标记的预测结果: [' universe', ' answer', ' cosmos', ' world', ' existence', ' Universe', ' everything', ' un', ' meaning', ' question']
基于第12个标记的预测结果: [',', ' and', ' is', ' &', '\n', ' ,', '.', '...', ' (', ' ']
基于第13个标记的预测结果: [' and', ' &', ' don', ' the', ' is', ' a', ' or', ' Douglas', '\xa0', '<|end_of_text|>']
基于第14个标记的预测结果: [' everything', ' dough', ' don', ' ever', ' deep', ' Douglas', ' the', ' every', ' all', ' death']
基于第15个标记的预测结果: ['\n', ' is', ',', '.', ' ', ' (', ':', '<|end_of_text|>', '\n\n', '.\n']
基于第16个标记的预测结果: [' ', '\n', ' forty', '...', ' "', '42', ' the', ':', '\xa0', ' to']
基于第17个标记的预测结果: ['42', '6', '4', '41', '1', '2', '3', '7', '5', '43']
"""
需要预测多个标记吗?只需使用KV缓存即可!(这真的让我费了很大的劲才弄清楚。Orz)
如何连续预测多个标记
现在,我们已经完成了对输入文本的下一个词的预测。但如果我们的预期输出需要多个标记呢?
例如,在实际的大模型应用中,模型通常不会只输出一个词,而是常常会输出一段文字,甚至是非常长的一段文本。这种能力是如何实现的呢?
其实很简单:我们只需要反复调用大模型的预测过程,逐步生成完整的句子或段落。
这个过程就像“滚雪球”一样:每预测出一个词,我们就把这个词添加到当前的输入序列中,然后再次调用模型进行新一轮的预测。当遇到停止符号(在Llama3中是一个特殊的标记“<|end_of_text|>”)或者达到最大长度限制(超参数max_seq_len)时,预测就会停止。
这听起来效率不高吗?确实如此!
这就是为什么会有像KV缓存这样广为人知的优化机制。通过缓存历史标记的KV向量,我们可以减少每次输入和计算的负担,从而大幅提升推理效率。
得益于缓存机制,当我们使用大型模型进行推理时,你可能会注意到:等待第一个标记输出往往是最耗时的阶段。但一旦第一个标记被输出,后续标记的输出速度就会显著加快。
KV缓存的优缺点
优点:在连续预测时,我们每次只需输入新标记,而不需要输入整个文本序列。这大大提高了推理过程中的计算速度。
缺点:由于引入了缓存机制,推理过程中会占用更多的内存资源。
KV缓存的原理推导
KV缓存源于对上述矩阵计算过程的观察与分析。通过分析每个输入标记的计算步骤,我们可以发现:在大多数计算环节中,各个标记的计算其实是相对独立的,很少涉及与其他标记的交互。只有在计算注意力机制时,才会出现标记之间的相互作用,因此需要缓存历史的KV向量。
以下是KV缓存的具体推导逻辑:
- 前提:要预测下一个标记,我们只需要获取最后一个标记的输出结果(就像我们在预测章节中所做的那样)。
- 非注意力部分只需计算新标记:除了注意力计算之外,其他所有部分的计算都是独立于各个标记的。因此,我们只需要计算新标记,而无需输入历史标记(下文将进一步展开分析)。
- 注意力部分也只需计算新标记:在注意力层中,由于掩码机制的作用,历史标记的输出结果不会受到未来新标记的影响。因此,它们在每一层的输入和输出都是固定的,也就是说,历史标记的QKV向量不会因为新标记的加入而改变。所以,我们只需要计算新标记的注意力即可。
- 计算新标记的注意力机制:注意力层的作用是让标记获取历史标记的上下文信息。因此,对于每一个新标记,我们需要使用所有标记的值向量进行加权求和。这就要求我们必须存储历史标记的值向量。
- 计算新标记的注意力权重:正如第4点所述,我们还需要先获得新标记与历史标记之间的重要性信息,即注意力权重。为此,我们需要将新标记的键向量与所有标记的键向量相乘。因此,我们也需要存储历史标记的键向量。
- KV缓存的形成:由第4和第5点可知,我们需要存储历史标记的KV向量。由于查询向量并未被使用,因此无需存储。这就是KV缓存的由来。
- KV缓存的效率:根据第3点,历史的KV向量不会发生变化。因此,在连续预测过程中,它们可以被增量更新,而无需修改历史内容。这样一来,每次预测时,我们只需输入并计算新添加标记的结果,而不必再以完整序列作为输入,从而极大地提升了推理效率。
补充:KV缓存中标记计算独立性的分析
除注意力层外的所有组件(彼此之间无交互):
- 两次归一化:每个标记向量都在其自身的特征维度上进行归一化,不涉及其他标记。
- 两次残差连接(加法):每个标记向量将其自身输出结果加回到自己身上,也不涉及其他标记。
- 前馈网络(FFN):每个标记向量都乘以相同的权重矩阵W1、W2、W3来得到结果,过程中并不使用其他标记。设想如果输入标记的数量为17个,那么FFN的计算可以简化为:[17×4096] × [4096×14336] × [14336×4096] = [17×4096]。这实际上等同于每次输入一个标记,然后将17个结果拼接成一个矩阵,即:17次([1×4096] × [4096×14336] × [14336×4096] = [1×4096])= 17×[1×4096] => [17×4096]。因此,在前馈层中,每个标记的计算实际上并没有与其他标记发生交互。
注意力层(仅存在新标记与历史标记之间的单向交互):
- 计算QKV向量:每个标记向量都乘以相同的QKV权重矩阵来得到结果,不涉及其他标记。
- 向QK向量中添加位置信息:每个标记向量都基于自己的位置独立进行旋转操作,不依赖于其他标记的具体内容。
- 计算注意力权重:注意力权重表示每个标记与其之前所有历史标记之间的相关性,且与未来的标记无关。因此,历史标记的结果不受新标记的影响。而新标记则需要历史标记的键向量缓存。
- 计算注意力机制的结果:注意力机制根据注意力权重对值向量进行加权求和。因此,与上一点的结论类似,历史标记的结果同样不受新标记的影响。而新标记则需要历史标记的值向量缓存。
基于KV缓存的注意力计算流程
为了清晰地展示计算过程,我们仅推导单头场景(将其扩展到多头场景的原理和过程与之前的多头注意力实现完全相同):
- 假设历史输入标记为 $S_1$,长度为 N。基于 KV 缓存,我们将存储每个头的 KV 结果矩阵。单个头的形状为 [Nxhead_dim] = [Nx128]。
- 假设新添加的输入标记为 $S_2$,长度为 M(可以是新预测的标记、新一轮用户对话的输入,或其他任何场景)。
- 计算新标记的 QKV 向量:$Q,K,V = S_2W_{Q,K,V}$ => [Mx4096] x [4096x128] = [Mx128]。
- 为 QK 向量加入位置信息:新标记的位置应从 N + 1 开始,而非从 0 开始。[Mx128] -> [Mx128]。
- 将新的 KV 值添加到 KV 缓存中,得到更新后的 KV 矩阵,即 [Nx128] -> [(N + M)x128]。
- 计算新标记的注意力权重:Attention_weight = softmax(QK/sqrt(d) + mask) => [Mx128] x [128x(N + M)] = [Mx(N + M)]。
- 计算新标记的注意力机制最终结果:Attention_weight x V => [Mx(N + M)] x [(N + M)x128] = [Mx128]。
- 将每个头的结果拼接起来,并进行线性映射,得到注意力层的最终输出,其形状为 32x[Mx128] -> [Mx4096]。
由于我们之前的学习过程已经相当全面,这里不再实现优化方案的代码(如果你感兴趣,可以参考 Llama 3 的官方代码,其实现相对简单)。就像之前提到的多头注意力并行计算一样,知道计算过程可以被优化就足够了~
感谢大家。感谢你们持续的学习。爱你们 :)
我们的学习到这里就结束了。希望你也享受了这段阅读的过程!
来自我
如果你看到了这篇作品,感谢你的信任,也感谢你一直学到这里。我很高兴能对你有所帮助~
如果你想支持我的工作
- 给它点个赞⭐~ :)
- 请我喝杯咖啡~ https://ko-fi.com/therealoliver
来自前作作者
如果你想支持我的工作
- 在 Twitter 上关注我 https://twitter.com/naklecha
- 或者,请我喝杯咖啡 https://www.buymeacoffee.com/naklecha
说实话,如果你能看到这里,就已经让我很开心了 :)
是什么激励着我呢?
我和朋友们有一个使命——让科研更加普及!我们创建了一个名为 A10 的研究实验室——AAAAAAAAAA.org
A10 的 Twitter 账号——https://twitter.com/aaaaaaaaaaorg
我们的理念如下:
再次感谢原作者提供的基础代码和插图,它们也让我学到了很多。
许可证
版权所有 (c) 2025 张金龙 (https://github.com/therealoliver)
版权所有 (c) 2024 尼山特·阿克莱查
MIT
相似工具推荐
stable-diffusion-webui
stable-diffusion-webui 是一个基于 Gradio 构建的网页版操作界面,旨在让用户能够轻松地在本地运行和使用强大的 Stable Diffusion 图像生成模型。它解决了原始模型依赖命令行、操作门槛高且功能分散的痛点,将复杂的 AI 绘图流程整合进一个直观易用的图形化平台。 无论是希望快速上手的普通创作者、需要精细控制画面细节的设计师,还是想要深入探索模型潜力的开发者与研究人员,都能从中获益。其核心亮点在于极高的功能丰富度:不仅支持文生图、图生图、局部重绘(Inpainting)和外绘(Outpainting)等基础模式,还独创了注意力机制调整、提示词矩阵、负向提示词以及“高清修复”等高级功能。此外,它内置了 GFPGAN 和 CodeFormer 等人脸修复工具,支持多种神经网络放大算法,并允许用户通过插件系统无限扩展能力。即使是显存有限的设备,stable-diffusion-webui 也提供了相应的优化选项,让高质量的 AI 艺术创作变得触手可及。
everything-claude-code
everything-claude-code 是一套专为 AI 编程助手(如 Claude Code、Codex、Cursor 等)打造的高性能优化系统。它不仅仅是一组配置文件,而是一个经过长期实战打磨的完整框架,旨在解决 AI 代理在实际开发中面临的效率低下、记忆丢失、安全隐患及缺乏持续学习能力等核心痛点。 通过引入技能模块化、直觉增强、记忆持久化机制以及内置的安全扫描功能,everything-claude-code 能显著提升 AI 在复杂任务中的表现,帮助开发者构建更稳定、更智能的生产级 AI 代理。其独特的“研究优先”开发理念和针对 Token 消耗的优化策略,使得模型响应更快、成本更低,同时有效防御潜在的攻击向量。 这套工具特别适合软件开发者、AI 研究人员以及希望深度定制 AI 工作流的技术团队使用。无论您是在构建大型代码库,还是需要 AI 协助进行安全审计与自动化测试,everything-claude-code 都能提供强大的底层支持。作为一个曾荣获 Anthropic 黑客大奖的开源项目,它融合了多语言支持与丰富的实战钩子(hooks),让 AI 真正成长为懂上
ComfyUI
ComfyUI 是一款功能强大且高度模块化的视觉 AI 引擎,专为设计和执行复杂的 Stable Diffusion 图像生成流程而打造。它摒弃了传统的代码编写模式,采用直观的节点式流程图界面,让用户通过连接不同的功能模块即可构建个性化的生成管线。 这一设计巧妙解决了高级 AI 绘图工作流配置复杂、灵活性不足的痛点。用户无需具备编程背景,也能自由组合模型、调整参数并实时预览效果,轻松实现从基础文生图到多步骤高清修复等各类复杂任务。ComfyUI 拥有极佳的兼容性,不仅支持 Windows、macOS 和 Linux 全平台,还广泛适配 NVIDIA、AMD、Intel 及苹果 Silicon 等多种硬件架构,并率先支持 SDXL、Flux、SD3 等前沿模型。 无论是希望深入探索算法潜力的研究人员和开发者,还是追求极致创作自由度的设计师与资深 AI 绘画爱好者,ComfyUI 都能提供强大的支持。其独特的模块化架构允许社区不断扩展新功能,使其成为当前最灵活、生态最丰富的开源扩散模型工具之一,帮助用户将创意高效转化为现实。
NextChat
NextChat 是一款轻量且极速的 AI 助手,旨在为用户提供流畅、跨平台的大模型交互体验。它完美解决了用户在多设备间切换时难以保持对话连续性,以及面对众多 AI 模型不知如何统一管理的痛点。无论是日常办公、学习辅助还是创意激发,NextChat 都能让用户随时随地通过网页、iOS、Android、Windows、MacOS 或 Linux 端无缝接入智能服务。 这款工具非常适合普通用户、学生、职场人士以及需要私有化部署的企业团队使用。对于开发者而言,它也提供了便捷的自托管方案,支持一键部署到 Vercel 或 Zeabur 等平台。 NextChat 的核心亮点在于其广泛的模型兼容性,原生支持 Claude、DeepSeek、GPT-4 及 Gemini Pro 等主流大模型,让用户在一个界面即可自由切换不同 AI 能力。此外,它还率先支持 MCP(Model Context Protocol)协议,增强了上下文处理能力。针对企业用户,NextChat 提供专业版解决方案,具备品牌定制、细粒度权限控制、内部知识库整合及安全审计等功能,满足公司对数据隐私和个性化管理的高标准要求。
ML-For-Beginners
ML-For-Beginners 是由微软推出的一套系统化机器学习入门课程,旨在帮助零基础用户轻松掌握经典机器学习知识。这套课程将学习路径规划为 12 周,包含 26 节精炼课程和 52 道配套测验,内容涵盖从基础概念到实际应用的完整流程,有效解决了初学者面对庞大知识体系时无从下手、缺乏结构化指导的痛点。 无论是希望转型的开发者、需要补充算法背景的研究人员,还是对人工智能充满好奇的普通爱好者,都能从中受益。课程不仅提供了清晰的理论讲解,还强调动手实践,让用户在循序渐进中建立扎实的技能基础。其独特的亮点在于强大的多语言支持,通过自动化机制提供了包括简体中文在内的 50 多种语言版本,极大地降低了全球不同背景用户的学习门槛。此外,项目采用开源协作模式,社区活跃且内容持续更新,确保学习者能获取前沿且准确的技术资讯。如果你正寻找一条清晰、友好且专业的机器学习入门之路,ML-For-Beginners 将是理想的起点。
ragflow
RAGFlow 是一款领先的开源检索增强生成(RAG)引擎,旨在为大语言模型构建更精准、可靠的上下文层。它巧妙地将前沿的 RAG 技术与智能体(Agent)能力相结合,不仅支持从各类文档中高效提取知识,还能让模型基于这些知识进行逻辑推理和任务执行。 在大模型应用中,幻觉问题和知识滞后是常见痛点。RAGFlow 通过深度解析复杂文档结构(如表格、图表及混合排版),显著提升了信息检索的准确度,从而有效减少模型“胡编乱造”的现象,确保回答既有据可依又具备时效性。其内置的智能体机制更进一步,使系统不仅能回答问题,还能自主规划步骤解决复杂问题。 这款工具特别适合开发者、企业技术团队以及 AI 研究人员使用。无论是希望快速搭建私有知识库问答系统,还是致力于探索大模型在垂直领域落地的创新者,都能从中受益。RAGFlow 提供了可视化的工作流编排界面和灵活的 API 接口,既降低了非算法背景用户的上手门槛,也满足了专业开发者对系统深度定制的需求。作为基于 Apache 2.0 协议开源的项目,它正成为连接通用大模型与行业专有知识之间的重要桥梁。