背景
ChatGPT 已经问世一年+了,在训练 ChatGPT 中必不可少的一环是 RLHF 训练,目前开源社区已经有了不少 RLHF 训练框架比如,TRL, DeepSpeedChat 或者最近热门的 LLaMA Factory。这些框架往往是基于 ZeRO 等并行方式,将 RLHF 算法中的四个模型切片后放到同一个 GPU 上。在模型规模越来越大的今天,这种调度方式无法满足 70B+ 甚至仅 13B+模型的全量 RLHF 训练,必须通过合并 Actor Critic 模型或者使用 LoRA 等方式妥协内存使用量。而这些PEFT的方式往往意味着模型效果的妥协。
于是乎开源项目:
https://github.com/OpenLLMAI/OpenRLHF
诞生了,我们基于 Ray 和 vLLM 重新设计了模型调度方案:
- 对于 7B 这种小模型,我们将所有模型放到同一张GPU上
- 对于 13B~34B 的中等模型,我们基于 Ray 将 PPO 中的四个模型放到不同的GPU上实现全量微调
- 对于 34B+的大模型,我们用 vLLM 的 TP 并行加载 Actor 模型,其他模型仍然用 Ray 的方式分散在不同的GPU上
ZeRO2 + Adam Offload + Pinned Memory
我们针对小于 34B 的模型使用 ZeRO2 + Adam Offload + Pinned Memory 的优化方案,我们的基本想法是
- 我们发现 RLHF 训练流程中 80% 的时间都被用于 GPT 模型的样本生成和推理,这是因为 GPT 模型的自回归解码具有 O(n^2) 复杂度,并且通常是 Memory Bound 的。
- 最简单的提升推理效率的方式是避免通过加大矩阵乘法的尺寸来避免 Memory Bound 和增强 GPU 计算效率,但大的矩阵乘法意味着大的batch_size,导致KV Cache对内存需求很大。
- 所以我们想到通过 Optimizer 的 Offload 将 Adam 优化器权重放到 CPU 内存中来节省内存,并且通过 Pinned Memory 避免梯度聚合时候的GPU-CPU通信效率问题。此时我们不仅可以用节省的内存来加大batch_size,而且可以用 ZeRO2 来避免模型切片造成的极大通信开销。
- 对于 13B+ 的模型我们会发现基于 ZeRO2 在 A100 的 80G 内存上无法塞下四个模型,所以我们基于 Ray 将模型分别放到不同的 GPU上。不过对于 Actor 我们会分配更多的GPU来减少 GPU 空闲。
通过这种优化策略后优化后,我们在13B模型上做测试,发现我们实现了 4倍于 DeepSpeedChat 的训练效率。
Ray + vLLM 方案架构
但是对于 34B+ 的模型我们发现即使用 Ray 把模型放到不同的卡上也没有办法放得下去
所以我们想到对于 Actor 推理模块我们基于 vLLM 的 TP 并行和 Dynamic Batching 能力做了分布式推理的优化,然后其他模块(即 Actor/Critic的训练模块和Reward/RefActor的推理模块)因为只参一次 forward 或者 backward 我们采用 ZeRO3 的方式进行并行训练。
每次 PPO 训练,vLLM 推理引擎都会收到 DeepSpeed ZeRO3 训练框架更新后的权重,我们通过 NVIDIA NCCL 高性能通信实现了这个过程。鉴于 vLLM 的高性能推理能力,我们实现的不错的性能收益。更进一步,我们可以融合 Actor 的训练节点和推理节点实现节点复用来避免 GPU 空闲,因为这两个模块并不会同时工作。
至此我们通过 Ray 和 vLLM 实现了 70B+ 模型的 RLHF训练方案,并且我们的方案是无缝兼容 Huggingface Transformers 库的,无需像 Megatron-LM 一样手动修改模型结构。
PPO Implementation Tricks
除了系统架构方面的优化,我们进一步整合了 RLHF 算法方面的优化。根据两篇 PPO 经论文:
https://arxiv.org/abs/2005.12729
https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/
PPO算法在实现细节方面有非常多的讲究和调参技巧,我们在 蜗牛在花园跑酷:如何正确复现 Instruct GPT / RLHF? 一文中论述过部分的实现细节和优化技巧。在 OpenRLHF 中我们集成了这些所有 Implementation Tricks,从而实现了 PPO 训练算法的稳定训练和收敛。
多种对齐算法支持
我们不仅实现了 PPO,而且提供了 DPO/Rejection Sampling/Conditonal SFT 等 Alignemnt 算法的支持。
详情参考 OpenRLHF 项目 Readme.md
Quick Start 快速教程
我们只需要安装好环境依赖后,使用 Ray 提交训练任务即可。OpenRLHF 的模型和数据集完美兼容 HuggingFace 格式,包括热门的 MoE 模型 Mixtral 8*7b,只需要指定模型名字或者本地目录地址即可。
址即可。
# 启动 Ray
nohup ray start --head --node-ip-address 0.0.0.0 --num-gpus 8 --block &> ray.log &
# 提交 Ray 任务
ray job submit --address="http://127.0.0.1:8265" \
--runtime-env-json='{"working_dir": "/openrlhf", "pip": "/openrlhf/requirements.txt"}' \
--no-wait \
-- python3 examples/train_ppo_ray.py \
--ref_num_nodes 1 \ # ref policy 节点数量
--ref_num_gpus_per_node 2 \ # ref policy GPU数量
--reward_num_nodes 1 \ # reward model 节点数量
--reward_num_gpus_per_node 2 \ # reward model GPU数量
--critic_num_nodes 1 \ # critic 节点数量
--critic_num_gpus_per_node 4 \ # critic GPU数量
--actor_num_nodes 1 \ # actor 训练节点数量
--actor_num_gpus_per_node 4 \ # actor 训练GPU数量
--vllm_num_engines 2 \ # actor 推理节点数量
--vllm_tensor_parallel_size 2 \ # actor 推理GPU数量
--pretrain meta-llama/Llama-2-70b-chat-hf \ # Actor 预训练模型
--reward_pretrain meta-llama/Llama-2-70b-chat-hf \ # Reward 预训练模型
--save_path /mnt/bn/wuxibin/cache/ckpt/llama_70b \ # 模型保存路径
--micro_train_batch_size 1 \
--train_batch_size 128 \
--micro_rollout_batch_size 2 \
--rollout_batch_size 1024 \
--max_epochs 1 \
--prompt_max_len 1024 \
--generate_max_len 1024 \
--zero_stage 3 \
--bf16 \
--actor_learning_rate 5e-7 \
--critic_learning_rate 9e-6 \
--init_kl_coef 0.01 \
--prompt_data Open-Orca/OpenOrca,Dahoas/full-hh-rlhf,tasksource/oasst1_pairwise_rlhf_reward \ # 数据集
--prompt_data_probs 0.4,0.5,0.1 \ # 数据集混合概率
--max_samples 80000 \ # 最大样本数量
--normalize_reward \ # Reward Normalization
--actor_init_on_gpu \
--adam_offload \
--flash_attn \
--gradient_checkpointing
对于 SFT/Reward 模型的训练,我们也提供了相应的实现。只需要直接运行 deepspeed 命令即可
# Reward Model training
deepspeed ./train_rm.py \
--save_path ./ckpt/7b_llama \
--save_steps -1 \
--logging_steps 1 \
--eval_steps -1 \
--train_batch_size 128 \
--micro_train_batch_size 1 \
--pretrain OpenLLMAI/Llama-2-7b-sft-model-ocra-500k \
--bf16 \
--max_epochs 1 \
--max_len 2048 \
--zero_stage 3 \
--learning_rate 9e-6 \
--dataset Anthropic/hh-rlhf,tasksource/oasst1_pairwise_rlhf_reward,lmsys/chatbot_arena_conversations,openai/webgpt_comparisons \
--dataset_probs 0.72,0.08,0.12,0.08 \
--flash_attn \
--gradient_checkpointing
# SFT model training
deepspeed ./train_sft.py \
--max_len 2048 \
--dataset Open-Orca/OpenOrca \
--dataset_probs 1.0 \
--train_batch_size 128 \
--micro_train_batch_size 2 \
--max_samples 500000 \
--pretrain meta-llama/Llama-2-7b-hf \
--save_path ./ckpt/7b_llama \
--save_steps -1 \
--logging_steps 1 \
--eval_steps -1 \
--zero_stage 2 \
--max_epochs 1 \
--bf16 \
--flash_attn \
--learning_rate 5e-6 \
--gradient_checkpointing