Giles' blog

从零编写 LLM,第 13 部分——Attention 的“为什么”,或者:Attention Heads 并不聪明

发布于 2025 年 5 月 8 日,分类:AIPythonLLM from scratchTIL deep dives

在完成了 Sebastian Raschka 的书 "Build a Large Language Model (from Scratch)" 的第三章之后,并在上一篇文章中介绍了 multi-head attention,我觉得有必要停下来,在进入第四章之前进行一些总结。

我想涵盖两件事,self-attention 的“为什么”,以及关于上下文长度的一些想法。这篇文章是关于“为什么”的——也就是说,书中描述的特定矩阵乘法集为什么能实现我们想要的功能?

像往常一样,我主要做这件事是为了让自己头脑更清晰——可能还会给其他人带来额外的用处。当然,我会让多个 LLM 来检查一下,以确保我发布的不是完全的胡说八道,但请读者自行判断!

让我们开始吧。正如我在本系列的第 8 部分中写道:

我认为值得注意的是,[书中的内容]在很大程度上是一种“机械”的解释——它说明了我们如何进行这些计算,但没有说明为什么。我认为“为什么”实际上超出了本书的范围,但这是我非常感兴趣的事情,我很快会写一篇关于它的博客。

那个“很快”就是现在 :-)

Attention Heads 并不聪明

我认为我在理解这些公式为什么有效时遇到的核心问题是,我高估了单个 attention head 所能做的事情。在第 6 部分中,我写了关于短语 "the fat cat sat on the mat":

因此,虽然 "cat" 的输入 embedding 只是表示 "第 3 个位置的 cat",但这句话中 "cat" 的上下文向量也包含着一些关于它是一只坐着的猫的含义,可能不太强烈地暗示着它是一只特定的猫("the" 而不是 "a"),并暗示它坐在垫子上。

我没有理解的是,在某种程度上,这是正确的,但只适用于整个 attention 机制的输出——_而不是_单个 attention head。

每个单独的 attention head 实际上都很笨,它所做的事情比这简单得多!

使整个机制变得智能的两件事是 multi-head attention 和分层。这本书已经详细介绍了 multi-head attention,所以让我们深入研究第二部分。

一开始,在第 1 部分中,我写道:

另一件让感到困惑的事情是,Raschka 提到最初的 transformer 架构有六个 encoder 和六个 decoder 块,而 GPT-3 有 96 个 transformer 层。这与我对这一切如何运作的理解不太吻合。encoder 和 decoder 似乎都是独立的,接受输入(tokens/embeddings)并产生输出(embeddings/tokens)。你会如何使用多个层级的它们?

现在我们已经介绍了 attention 的工作原理,这一点变得更加清楚了。一个 multi-head attention 块获得一组输入 embeddings(输入序列中每个 token 一个),并产生一组相同数量的上下文向量。没有什么能阻止我们将这些上下文向量视为另一个 attention 块的输入 embeddings,并再次执行相同的操作。

(这也解释了为什么 Raschka 提到上下文向量中的维度数量通常与输入 embeddings 中的数量相匹配;这使得为每一层使用相同“形状”的 multi-head attention 计算更容易。)

在我看来,这类似于图像处理网络——例如 CNN——的工作方式。在这些网络中,第一层可能检测边缘,第二层可能检测特定方向的线条,下一层可能检测特定形状,然后在稍后的某个位置,第 n 层可能会识别狗的脸。

因此,我上面描述的 "cat" token 的表示不会是单个 attention head 的输出的一部分,甚至 attention 机制的第一层可能也没有那么丰富的内容。但它可能是 multi-head attention 的第三层或第四层的输出。

到了 GPT-3 中的第 96 层,上下文中表示的内容将非常丰富,并且在不同的 tokens 之间分布着大量信息。意识到这一点对我来说也是一种顿悟。

不再有固定长度的瓶颈

如果你回想一下第 5 部分,没有 attention 机制的 encoder/decoder RNN 的一个大问题是固定长度的瓶颈。你会将输入序列输入到一个 encoder RNN 中,它会尝试将其含义表示在其隐藏状态中——一个具有特定固定长度的向量——准备好将其传递给 decoder。对于短输入来说很容易,但随着输入越来越长,它变得越来越困难,最终变得不可能,因为你试图将越来越多的信息打包到相同的“空间”中。

