Giles' blog

从零开始编写 LLM,第 8 部分 -- 可训练的 Self-Attention

发布于 2025 年 3 月 4 日,分类:AI, Python, LLM from scratch, TIL deep dives

这是我阅读 Sebastian Raschka 的著作 "Build a Large Language Model (from Scratch)" 的第八篇博客。我写博客记录下那些吸引我兴趣的点,以及那些让我绞尽脑汁才能弄明白的东西,以此来理清自己的思路,同时也希望能帮助到也在学习这本书的其他人。 距离我的 上次更新 已经快一个月了 -- 如果你怀疑我是在 写博客来逃避写博客,并且花时间让 LaTeX 在这个网站上正常工作,是因为下一个章节会很难,那么你就完全正确了! 好消息是 -- 就像这些事情经常发生的那样 -- 当我真正开始着手时,它并没有想象中那么困难。 动力重新找回了。

如果你通过写博客相关的博客发现了这篇博客,欢迎!那些文章并不是典型的内容,但我希望你会喜欢这次回归我正常风格的文章。

这次我将介绍 3.4 节,"使用可训练权重实现 self-attention"。 当我们查看句子中的其他词时,如何创建一个系统来学习如何解释对单词的关注程度 -- 例如,学习到在 "the fat cat sat on the mat" 中,当你查看 "cat" 时,单词 "fat" 很重要,但是当你查看 "mat" 时,"fat" 就不那么重要了?

在深入研究之前,特别是考虑到距离上一篇文章已经过去很长时间,让我们从 GPT 类型的、仅解码器的、基于 Transformer 的 LLM(以下简称 "LLM",以避免我患上 RSI)的工作原理的宏观角度开始。 对于每个步骤,我都链接到了我详细介绍过的帖子。

完成所有这些之后,我们得到一个 context vectors 序列,每个 context vectors 应该以某种方式表示其在输入中各自 token 的含义,包括它从所有其他 token 中获得的含义。 因此,例如,"cat" 的 context vector 将包含一些关于其肥胖程度的提示。

那些 context vectors 发生了什么,让 LLM 可以使用它们来预测下一个 token 可能是什么? 这一点仍有待解释,因此我们将不得不拭目以待。 但是首先要学习的是,我们如何创建一个可训练的 attention 机制,它可以获取 input vectors 并生成 attention scores,以便我们可以首先计算出那些 context vectors。

Raschka 在本节中给出的答案称为 scaled dot product attention。 他给出了清晰的代码运行过程,但我不得不花一个周末的时间来解决它,才能获得一个扎实的心理模型。 因此,我将提供我自己的解释来代替逐节浏览代码,这样可以避免我将来在尝试记住它时撞头,也许也可以避免其他人的额头遭受同样的命运。

提前总结

我一直是 Pimsleur 式语言课程的 长期粉丝,他们每次教程都以一分钟左右的你正在尝试学习的语言的对话开始,然后说 "在 30 分钟后,你会再次听到它,你就会理解它"。 你完成了课程,他们再次播放对话,而你确实理解了它。

因此,这是对 self-attention 如何工作的一个压缩总结,用我自己的话说,基于 Raschka 的解释。 现在它可能看起来像一堵术语墙,但是(希望)当你读完这篇博客文章后,你将能够重新阅读它,并且一切都会变得有意义。

我们有一个长度为 n 的 token 输入序列。 我们已将其转换为 input embeddings 序列,每个 embeddings 都是长度为 d 的向量 -- 其中的每一个都可以被视为 d 维空间中的一个点。 让我们用这样的值来表示 embeddings 序列:x1,x2,x3,...xn。 我们的目标是生成一个长度为 n 的由 context vectors 组成的序列,每个 context vectors 代表各自的输入 token 在整个输入上下文中的含义。 这些 context vectors 的长度均为 c(在实践中通常等于 d,但理论上可以是任何长度)。

我们定义三个矩阵,分别是 query weights matrix Wq,key weights matrix Wk 和 value weights matrix Wv。 这些由可训练的权重组成; 它们的大小均为 d×c。 由于这些维度,我们可以将它们视为将长度为 d 的向量(即 d 维空间中的一个点)投影到长度为 c 的向量(即 c 维空间中的一个点)的操作。 我们将这些投影空间称为 key spacequery spacevalue space。 例如,要将输入向量 xm 转换为 query space,我们只需将其乘以 Wq,例如 qm=xmWq。

