home blog microblog favorites pl resources bread recipes rss

从零开始构建一个简单的搜索引擎*

2025年5月20日Chris Gregory 共同创作! *如果你包括了 word2vec 的话。 几天前,Chris 和我花了几个小时,从“零”开始为我的博客创建了一个搜索引擎。主要是他引导我完成的,因为在这个实验之前,我对 word2vec 的了解还很模糊。

我们构建的搜索引擎基于 word embeddings(词嵌入)。它指的是某种函数,它接受一个词,并将其映射到 N 维空间(在本例中,N=300),其中每个维度模糊地对应于某种意义轴。Word2vec from Scratch 是一篇不错的博客文章,展示了如何训练你自己的迷你 word2vec 并解释其内部原理。

搜索引擎背后的想法是将我的每篇文章都嵌入到这个空间中,方法是将文章中单词的 embeddings 相加。对于给定的搜索,我们将以相同的方式嵌入搜索词。然后,我们可以根据所有文章与查询的 cosine similarities (余弦相似度) 对它们进行排序。

下面的公式看起来可能很吓人,但它是在说余弦相似度,即两个向量 cos(theta) 之间的角度的余弦,定义为点积除以每个向量的幅度的乘积。我们将详细介绍所有这些。

来自 Wikimedia 的 Cosine similarity 页面中的公式。

余弦距离可能是比较查询 embedding 和文档 embeddings 以对文档进行排序的最简单方法。另一种直观的选择可能是欧几里得距离,它会测量两个向量在空间中的距离(而不是它们之间的角度)。

我们更喜欢余弦距离,因为它保留了我们的直觉,即如果两个向量具有相同比例的每个 embedding 维度,则它们具有相似的含义。如果你有两个指向相同方向的向量,但一个非常长,一个非常短,则应将它们视为具有相同的含义。(如果两个文档都关于猫,但一个文档中“cat”这个词出现的频率更高,那么它们仍然都只是关于猫的)。

让我们打开 word2vec 并嵌入我们的第一个单词。

Embedding

我们理所当然地使用了这个包含前 10,000 个最流行的词 embeddings 的数据库,它是一个 12MB 的 pickle 文件,大致如下所示:

couch [0.23, 0.05, ..., 0.10]
banana [0.01, 0.80, ..., 0.20]
...

Chris 通过互联网发给了我。如果你 unpickle 它,它实际上是一个 NumPy 数据结构:一个将字符串映射到 numpy.float32 数组的字典。我编写了一个脚本,将这个 pickle 文件转换为普通的 Python floats 和 lists,因为我想手动完成所有这些操作。

加载代码很简单:使用 pickle 库。通常的安全注意事项适用,但我信任 Chris。

import pickle
def load_data(path):
  with open(path, "rb") as f:
    return pickle.load(f)
word2vec = load_data("word2vec.pkl")

你可以打印出 word2vec,但会产生大量的输出。我从惨痛的经历中学到了这一点。也许可以打印 word2vec["cat"]。这将打印出 embedding。

要嵌入一个单词,我们只需要在庞大的字典中查找它。但是,一个无意义或不常见的单词可能不在其中,因此我们返回 None,而不是引发错误。

def embed_word(word2vec, word):
  return word2vec.get(word)

要嵌入多个单词,我们分别嵌入每个单词,然后成对地将 embeddings 相加。如果给定的单词无法嵌入,则忽略它。只有当我们无法理解任何单词时,才会出现问题。

def vec_add(a, b):
  return [x + y for x, y in zip(a, b)]
def embed_words(word2vec, words):
  result = [0.0] * len(next(iter(word2vec.values())))
  num_known = 0
  for word in words:
    embedding = word2vec.get(word)
    if embedding is not None:
      result = vec_add(result, embedding)
      num_known += 1
  if not num_known:
    raise SyntaxError(f"I can't understand any of {words}")
  return result

这就是 embeddings 的基础知识:它是一个字典查找和向量加法。

embed_words([a, b]) == vec_add(embed_word(a), embed_word(b))

现在让我们制作我们的“搜索引擎索引”,或者说是我所有帖子的 embeddings。

嵌入所有帖子

