Implementing DeepSeek R1's GRPO algorithm from scratch
从零实现 DeepSeek R1 的 GRPO 算法
policy-gradient/GRPO-Zero
本项目旨在以最小的依赖实现 GRPO 训练。 我们几乎从头开始实现所有内容,仅依赖 tokenizers
进行分词和 pytorch
进行训练。
- 没有
transformers
和vLLM
依赖! - 默认配置设置为在单个 A40 GPU (48GB VRAM) 上运行几个小时以获得良好的结果。(如果您从 RunPod 租用,A40 每小时花费
$0.44
。) - 我们支持对 DAPO 项目 中的原始 GRPO 算法的几项改进,包括:
- Token-level policy gradient loss:每个 token 在策略梯度损失中具有相同的权重。
- Removing KL Divergence:KL 散度未在策略梯度损失中使用。这减少了 GPU 内存使用,因为我们不再需要参考策略网络。
- Overlong episode filtering:跳过超出上下文长度限制的未完成 episode。这稳定了训练。 虽然我们默认禁用它以观察模型在有限上下文长度下的学习情况。 将
skip_unfinished_episodes
设置为true
以启用它。
算法
Group Relative Policy Optimization (GRPO) 是 Deepseek 提出的一种使用强化学习训练大型语言模型的算法。这个想法很简单:对于每个问题,我们随机抽取多个答案。然后,答案的优势被定义为归一化奖励。这摆脱了价值估计网络。 特别是,我们实现以下算法:
- 对于每个训练步骤,随机抽取 $N$ 个问题 $q_1, q_2, \cdots, q_N$。
- 对于每个问题 $q_i$,抽取 $M$ 个答案 $a_{i,1}, a_{i,2}, \cdots, a_{i,M}$。
- 计算每个答案 $a_{i,j}$ 的奖励 $r_{i,j}$。
- 计算每个问题 $q_i$ 的奖励的平均值和标准差。
$$ \begin{aligned} \mu_i &\leftarrow \text{mean}(r_{i,1}, r_{i,2}, \cdots, r_{i,M}) \\ \sigma_i &\leftarrow \text{std}(r_{i,1}, r_{i,2}, \cdots, r_{i,M}) \end{aligned} $$
- 对于答案 $a_{i,j}$ 中的每个 token $t$,计算优势如下
$$A_{i,j}[t] \leftarrow \frac{r_{i,j} - \mu_i}{\sigma_i}$$
- 使用 PPO 替代目标计算策略梯度。 为简单起见,我们每次迭代只进行一次策略更新,其中 PPO 目标的梯度等效于以下 vanilla policy gradient estimation (每个 token)。
$$ \nabla_\theta \log \pi_\theta(a_{i,j}[t]) \cdot A_{i,j}[t] $$
- 使用梯度更新策略网络 $\pi(\theta)$。 返回步骤 1。
CountDown Task
我们将使用 Qwen2.5 模型在 CountDown task 上进行训练。 给定 3 个或 4 个数字的列表和一个目标数字,模型需要生成一个使用简单算术运算(+、-、*、/)计算到目标数字的数学表达式。 例如:
Question: Given 1 2 3 4 and a target number 11. Show an expression that evaluates to 11.
Answer: 1 + (2 * 3) + 4
Reward Function
为了解决 CountDown task,我们将使用 GRPO 算法来训练模型,使其在生成最终表达式之前生成思维链推理。 具体来说,该模型经过训练以遵循以下格式:
<think>Model step by step reasoning</think>
<answer>Final answer</answer>
奖励是两个组成部分的总和:
- Format Reward:当模型正确遵循指定的格式(带有 thinking 和 answer 标签)时,模型获得
0.1
的奖励,否则为0
。 - Answer Reward:如果模型的最终答案恰好使用每个提供的数字一次并且正确评估为目标值,则模型获得
1
的奖励,否则获得0
。
Training
我们使用 Qwen2.5-3B-Instruct
模型进行训练。 要训练模型,请运行以下命令:
# initialize the environment
pip install uv
uv sync
# install git-lfs
apt update; apt install git-lfs -y; git lfs install
# download the dataset
git clone https://huggingface.co/datasets/Jiayi-Pan/Countdown-Tasks-3to4
# download the pretrained model
git clone https://huggingface.co/Qwen/Qwen2.5-3B-Instruct
# train the model
uv run train.py
Acknowledgements
该项目建立在几个优秀项目的工作之上:
- DeepSeekMath 开创了 GRPO 算法。
- DAPO 增强了原始 GRPO 算法。
- TinyZero 实现了 GRPO 并创建了 CountDown-Tasks-3to4 数据集。
- nano-aha-moment 对 GRPO 算法进行了清晰的实现和教程。
- Qwen2.5 开发了本项目中使用的高质量预训练模型。