当我们考虑输入 xm 时,我们要计算其对于序列中每个输入(包括其自身)的 attention weights。 第一步是计算 attention score,当考虑另一个输入 xp 时,通过获取 xm 投影到 query space 的点积与 xp 投影到 key space 的点积来计算。 对所有输入执行此操作,可为我们提供 xm 的每个其他 token 的 attention score。 然后,我们将这些值除以我们正在投影到的空间维度的平方根 c,然后通过 softmax 函数运行结果列表,以使它们全部加起来等于 1。 此列表是 xm 的 attention weights。 这个过程称为 scaled dot product attention

下一步是为 xm 生成一个 context vector。 这仅仅是所有输入投影到 value space 的向量之和,每个向量都乘以其相关的 attention weight。

通过对每个输入向量执行这些操作,我们可以生成一个长度为 n 的由长度为 c 的 context vectors 组成的列表,每个向量代表输入 token 在整个输入上下文中的含义。

重要的是,通过巧妙地使用矩阵乘法,所有这些都可以针对序列中的所有输入完成,从而仅需五个矩阵乘法和一个转置,即可为每个输入生成一个 context vector。

现在让我们解释一下

首先,如果有人在没有已经了解 attention 机制的工作方式的情况下理解了所有这些,那么我向你致敬! 这非常密集,我希望它读起来不像我的朋友 Jonathan 对 难以理解的使用 git 的指南 的模仿。 对我来说,我花了八遍地阅读 Raschka(极其清晰易懂)的解释才能达到我感觉自己理解它的水平。 我认为值得注意的是,这在很大程度上是一种 "机械式" 解释 -- 它说明了我们如何进行这些计算,而没有说明原因。 我认为 "原因" 实际上超出了本书的范围,但这令我着迷,我将在不久后写博客介绍它。 但是,为了理解 "原因",我认为我们需要对 "如何" 有一个扎实的基础,因此让我们深入研究一下这篇文章。

直到本书的本节为止,我们一直在通过获取 input embeddings 之间的点积来计算 attention scores -- 也就是说,当你查看 xm 时,xp 的 attention score 只是 xm·xp。 我之前怀疑 Raschka 为他的 "toy" self-attention 使用该特定操作的原因是实际实现与之类似,事实证明这是正确的,因为我们在这里进行缩放的点积运算。 但是我们所做的是首先调整它们 -- 我们正在考虑的 xm 首先乘以 query weights matrix Wq,而另一个 xp 首先乘以 key weights matrix Wk。 Raschka 将其称为投影,对我来说,这是一个非常好的看待它的方式。 但是他的参考只是顺便提及,对我来说,需要更多的挖掘。

矩阵作为空间之间的投影

如果你的矩阵数学有点生疏 -- 就像我一样 -- 并且你还没有阅读 我上周发布的入门读物,那么你可能想现在查看一下。

从你的学生时代,你可能还记得矩阵可用于应用几何变换。 例如,如果你采用表示一个点的向量,则可以将其乘以一个矩阵,以使该点绕原点旋转。 你可以使用这样的矩阵将事物逆时针旋转 θ 度:

[xy][cosθ−sinθsinθcosθ]=[x.cosθ+y.sinθx.−sinθ+y.cosθ]

由于这是矩阵乘法,因此你可以添加更多点 -- 也就是说,如果第一个矩阵具有更多行,其中每一行都是你想要旋转的点,则相同的乘法将使它们都旋转 θ。 因此,你可以将矩阵视为将点集映射到其旋转等效物的函数。 这也适用于更高的维度 -- 像这样的 2×2 矩阵可以表示 2 个维度的变换,但是例如,在 3d 图形中,人们使用 3×3 矩阵对构成 3d 对象的点进行类似的变换。2

看待此 2×2 矩阵的另一种方法是,它是将点从一个二维空间投影到另一个空间的函数,目标空间是第一个空间逆时针旋转 θ 度的空间。 对于像这样的简单 2d 示例,甚至对于 3d 示例,这不一定是更好的看待它的方式。 这是一种哲学上的差异,而不是一种实践上的差异。

但是,假设矩阵不是正方形 -- 也就是说,它的行数与列数不同。 如果你有 3×2 矩阵,则可以使用它来乘以 3d 空间中的向量矩阵并生成 2d 空间中的矩阵。 记住矩阵乘法的规则:n×3 矩阵乘以 3×2 矩阵将得到 n×2 矩阵。

这实际上非常有用;如果你进行过任何 3d 图形处理,你可能会记得 视锥体 矩阵,该矩阵用于将你正在处理的 3d 点转换为屏幕上的 2d 点。 无需过多介绍,它允许你通过单个矩阵乘法将这些 3d 点投影到 2d 空间中。

因此:一个 d×c 矩阵可以看作是将代表 d 维空间中一个点的向量投影到代表不同 c 维空间中一个点的向量的一种方法。