嵌入所有帖子是一个递归的目录遍历,我们在其中构建一个将路径名映射到 embedding 的字典。

import os
def load_post(pathname):
  with open(pathname, "r") as f:
    contents = f.read()
  return normalize_text(contents).split()
def load_posts():
  # Walk _posts looking for *.md files
  posts = {}
  for root, dirs, files in os.walk("_posts"):
    for file in files:
      if file.endswith(".md"):
        pathname = os.path.join(root, file)
        posts[pathname] = load_post(pathname)
  return posts
post_embeddings = {pathname: embed_words(word2vec, words)
          for pathname, words in posts.items()}
with open("post_embeddings.pkl", "wb") as f:
  pickle.dump(post_embeddings, f)

我们还做了另一件事:normalize_text。这是因为博客文章很混乱,包含标点符号、大写字母和各种其他无意义的东西。为了获得最佳匹配,我们希望将诸如“CoMpIlEr”和“compiler”之类的内容放在同一个 bucket 中。

import re
def normalize_text(text):
  return re.sub(r"[^a-zA-Z]", r" ", text).lower()

我们也将对每个查询执行相同的操作。说到这,我们应该测试一下。让我们做一个小的搜索 REPL。

一个小的搜索 REPL

我们将从使用 Python 的内置 REPL 创建库 code 开始。我们可以创建一个定义 runsource 方法的子类。它真正需要做的就是处理 source 输入并返回一个 falsy 值(否则它会等待更多输入)。

import code
class SearchRepl(code.InteractiveConsole):
  def __init__(self, word2vec, post_embeddings):
    super().__init__()
    self.word2vec = word2vec
    self.post_embeddings = post_embeddings
  def runsource(self, source, filename="<input>", symbol="single"):
    for result in self.search(source):
      print(result)

然后我们可以定义一个 search 函数,它将我们现有的函数组合在一起。就这样,我们有了一个搜索功能:

class SearchRepl(code.InteractiveConsole):
  # ...
  def search(self, query_text, n=5):
    # Embed query
    words = normalize_text(query_text).split()
    try:
      query_embedding = embed_words(self.word2vec, words)
    except SyntaxError as e:
      print(e)
      return
    # Cosine similarity
    post_ranks = {pathname: vec_cosine_similarity(query_embedding,
                           embedding) for pathname,
           embedding in self.post_embeddings.items()}
    posts_by_rank = sorted(post_ranks.items(),
                reverse=True,
                key=lambda entry: entry[1])
    top_n_posts_by_rank = posts_by_rank[:n]
    return [path for path, _ in top_n_posts_by_rank]

是的,我们必须进行余弦相似度计算。值得庆幸的是,维基百科上的数学片段几乎可以 1:1 地翻译成 Python 代码:

import math
def vec_norm(v):
  return math.sqrt(sum([x*x for x in v]))
def vec_cosine_similarity(a, b):
  assert len(a) == len(b)
  a_norm = vec_norm(a)
  b_norm = vec_norm(b)
  dot_product = sum([ax*bx for ax, bx in zip(a, b)])
  return dot_product/(a_norm*b_norm)

最后,我们可以创建并运行 REPL。

sys.ps1 = "QUERY. "
sys.ps2 = "...... "
repl = SearchRepl(word2vec, post_embeddings)
repl.interact(banner="", exitmsg="")

这就是与之交互的样子:

QUERY.type inference
_posts/2024-10-15-type-inference.md
_posts/2025-03-10-lattice-bitset.md
_posts/2025-02-24-sctp.md
_posts/2022-11-07-inline-caches-in-skybison.md
_posts/2021-01-14-inline-caching.md
QUERY.

这是来自一个非常小的数据集(我的博客)的示例查询。这是一个非常好的搜索结果,但它可能不能代表整体搜索质量。Chris 说我应该挑挑拣拣,“因为人工智能领域的每个人都这样做”。

好的,这真的很简洁。但是,大多数想在我的网站上查找内容的人都不会运行他们的终端。虽然我的网站明确设计为可以在诸如 Lynx 之类的终端浏览器中良好运行,但大多数人已经在图形 Web 浏览器中。因此,让我们制作一个搜索前端。

一个小的 Web 搜索