但是有了 attention,从最后一个 attention 层输出的输入序列的这种超丰富和组合的表示,其长度与输入中 tokens 的数量成正比!当然,你仍然受到可用内存的限制(以及其他因素——请参阅下一篇文章),但是你拥有的 tokens 越多,上下文向量的这种“隐藏状态”就越大。

这太酷了。

因此,使用 multi-head attention 加上层允许我们构建复杂的表示,即使每个单独的 attention head 都很笨。但是,回到这篇文章的核心,这些笨 attention heads 为什么 使用它们所做的特定计算?

为什么笨 Attention Heads 有效

让我们使用一个例子。

首先是一个提示/警告:attention heads 正在学习它们自己的表示和模式,作为深度学习梯度下降的一部分——因此无论他们学到什么,都可能很奇怪,与我们理解的语法没有任何关系。但对于这个例子,让我们假设情况并非如此,并且我们有一个 attention head,它已经学会了如何将冠词(如 "a"、"an" 和 "the")与其相关的名词匹配起来。

它是如何工作的?让我们以 "the fat cat sat on the mat" 为例,忽略除了两个 "the" 和名词 "cat" 和 "mat" 之外的所有内容。我们将假设我们的 attention head 想要为 "cat" 生成一个上下文向量,将其与第一个 "the" 结合起来(这意味着它将包含我们正在谈论一只特定猫而不是仅仅是 "a" 猫的概念),类似地,它希望将第二个 "the" 融入 "mat" 中。

现在,请记住,我们的输入序列是一系列输入 embeddings,它是 token embeddings(空间中的向量,指向 tokens 的一些抽象“含义”)和位置 embeddings(表示它们在序列中的位置)的组合。

以 "mat" 为例,我们投影它的输入 embedding,这意味着 "位置 7 的 token 'mat'"1 进入 query space。对我来说的突破是,query space 是另一个 embedding space,就像输入 embeddings 的原始空间一样,但对于值的表示不同。

假设在这个新的 embedding space 中,表示要简单得多——它们没有原始空间那么多的细节。它只是表示 "这是一个冠词" 或 "这不是一个冠词",以及一些关于定位的信息——也就是说,位置 1 的冠词的 embedding 接近位置 2 的 embedding,但与位置 69,536 的 embedding 相差甚远。其他不是冠词的东西会在更远的地方。

在这个例子中,也许我们的 attention head 已经学会的投影会将 "'mat' 在位置 7" 映射到一个 embedding,指向 "位置 6 或更低处的某个冠词——the 或 a——可能非常接近" 的方向。换句话说,投影到 query space 将 token 的输入 embedding 转换为 attention head 在处理该 token 时正在寻找的东西。同样,"'cat' 在位置 2" 将被投影到一个 embedding 向量,意思是 "位置 1 或更低处的某个冠词,可能非常接近"。

现在,除了将输入 embeddings 投影到 query space 之外,我们还将它们投影到 key space 中。在这种情况下,我们假设的冠词匹配 head 将创建一个投影,将第一个 "the" 变成意味着 "位置 1 的冠词",将第二个 "the" 变成意味着 "位置 6 的冠词"。

因此,query weights 已将我们的输入 embeddings 投影到这个 "低分辨率" 的 embedding space 中,以指向意味着 "这是我感兴趣的东西" 的方向,并且 key weights 已将输入 embeddings 投影到相同的 embedding space 中,以指向意味着 "这是我" 的方向。

这意味着当我们进行点积时,"mat" 的 query 向量将指向与第二个 "the" 的 key 向量非常相似的方向,因此点积将很高——请记住,只要向量的长度大致相同,点积就表示它们有多相似。

重要的是,query 和 key 向量使用的共享 embedding space 实际上可能比输入 embeddings 使用的丰富空间更贫乏。在我们的例子中,head 关心的只是 tokens 是名词、冠词还是其他东西,以及它们的位置。

让我们举个例子。这是我想象的 attention 机制可能在第 6 部分中提出的假想的 attention 分数集(经过修改以使其具有因果关系,因此 tokens 不会关注它们“未来”中的 tokens):