我们在 self-attention 中所做的是获取组成输入 embeddings 序列的 d 维向量,然后将它们投影到三个不同的 c 维空间中,并使用投影版本。 为什么我们这样做? 这就是我想在未来关于 "原因" 的帖子中探讨的问题,但是现在,我认为比较清楚的一件事是,由于这些投影是作为训练的一部分学习的(请记住,我们用于投影的三个矩阵由可训练的权重组成),因此它在混合中放入了一些间接性,而我们之前使用的简单点积 attention 没有这种间接性。

如何进行投影 input embeddings 的点积运算

现在,让我们坚持这种机械的观点 -- "如何" 而不是 "为什么" -- 让我们看一下计算以及矩阵乘法如何使它们高效。 我将大致按照 Raschka 的解释进行,但是使用数学符号而不是代码,因为(对于我这个职业技术人员来说很不寻常)我发现这样更容易掌握正在发生的事情。

我们将坚持考虑 token xm 并尝试计算其对 xp 的 attention score 的情况。 我们要做的第一件事是将 xm 投影到 query space 中,我们通过将其乘以 query weights matrix Wq 来完成:

qm=xmWq

现在,让我们通过将 xp 乘以 key weights matrix Wk 来将其投影到 key space 中:

kp=xpWk

我们的 attention score 定义为这两个向量的点积:

ωm,p=qm.kp

因此,我们可以编写一个简单的循环,该循环迭代一次所有输入 x1...xn,为每个输入生成到 query space 的投影,然后在该循环内第二次迭代 x1...xn,将它们投影到 key space,进行点积运算,并将这些存储为 attention scores。

但是,那将是浪费! 我们正在进行矩阵乘法,因此我们可以批量处理事情。 让我们首先考虑输入到 key space 的投影; 每次循环时,这些投影将始终相同。 因此,我们可以一气呵成。 让我们将输入序列视为一个矩阵 X,如下所示:

[x1(1)x1(2)x1(3)x2(1)x2(2)x2(3)...xn(1)xn(2)xn(3)]

我们输入序列 x1, x2 等中的每个 input embedding 都有一行,该行由该 embedding 中的元素组成。 因此,它有 n 行,每个输入序列有一个元素,d 列,每个输入 embeddings 中有一个维度,因此它是 n×d。(我在这里使用 d=3 作为示例,就像 Raschka 在书中一样。)

就像上面旋转矩阵示例中的点矩阵一样,因此我们可以通过将它乘以 Wk 来一次将它投影到 key space 中。 让我们将结果称为 K:

K=XWk

它看起来像这样(同样,就像 Raschka 一样,我使用 2 维 key space -- 也就是说,c=2 -- 因此很容易看到矩阵是在原始 3d input embeddings 空间还是在 2d 投影空间中):

[k1(1)k1(2)k2(1)k2(2)...kn(1)k2(2)]

...其中每一行都是输入 xn 到 key space 的投影。 只是所有投影都堆叠在一起。

现在,让我们考虑一下点积 -- 前面的这一点:

ωm,p=qm.kp

我们现在有一个包含我们所有 kn 值的矩阵 K。当你进行矩阵乘法时,值 Mi,j -- 也就是输出矩阵中第 i 行第 j 列的元素 -- 是第一个矩阵中第 i 行的点积(用作向量)与第二个矩阵中第 j 列的点积(也视为向量)。

听起来我们可以利用它来批量完成我们所有的点积。 让我们将 qm(我们第 m 个输入 token 到 query space 的投影)视为一个单行矩阵。 我们可以将 key 矩阵乘以它,像这样

qmK

...?

不幸的是,不能。 qm 是一个单行矩阵(大小为 1×c),而 K 是我们的 n×c key 矩阵。 使用矩阵乘法,第一个矩阵中的列数(在本例中为 c)需要与第二个矩阵中的行数(即 n)匹配。 但是,如果我们转置 K,则基本上交换行和列:

qmKT

...然后我们得到一个 1×c 矩阵乘以一个 c×n 矩阵,这确实有意义 -- 甚至更好的是,它是对于所有 p 的所有 (qm, kp) 对的每个点积 -- 也就是说,通过两次矩阵乘法 -- 计算 K 的矩阵乘法和这个矩阵乘法,以及一个转置,我们已经计算了输入序列中元素 xm 的所有 attention scores。

但是它变得更好了!

首先,让我们像投影输入序列到 key space 一样,将其投影到 query space 中。 我们计算了 K=XWk 来计算 key 矩阵,因此我们可以使用相同的方式计算 query 矩阵 Q=XWq。 就像 K 是所有输入向量投影到 key space 并 "堆叠" 在彼此之上的矩阵一样,Q 是所有输入向量投影到 query space 的矩阵。

