论文概览

论文标题:Enhancing Reasoning through Process Supervision with Monte Carlo Tree Search

研究机构:AI Lab

基座模型:Llama-3.1-8B-Instruct, DeepSeek-Math-7B-Instruct

论文地址:https://arxiv.org/abs/2501.01478

核心创新:将蒙特卡洛树搜索与自训练相结合,通过过程监督提升推理能力

主要成果:在MATH数据集上提升4.85%,GSM8K上提升5.03%

MCTS过程监督框架

核心贡献

本文提出了一个创新的自训练框架,巧妙地将蒙特卡洛树搜索与过程监督结合:

  1. MCTS引导的数据生成:利用蒙特卡洛树搜索系统性地探索推理路径,生成高质量训练数据
  2. 过程级监督信号:不仅关注最终答案,更重视每个推理步骤的质量评估
  3. 迭代自我提升:通过多轮生成-训练循环,模型能力逐步提升
  4. 跨任务泛化:训练得到的推理能力能够有效迁移到不同任务

问题背景:推理训练的挑战

传统自训练的困境

虽然自训练(Self-Training)是提升模型能力的有效方法,但在推理任务中面临独特挑战:

1. 探索效率低下

  • 随机采样很难找到高质量的推理路径
  • 大部分生成的推理过程都是错误的
  • 缺乏系统性的探索策略

2. 监督信号稀疏

  • 只有最终答案的对错反馈信息量有限
  • 无法区分"歪打正着"和"正确推理"
  • 难以识别推理过程中的关键错误步骤

3. 数据质量参差不齐

  • 错误的推理步骤会误导模型学习
  • 缺乏有效的质量筛选机制
  • 可能加强模型的错误推理模式

核心洞察:树搜索遇上过程监督

问题的关键在于:如何系统性地探索推理空间,并对每个推理步骤给出准确的质量评估?

本文的创新思路是将问题建模为树搜索过程:

  • 每个推理步骤对应树中的一个节点
  • MCTS算法系统性地探索最有希望的推理路径
  • 过程监督为每个步骤提供细粒度的质量评估

技术方法详解

数据生成:MCTS驱动的推理探索

问题建模:推理即树搜索

我们将推理过程建模为一棵树:

  • 根节点:问题本身
  • 中间节点:部分推理过程
  • 叶子节点:完整的解答
  • :从一个推理状态到下一个状态的转换

MCTS搜索过程示例

MCTS四步曲:系统性探索推理空间

1. 选择(Selection) 从根节点开始,使用UCB(Upper Confidence Bound)策略选择最有希望的子节点:

  • 平衡探索(exploration)和利用(exploitation)
  • 优先选择既有潜力又未充分探索的路径
  • 一直向下搜索直到找到未完全扩展的节点

2. 扩展(Expansion) 在选中的节点处生成新的推理步骤:

  • 使用LLM以非零温度采样下一步推理
  • 每次扩展增加一个新的子节点
  • 确保探索的多样性

3. 模拟(Simulation) 从新扩展的节点继续推理直到得到最终答案:

  • 使用LLM完成剩余的推理过程
  • 评估最终答案的正确性
  • 获得本次探索的奖励信号

4. 回传(Backpropagation) 将奖励信号向上传播,更新路径上所有节点的统计信息:

  • 正确答案:奖励 = 1.0
  • 错误答案:奖励 = 0.0
  • 更新访问次数和累积奖励

过程级质量评估

对于每个推理步骤,我们计算其质量分数:

$$r_{j,k}^i = \alpha \times \frac{Q(v_{j,k}^i)}{N(v_{j,k}^i)}$$

其中:

  • $Q(v_{j,k}^i)$:节点的累积奖励
  • $N(v_{j,k}^i)$:节点的访问次数
  • $\alpha$:缩放常数

这个分数反映了该推理步骤的**“相对正确度”**,即在多次探索中,通过这个步骤最终得到正确答案的概率。

训练数据构建

经过MCTS搜索后,我们得到高质量的训练数据:

$$D = {(x_i, p_j^i, s_j^i, c_j^i)}$$

其中:

  • $x_i$:第$i$个问题
  • $p_j^i$:问题$x_i$的第$j$个部分解
  • $s_j^i$:部分解$p_j^i$的下一步推理
  • $c_j^i$:分配给$s_j^i$的质量分数