Token | ω("The") | ω("fat") | ω("cat") | ω("sat") | ω("on") | ω("the") | ω("mat")
---|---|---|---|---|---|---|---
The | 1 | 0 | 0 | 0 | 0 | 0 | 0
fat | 0.2 | 1 | 0 | 0 | 0 | 0 | 0
cat | 0.6 | 0.8 | 1 | 0 | 0 | 0 | 0
sat | 0.1 | 0 | 0.85 | 1 | 0 | 0 | 0
on | 0 | 0.1 | 0.4 | 0.6 | 1 | 0 | 0
the | 0 | 0 | 0 | 0 | 0.1 | 1 | 0
mat | 0 | 0 | 0.2 | 0.8 | 0.7 | 0.6 | 1

每一行都是对于第一列中的 token,所有其他单词的 attention 分数。它基于我个人对单词重要性的直觉,并且是你可能想象一个聪明的 attention head 可能会提出的东西。(请记住,ω 是我们用来表示 attention 分数的变量。)

但是我们更接近真实世界的冠词-名词匹配 head 的例子非常笨,因此它可能会提出更像这样的东西:

Token | ω("The") | ω("fat") | ω("cat") | ω("sat") | ω("on") | ω("the") | ω("mat")
---|---|---|---|---|---|---|---
The | 0 | 0 | 0 | 0 | 0 | 0 | 0
fat | 0 | 0 | 0 | 0 | 0 | 0 | 0
cat | 0.8 | 0 | 1 | 0 | 0 | 0 | 0
sat | 0 | 0 | 0 | 0 | 0 | 0 | 0
on | 0 | 0 | 0 | 0 | 0 | 0 | 0
the | 0 | 0 | 0 | 0 | 0 | 0 | 0
mat | 0.1 | 0 | 0 | 0 | 0 | 0.8 | 1

它所做的只是决定在考虑名词时关注 "the" ——它甚至在考虑 "mat" 时有点关注第一个 "the",因为它不知道它必须是它匹配的最近的 "the"。2

现在,正如我之前所说,真正的 attention heads,经过数十亿 tokens 的梯度下降训练,可能会学到一些奇怪和抽象的东西,与我们对语言、语法和词性的思考方式无关。

但单独来看,它们将非常笨,因为该等式正在做一些非常简单的事情:在考虑特定类型的事物时,寻找另一种类型的事物。每个 token 都通过 query weights 投影到共享 embedding space 中(“我在寻找什么”),并通过 key weights 投影到同一空间中,但这次以使其指向它在相同意义上“是什么”。然后点积将它们匹配起来,以便我们可以将输入 embeddings 彼此关联,以计算我们的 attention 分数。

当然,这并不意味着我们丢失了任何信息。这种贫乏的 embedding space 仅用于进行匹配以计算我们的 attention 分数。当我们计算上下文向量时,我们使用投影到 value space 中,它可以像我们喜欢的那么丰富。

值得注意的是,尽管 Raschka 在书中使用的示例对 query 和 key 向量的共享空间以及 value 向量的空间具有相同的维度,但实际上没有必要这样做。我见过 LLM 的规范,其中 QK 空间具有更少的维度——这至少对于这个简单的例子来说是有意义的。

还值得注意的是,在这个例子中,这个 key/query space 是贫乏的,但在一个真正的“外星人”学习的例子中,它实际上可能非常复杂和丰富——但比这个例子更难理解。最终,该 embedding space 的性质将以与所有其他事物相同的方式学习,并将匹配 head 所学到的任何事物。

笨 Attention 的优雅

因此,这就是(现在)我对 scaled dot product attention 如何运作的理解。我们只是在做简单的模式匹配,其中每个 token 的输入 embedding 都通过 query weights 投影到一个(学习的)embedding space 中,该空间能够以某种方式表示它“正在寻找”的内容。它也通过 key weights 投影到同一空间中,但这次以使其指向它在相同意义上“是”的方式。然后点积将这些匹配起来,以便我们可以将输入 embeddings 彼此关联,以计算我们的 attention 分数。

所有这些在我的脑海中都说得通,我希望至少在其他一些人中也能说得通:-)

我将在这里结束这篇文章;下次我将发布我对目前为止我们在本书中所经历的内容对上下文长度意味着什么的理解。我们已经看到了随着输入序列增长而增长的隐藏状态的优势——缺点是什么?

  1. 我将为此建立一个索引。
  2. 我无法想到单个 head 可以做到的任何方法,TBH。它并行考虑所有其他tokens,所以当它看第一个 "the" 时,它不知道还有另一个更近的。

« Writing an LLM from scratch, part 12 -- multi-head attention Copyright (c) 2006-2025 by Giles Thomas. This work is licensed under a Creative Commons Attribution 4.0 International License.