到目前为止,我们一直在我的本地机器上运行,在那里我不介意有一个 12MB 的权重文件。既然我们要迁移到 Web,我宁愿不让休闲浏览者承担意外的大型下载负担。因此,我们需要变得聪明。

幸运的是,Chris 和我都看到了这篇非常酷的博客文章,它讨论了在 GitHub Pages 上托管 SQLite 数据库。该博客文章详细介绍了作者如何:

这非常酷,但再次:SQLite 虽然小,但对于此项目来说相对较大。我们想从头开始构建东西。幸运的是,我们可以模仿主要的想法。

我们可以给 word2vec 字典一个稳定的顺序,并将其拆分为两个文件。一个文件可以只包含 embeddings,没有名称。另一个文件,即索引,可以将每个单词映射到该单词的权重的字节开始和字节长度(我们认为 start&length 可能比 start&end 在线上传输更小)。

#vecs.jsonl[0.23,0.05,...,0.10][0.01,0.80,...,0.20]...
#index.json{"couch":[0,20],"banana":[20,30],...}

这样做的好处是 index.json 比 word2vec blob 小得多,只有 244KB。既然它不会经常更改(word2vec 多久更改一次?),我不会对用户急切地下载整个索引感到难过。同样,post_embeddings.json 只有 388KB。它们甚至可以缓存。并且由服务器和浏览器自动(解)压缩(分别为 84KB 和 140KB)。如果我们选择二进制格式,两者都会更小,但出于本文的目的,我们将推迟这样做。

然后,我们可以向服务器发出 HTTP Range 请求,并且只下载我们需要的权重部分。甚至可以将所有范围捆绑到一个请求中(称为多部分范围)。不幸的是,GitHub Pages 似乎不支持多部分,因此我们改为在单独的请求中下载每个单词的范围。

这是相关的 JS 代码,省略了(简短、非常熟悉的)向量函数:

(async function() {
 // Download stuff
 async function get_index() {
  const req = await fetch("index.json");
  return req.json();
 }
 async function get_post_embeddings() {
  const req = await fetch("post_embeddings.json");
  return req.json();
 }
 const index = new Map(Object.entries(await get_index()));
 const post_embeddings = new Map(Object.entries(await get_post_embeddings()));
 // Add search handler
 search.addEventListener("input", debounce(async function(value) {
  const query = search.value;
  // TODO(max): Normalize query
  const words = query.split(/\s+/);
  if (words.length === 0) {
   // No words
   return;
  }
  const requests = words.reduce((acc, word) => {
   const entry = index.get(word);
   if (entry === undefined) {
    // Word is not valid; skip it
    return acc;
   }
   const [start, length] = entry;
   const end = start+length-1;
   acc.push(fetch("vecs.jsonl", {
    headers: new Headers({
     "Range": `bytes=${start}-${end}`,
    }),
   }));
   return acc;
  }, []);
  if (requests.length === 0) {
   // None are valid words :(
   search_results.innerHTML = "No results :(";
   return;
  }
  const responses = await Promise.all(requests);
  const embeddings = await Promise.all(responses.map(r => r.json()));
  const query_embedding = embeddings.reduce((acc, e) => vec_add(acc, e));
  const post_ranks = {};
  for (const [path, embedding] of post_embeddings) {
   post_ranks[path] = vec_cosine_similarity(embedding, query_embedding);
  }
  const sorted_ranks = Object.entries(post_ranks).sort(function(a, b) {
   // Decreasing
   return b[1]-a[1];
  });
  // Fun fact: HTML elements with an `id` attribute are accessible as JS
  // globals by that same name.
  search_results.innerHTML = "";
  for (let i = 0; i < 5; i++) {
   search_results.innerHTML += `<li>${sorted_ranks[i][0]}</li>`;
  }
 }));
})();

你可以看看实际的 搜索页面。特别是,打开浏览器控制台的网络请求选项卡。惊叹于它只下载几个 4KB 的 embeddings 块。

那么我们的搜索技术效果如何?让我们尝试构建一个客观的评估。

评估

我们将设计一个指标,粗略地告诉我们何时我们的搜索引擎比没有词 embeddings 的幼稚方法更好或更差。