数据筛选策略

  • 只保留质量分数大于0的推理步骤
  • 避免错误推理步骤污染训练数据
  • 确保训练数据的高质量

迭代训练:螺旋上升的能力提升

损失函数设计

我们使用加权损失函数进行训练:

$$\mathcal{L}(\theta_i) = -\sum_{(x,p,s,c) \in D} c \cdot \log \pi_{\theta_i}(s|x,p) + \beta \cdot D_{KL}(\pi_{\theta_i}(\cdot|x,p) || \pi_{\theta_{i-1}}(\cdot|x,p))$$

第一项:加权负对数似然

  • 高质量的推理步骤获得更大的训练权重
  • 让模型更多地学习正确的推理模式

第二项:KL散度正则化

  • 防止模型在训练过程中发生剧烈的分布漂移
  • 保持与上一轮模型的合理距离
  • 借鉴了RLHF中的稳定性技巧

迭代流程

  1. 第0轮:从预训练模型开始
  2. 第i轮
    • 使用第$(i-1)$轮的模型进行MCTS数据生成
    • 训练得到第$i$轮模型
    • 评估性能提升
  3. 重复:直到性能收敛

这种迭代方式让模型能够:

  • 自我纠错:发现并修正之前的推理错误
  • 能力扩展:探索更复杂的推理路径
  • 知识巩固:强化正确的推理模式

实验结果:数据说话

主要性能提升

在两个主要数学推理基准上的结果:

MATH数据集(Llama-3.1-8B-Instruct):

  • 基础模型:47.07%
  • 第1轮训练:50.04% (+2.97%)
  • 第4轮训练:51.92% (+4.85%)

GSM8K数据集(DeepSeek-Math-7B-Instruct):

  • 基础模型:80.77%
  • 第1轮训练:85.35% (+4.58%)
  • 第2轮训练:85.80% (+5.03%)

与基线方法对比

方法MATHGSM8K
Zero-shot-CoT47.07 ± 0.1580.77 ± 0.25
RFT47.96 ± 0.2183.33 ± 0.25
Step-level DPO48.48 ± 0.2782.55 ± 0.17
我们的方法51.92 ± 0.2085.80 ± 0.22

关键洞察

  • 相比传统的拒绝采样微调(RFT),我们的方法提升更显著
  • 相比Step-level DPO,MCTS的系统性探索带来更大收益
  • 迭代训练确实能够持续提升模型能力

跨任务泛化能力

我们还验证了方法的泛化能力:

GSM8K → MATH

  • 在GSM8K上训练的模型在MATH上仍有提升
  • 第2轮:48.74% (vs 基础47.07%)

MATH → GSM8K

  • 在MATH上训练的模型在GSM8K上表现稳定
  • 第2轮:85.72% (vs 基础80.77%)

这说明我们的方法学到的是通用的推理能力,而非任务特化的技巧。

技术创新点

1. MCTS与自训练的有机结合

  • 首次将蒙特卡洛树搜索应用于大语言模型的自训练
  • 系统性地探索推理空间,避免随机采样的低效性

2. 过程级监督信号

  • 不仅关注最终答案,更重视推理过程的质量
  • 为每个推理步骤提供细粒度的质量评估

3. 迭代能力提升机制

  • 通过多轮生成-训练循环实现螺旋上升
  • KL正则化确保训练过程的稳定性

4. 通用性强

  • 适用于不同的基础模型和推理任务
  • 具有良好的跨任务泛化能力

局限性与未来方向

当前挑战

  1. 计算成本较高:MCTS搜索需要大量的前向推理计算
  2. 超参数敏感:UCB参数、温度设置等需要仔细调优
  3. 评估方式限制:目前主要依赖最终答案的正确性评估

未来展望

  1. 搜索效率优化:开发更高效的搜索策略和剪枝技术
  2. 更丰富的奖励信号:结合过程奖励模型提供更细粒度的反馈
  3. 扩展到更多任务:将方法推广到代码生成、科学推理等领域

论文翻译:https://dppemvhuzp.feishu.cn/docx/W84qdm5AAoienmxFl17cU7OOnkg?from=from_copylink