如未额外说明,所有图片来自于论文。
今年,以 ChatGPT 为首的大语言模型(Large Language Models, LLMs) 在各个方面大放光彩,由此引发了学术界和商业界对 GPU 等计算资源的需求剧增。
左图来自 DALL・E3,右图来自 DALL・E3
比如监督训练地调优 (supervised fine-tuning, SFT) 一个 Llama2-7B 的模型,需要消耗 80GB 以上的内存。而这往往不够,为了和人类对齐(alignment),大语言模型还要经过 RLHF (reinforcement learning from human feedback) 的训练。RLHF 的 GPU 消耗往往是 SFT 的 2 倍以上,训练时间更能达到 6 倍以上。
近日,美国政府宣布限制英伟达 GPU 产品 H100, H800等进入中国市场。这项条款无疑为中国发展大语言模型(LLMs) 和人工智能增添了很多阻力。减小 RLHF 的训练成本(GPU 消耗和训练时间)对 LLMs 的发展非常重要。
RLHF 包含三个阶段:
1. 监督式地调优(Supervised Fine-Tuning, SFT)。
2. 从对比数据中学习奖励模型(reward model)。
3. 利用强化学习(RL)算法来最大化奖励。
图片来源自 InstructGPT 论文
我们发现 RLHF 的主要计算开销来源于第三阶段(奖励最大化)。这一点可以从 DeepSpeed-Chat 的报告里看到,第三阶段的训练时间是前两个阶段时间总和的 4 倍以上。而且,根据我们的经验,第三阶段的 GPU 消耗是前两阶段的 2 倍以上。
图片来自 DeepSpeed-Chat 技术报告
目前 RLHF 第 3 阶段的主要计算瓶颈是什么?
我们发现该阶段的计算瓶颈主要来源用来目前使用的 RL 算法:PPO 算法。PPO 算法是用来解决普适 RL 问题的最流行的算法之一,有非常多成功的案例。我们在这里省略 PPO 的技术细节,着重介绍 PPO 的一个关键组件:价值模型 (The value model)。价值模型是一个需要被训练的神经网络,能够有效地估计给定策略的预期长期回报。尽管价值模型为 PPO 带来了良好的性能,但它在 RLHF 任务中也引入了沉重的计算开销。例如,为了更好地与人类偏好对齐,PPO 中的价值模型通常与 LLM 大小相似,这使存储需求翻了一番。此外,价值模型的训练需要存储其梯度、激活和优化器状态,这进一步增加了近 4 倍的 GPU 存储需求。总结来说,PPO 和它的价值模型(以及其训练相关部分)已成为 RLHF 奖励最大化阶段的主要计算障碍。
相比 PPO,ReMax 是轻量级算法
是否有可能找到比 PPO 更适配 RLHF 的算法?
我们得出的答案是肯定的。这是因为 PPO 和价值模型是为通用 RL 问题设计的,而不是针对像 RLHF 这样的特定问题(RLHF 只是 RL 问题中的一个子类)。有趣的是,我们发现 RLHF 具有三个在 PPO 中未使用的重要结构:
1. 快速模拟(fast simulation): 轨迹(即 LLM 中的整个响应)可以在很短的时间内迅速执行(小于 1s),几乎没有时间开销。
2. 确定性转移(deterministic transitions):上下文确定性依赖于过去的标记和当前生成的标记。
3. 轨迹级奖励(trajectory-level rewards):奖励模型只在响应完成时提供一个奖赏值。
通过这三个观察,我们不难发现 value model 在 RLHF 的问题中是 “冗余” 的。这是因为 value model 设计的初衷是为了随机环境下的样本效率和慢仿真环境的计算效率。然而这在 RLHF 中是不需要的。
ReMax 是针对 RLHF 设计的算法,PPO 则是为通用 RL 设计的算法
ReMax
ReMax 算法基于一个古老的策略梯度算法 REINFORCE,REINFORCE 使用的策略梯度估计器如下图所示:
REINFORCE 梯度估计器
REINFORCE可以在计算层面利用好RLHF任务的三个性质,因为REINFORCE直接利用一个响应的奖励来进行优化,不需要像一般的RL算法一样需要知道中间步骤的奖励和值函数。然而,由于策略的随机性, REINFORCE梯度估计器存在高方差问题(在Richard Sutton的RL书里有指出),这一问题会影响模型训练的有效性,因此REINFORCE在RLHF任务中的效果较差,见下面两张图片。
REINFORCE 的计算代价小,但性能差
REINFORCE 的(随机)梯度值远远大于 ReMax
为解决这一问题,ReMax 使用贪婪生成的回答(greedy response)的奖励作为基准值(baseline value)来构建梯度估计器,具体公式如下:
ReMax 梯度估计器
注意到,贪婪回复的奖励可以看作为期望奖励
的好的近似。在理想情形下(
),对于随机变量
,
,因此我们能够期望估计器
具有更小的方差。
下图展示了 ReMax 的算法流程,红色方框中的是核心算法改变。
ReMax 算法流程
理论保证
我们证明了 ReMax 使用的梯度估计器仍然是真实策略梯度的一个无偏估计器。
详细理论介绍见论文。
算法优点
有效性
在 OPT-1.3B 上,ReMax 可以有效地最大化奖励
在 OPT-1.3B 上,ReMax 的训练非常稳定
GPT4 打分显示 ReMax 得到的模型会更好
高效性
在 Llama2-7B 上,ReMax 可以节省近 50% 的 GPU 内存
通用性
除了 RLHF 任务,作为一个 RL 算法,ReMax 对于经典的 NLP 任务也适用。本文考虑了在 GPT-2 上进行一个电影评论续写的任务,这里奖励模型不是从对比数据学习的。实验观测到,ReMax 可以实现 2.2 倍的训练加速和 60% 的 GPU 内存节省。
在经典的 NLP 任务(文本续写)上,ReMax 相比 PPO 实现了 2.2 倍加速
最后,我们从实验中简要总结了 ReMax 相对于 PPO 的主要优势。