我们首先收集一个 (document, query) 对的评估数据集。从一开始,我们将通过自己收集此数据集来偏向此评估,但希望它仍然可以帮助我们了解搜索的质量。在这种情况下,查询只是我们认为应该成功检索文档的几个搜索词。

sample_documents = {
 "_posts/2024-10-27-on-the-universal-relation.md": "database relation universal tuple function",
 "_posts/2024-08-25-precedence-printing.md": "operator precedence pretty print parenthesis",
 "_posts/2019-03-11-understanding-the-100-prisoners-problem.md": "probability strategy game visualization simulation",
 # ...
}

现在我们已经收集了我们的数据集,让我们实现一个 top-k 准确性指标。此指标衡量给定其对应查询,文档出现在前 k 个搜索结果中的百分比。

def compute_top_k_accuracy(
  # Mapping of post to sample search query (already normalized)
  # See sample_documents above
  eval_set: dict[str, str],
  max_n_keywords: int,
  max_top_k: int,
  n_query_samples: int,
) -> list[list[float]]:
  counts = [[0] * max_top_k for _ in range(max_n_keywords)]
  for n_keywords in range(1, max_n_keywords + 1):
    for post_id, keywords_str in eval_set.items():
      for _ in range(n_query_samples):
        # Construct a search query by sampling keywords
        keywords = keywords_str.split(" ")
        sampled_keywords = random.choices(keywords, k=n_keywords)
        query = " ".join(sampled_keywords)
        # Determine the rank of the target post in the search results
        ids = search(query, n=max_top_k)
        rank = safe_index(ids, post_id)
        # Increment the count of the rank
        if rank is not None and rank < max_top_k:
          counts[n_keywords - 1][rank] += 1
  accuracies = [[0.0] * max_top_k for _ in range(max_n_keywords)]
  for i in range(max_n_keywords):
    for j in range(max_top_k):
      # Divide by the number of samples to get the average across samples and
      # divide by the size of the eval set to get accuracy over all posts.
      accuracies[i][j] = counts[i][j] / n_query_samples / len(eval_set)
      # Accumulate accuracies because if a post is retrieved at rank i,
      # it was also successfully retrieved at all ranks j > i.
      if j > 0:
        accuracies[i][j] += accuracies[i][j - 1]
  return accuracies

让我们首先评估一个基线搜索引擎。此实现根本不使用词 embeddings。我们只是规范化文本,并计算每个查询词在文档中出现的次数,然后按查询词出现次数对文档进行排序。绘制各种 k 值的 top-k 准确率会给出以下图表。请注意,当我们增加 k 时,我们会获得更高的准确率 - 在极限情况下,当 k 接近我们的文档数量时,我们接近 100% 的准确率。

您可能还会注意到,随着我们增加关键字的数量,准确率也会提高。我们还可以看到,随着关键字数量的增加,线条越来越接近,这表明每个新关键字的边际收益递减。

这些兆字节的词 embeddings 实际上是否可以提高我们的搜索效果?我们必须与基线进行比较。也许该基线是将每个文档中所有关键字的计数相加来对其进行排序。我们将其作为读者的练习,因为我们的时间已经不多了:)

看看更大的 word2vec 如何帮助提高准确性也会很有趣。在为 top-k 采样时,会产生很多错误输出(I can't understand any of ['prank', ...])。这些未知的单词将从搜索中删除。更大的 word2vec(超过 10,000 个单词)可能包含这些不太常见的单词,因此搜索效果更好。

总结

您只需一百行左右的代码就可以从“头”构建一个小型搜索引擎。请参阅完整的 search.py,其中包括用于评估和绘图的一些额外功能。

未来的想法

我们可以比简单的余弦相似度更花哨。让我们想象一下,我们所有的文档都讨论计算机,但只有一个文档讨论编译器(那会很悲伤)。如果我们的搜索词之一是“计算机”,那么它实际上无助于缩小搜索范围,并且是我们的 embeddings 中的噪声。为了减少噪声,我们可以采用一种称为 TF-IDF(词频逆文档频率)的技术,我们在其中分解跨文档的常见词,并更多地关注每个文档特有的词。

这个博客是开源的。看到错误了吗?继续提出更改