现在,如果我们将其乘以转置的 key 矩阵会发生什么?

QKT

好吧,我们的 Q 矩阵是每个输入一行,每个投影空间维度一列,因此它是 n×c。 而且,如我们所知,转置的 K 矩阵为 c×n。 因此,我们的结果是 n×n -- 并且由于矩阵乘法是根据点积定义的,因此它包含 Q 中每一行的点积(转换为 query space)与 KT 中每一列的点积(转换为 key space)。

该计划是通过计算精确的那些点积来生成 attention scores!

因此,通过三个矩阵乘法,我们就完成了:

Q=XWq K=XWk Ω=QKT

...其中我使用大写的 Ω 表示一个矩阵,其中每行表示序列中的一个输入,并且行中的每列表示该输入的 attention weight。 元素 Ωm,p 表示当你想计算 xm 的 context vector 时,应该对输入 xp 给予多少关注。 并且它通过计算 xm 投影到 query space 的点积和 xp 投影到 key space 的点积来完成。

这就是 "scaled dot product attention" 的 "点积" 部分完成的 :-)

归一化

因此,我们已经计算出了我们的 attention scores。 我们需要做的下一件事是使它们归一化; 过去我们使用了 softmax 函数。 该函数获取一个列表并调整其中的值,以使它们全部加起来为 1,但是会提高较高的数字并降低较小的数字。 我认为它之所以被称为 "soft" "max",是因为它类似于找到最大值,但是从某种意义上说更柔和,因为它将其他较小的数字留在那里并降低了它们。

Raschka 解释说,当我们处理大量的维度时 -- 在真实的 LLM 中,d 和 c 很容易达到数千个 -- 使用纯 softmax 会导致小的梯度 -- 他说它会开始表现得 "像阶跃函数",我将其理解为意味着你会发现除了列表中最大的数字之外的所有数字都被缩放到非常小的数字,而最大的数字占据主导地位。 因此,作为一种解决方法,我们将这些数字除以我们投影到的空间 c 中的维度数的平方根,然后才通过 softmax 运行结果。3

请记住,Ω 是 attention scores 的矩阵,每个输入 token 一行,因此我们需要分别将 softmax 函数应用于每一行。 这是我们最终得到的结果:

A=softmax(Ωc, axis=1)

(axis=1 并不是真正的数学符号,它只是我从 PyTorch 中借用的东西,以表示我们正在逐行对矩阵应用 softmax。)

完成此操作后,我们就有了归一化的 attention scores -- 也就是 attention weights。 下一步,也是最后一步,是使用它们来计算 context vectors。

创建 context vectors

让我们重申一下我们如何计算 context vectors。 在之前的 toy 示例中,对于每个 token,我们获取 input embeddings,将每个 input embeddings 乘以其 attention weight,按元素将结果相加,结果就是结果。 现在我们做着相同的事情,但是首先将 input embeddings 投影到另一个空间 -- value space。 因此,让我们首先进行投影作为简单的矩阵乘法,就像我们对其他空间所做的一样:

V=XWv

现在,从上面我们得到了 attention weights 矩阵 A,其中在第 m 行中包含了输入序列中每个 token 对于输入 xm 的 attention weights -- 也就是说,在 Am,p 处,我们有输入 p 的 attention weight,当我们计算输入 m 的 context vector 时。 这意味着对于长度为 n 的输入序列,它是 n×n 矩阵。

在我们的 value 矩阵 V 中,我们每个输入也有一行。 在第 m 行中的值(视为向量)是输入 xm 投影到 value space 的向量。 因此,它是 n×c 矩阵。

如果我们进行矩阵乘法会发生什么

AV

...? 根据矩阵乘法的规则,我们将得到一个 n×c 矩阵,但它意味着什么?

重申一下,矩阵乘法的规则是,值 Mi,j -- 也就是输出矩阵中第 i 行第 j 列的元素 -- 是第一个矩阵中第 i 行的点积(用作向量)与第二个矩阵中第 j 列的点积(也视为向量)。

因此,在位置 (1,1) -- 第一行,第一列,我们有 A 中第一行的点积 -- 在我们考虑第一个 token 时,输入序列中每个 token 的 attention weight -- 以及 V 中第一列,它是每个 input embedding 的第一个元素,投影到 value space。 因此,它是每个 input embedding 的第一个元素乘以第一个 token 的 attention weights。 或者,换句话说,它是第一个 token 的 context vector 的第一个元素!

