Skip to content

Commit

Permalink
update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
Joyce94 committed Aug 23, 2023
1 parent 5f2cd4c commit 620866d
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 21 deletions.
70 changes: 51 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,40 @@
### 主要内容:
- 支持指令微调Alpaca模型
- 支持训练Reward模型
- 支持PPO算法训练RL模型([PPO算法实现细节](https://zhuanlan.zhihu.com/p/649665766)

- 支持PPO算法训练RL模型
- 支持基于两个基模型,两个lora的适配器,同时加载RM、SFT、Actor、Critic四个模型,支持accelerate分布式训练 ([PPO算法实现细节](https://zhuanlan.zhihu.com/p/649665766)
- 支持基于一个基模型,两个lora适配器,同时加载RM、SFT、Actor、Critic四个模型,支持accelerate、deepspeed训练
- 支持基于一个基模型,一个lora适配器,Actor、Critic共享base model,同时实现RM、SFT、Actor、Critic四个模型功能,支持accelerate、deepspeed训练
- 支持DPO算法训练模型

### 更新
[23/8/10] 支持LLaMA模型训练
[23/8/23] 支持LLaMA2模型训练;支持DPO训练;支持基于一个基模型、选择一个或两个lora适配器训练PPO、支持accelerate、deepspeed训练
[23/8/10] 支持LLaMA模型训练;支持基于两个基模型、两个lora的适配器训练PPO;支持accelerate分布式训练


### 功能
与开源的RLHF训练框架的功能进行对比
| 框架 | SFT Train | RM Train | PPO Train | DPO Train |
| ------------------ | ------------------ | ------------------ | ------------------ | ------------------ |
| Our | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| [Deepspeed-chat]() | :white_check_mark: | :white_check_mark: | :white_check_mark: | |
| [trl]() | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| [MOSS-RLHF]() | | | :white_check_mark: | |


##### PPO Train
| 框架 | Accelerate | Deepspeed | Multi LORA | 最低模型参数量 (7B为例) |
| ------------------ | ------------------ | ------------------ | ------------------ | ------------------ |
| Our | :white_check_mark: | :white_check_mark: | :white_check_mark: | single model size ~ 7B |
| [Deepspeed-chat]() | | :white_check_mark: | | sft+rm+actor+critic ~ 28B |
| [trl]() | :white_check_mark: | | | single model size(not use ref model)~ 7B |
| [MOSS-RLHF]() | actor model、critic model | sft model、rm model | | sft+rm+actor+critic ~ 28B |



## 使用指引

### 使用指引
##### 环境搭建
#### 环境搭建
```
accelerate==0.21.0
datasets==2.13.1
Expand All @@ -25,36 +50,43 @@ wandb=0.15.8
peft==0.4.0
torch==2.0.1
trl==0.5.0
deepspeed==0.10.0
```

##### 支持模型
#### 支持模型
- LLaMA
- LLaMA2

##### 支持训练方式
#### 支持训练方式
- LoRA

### 训练细节
##### 指令微调模型
## 训练细节
#### 指令微调模型
- [训练指南](https://github.com/Joyce94/LLM-RLHF-Tuning/wiki/%E6%8C%87%E4%BB%A4%E5%BE%AE%E8%B0%83%E6%A8%A1%E5%9E%8B)


##### 训练奖励模型
#### 训练奖励模型
- [训练指南](https://github.com/Joyce94/LLM-RLHF-Tuning/wiki/%E8%AE%AD%E7%BB%83%E5%A5%96%E5%8A%B1%E6%A8%A1%E5%9E%8B)

##### PPO训练
- [训练指南](https://github.com/Joyce94/LLM-RLHF-Tuning/wiki/PPO%E8%AE%AD%E7%BB%83)
- [PPO算法实现细节](https://zhuanlan.zhihu.com/p/649665766)
#### PPO训练
- 训练指南
- [基于两个基模型](https://github.com/Joyce94/LLM-RLHF-Tuning/wiki/PPO%E8%AE%AD%E7%BB%83%E2%80%90%E5%9F%BA%E4%BA%8E%E4%B8%A4%E4%B8%AA%E5%9F%BA%E6%A8%A1%E5%9E%8B)
- [PPO算法实现细节](https://zhuanlan.zhihu.com/p/649665766)

- [基于一个基模型](https://github.com/Joyce94/LLM-RLHF-Tuning/wiki/PPO%E8%AE%AD%E7%BB%83%E2%80%90%E5%9F%BA%E4%BA%8E%E4%B8%80%E4%B8%AA%E5%9F%BA%E6%A8%A1%E5%9E%8B)

#### DPO训练
- [训练指南](https://github.com/Joyce94/LLM-RLHF-Tuning/wiki/DPO%E8%AE%AD%E7%BB%83)

### TODO
- 支持DeepSpeed训练
- PPO部分提升训练稳定性
- 支持LLaMA-2模型
## TODO
- PPO提升训练稳定性,实现ppo-max
- 支持BLOOM模型
- 支持Baichuan模型
- 支持QLoRA训练


欢迎加群讨论 [WeChat](assets/RLHF讨论群.jpeg)





1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ wandb=0.15.8
peft==0.4.0
torch==2.0.1
trl==0.5.0
deepspeed==0.10.0
4 changes: 2 additions & 2 deletions script/ppo_co/run_ppo_co_multi_adapters.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ accelerate launch --config_file ds_config.yaml run_ppo_with_peft.py \
--lr_scheduler_type cosine \
--learning_rate 1e-4 \
--weight_decay 0 \
--logging_steps 1 \
--save_steps 1 \
--logging_steps 100 \
--save_steps 100 \
--dataloader_num_workers 16 \
--block_size 256 \
--max_prompt_length 256 \
Expand Down

0 comments on commit 620866d

Please sign in to comment.