在位置 (1,2) -- 第一行,第二列 -- 我们将进行相同的计算,但是对于每个 input embedding 的第二个元素。 这是第一个 token 的 context vector 的第二个元素。

...依此类推,对于其余的列。 到第一行结束时,我们将获得(视为向量)所有 input embeddings 的总和,乘以第一个输入的权重。 这是该输入的 context vector!

当然,对于每一行都重复相同的操作。 单个矩阵乘法的结果是一个矩阵,其中行 m 是输入 xm 的 context vector。

我们完成了!

将所有内容整合在一起

让我们把这些步骤放在一起。 我们从输入矩阵 X 开始,它是我们之前为长度为 n 的 token 序列生成的 input embeddings。 每行是一个 embedding,并且有 d 列,其中 d 是 embeddings 的维数。

我们还有权重矩阵,用于将 input embeddings 映射到不同的空间:query weights matrix Wq,key weights matrix Wk 和 value weights matrix Wv。

因此,我们通过三个矩阵乘法将输入矩阵投影到这些空间中:

Q=XWq K=XWk V=XWv

...以获得我们的 query 矩阵、key 矩阵和 value 矩阵。

然后,我们通过进一步的矩阵乘法和一个转置来计算 attention scores,以计算点积:

Ω=QKT

我们通过按 c 的平方根缩放这些值,然后应用 softmax 将它们归一化为 attention weights:

A=softmax(Ωc, axis=1)

...然后,我们使用最后一个矩阵乘法来使用它来计算 context vectors:

C=AV

这就是我们的 self-attention 机制 :-)

现在,如果你 回到开头的解释,那么希望它会有所意义。

回到书中

本书的 3.4 节通过 PyTorch 代码进行了上述操作,并得出了一个不错的简单 nn.Module 子类,该子类完全执行这些矩阵运算。 然后对其进行改进 -- 第一个版本对三个权重矩阵使用通用的 nn.Parameter 对象,第二个版本对更有效的训练使用 nn.Linear。 这方面相对容易理解。 因此,我们总结了我认为是 "Build a Large Language Model (from scratch)" 中最难的部分:使用可训练权重实现 self-attention。

下一步

现在我们已经克服了这个难题,第 3 章的其余部分要容易得多。 我们将介绍两件事:

因此,我想我可能会先写这些内容,然后再回过头来讨论这种形式的 self-attention 的 "原因"。 我们可以完成所有这些 -- 投影到不同维度的空间,在这些空间中获取每个 token 的输入 embeddings 之间的点积,并通过我们生成的权重对投影的输入 token 进行加权 -- 仅使用五个矩阵乘法,这真是太神奇了。 但是为什么我们专门这样做呢?

使用的矩阵的名称(query、key 和 value)以隐喻的方式暗示了它们所扮演的角色; Raschka 在侧边栏中说,这是对数据库等信息检索系统的致敬。 但是,它与数据库的实际工作方式有很大的不同,我无法完全建立联系。 我相信随着时间的推移,它将会实现。

我还想(可能在另一篇文章中)考虑批处理对所有这些的影响。 通过 普通的神经网络,我们在考虑给定输入时所有的激活都是单行或单列矩阵(取决于我们方程的顺序)。 扩展到批处理只是意味着移动到普通的多行、多列矩阵。

但是自从我们 第一次 引入 attention scores 矩阵以来,很明显,即使使用单个输入序列通过我们的 LLM,我们已经在使用完整的矩阵。 我们如何处理并行处理多个输入序列的批处理? 看来我们需要使用某种高阶张量 -- 如果标量是零阶张量,向量是一阶张量,矩阵是二阶张量,那么我们将需要开始考虑至少三阶张量。 这需要一番思考!

但是现在,就这些了 -- 下次再见! 请在下面评论 -- 当然,任何想法、问题或建议都将非常欢迎,但是即使你只是觉得这篇文章有用,我也很想知道 :-)

  1. 值得注意的是,这是绝对位置 embeddings -- 还有相对的位置,但本书中没有介绍。
  2. 当然,这是 GPU(用于加速游戏中 3d 图形的 GPU)对于神经网络如此有用的原因之一。 它们旨在非常高效地进行矩阵乘法,以便游戏开发人员可以轻松地操纵和转换 3d 和 2d 空间中的对象,但是它们的效率是一般性的 -- 它不仅仅与图形所需的那些矩阵乘法相关联。
  3. 感觉最好通过使用和不使用缩放进行一些训练运行并查看会发生什么来理解这一点 -- 这是一种工程修复,而不是数学上显而易见的事情。

« 在 AI 时代仍然值得写博客