forked from huggingface/trl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PPO / Reinforce Trainers (huggingface#1540)
* Add ppov2 trainer * make eos trick optional, remove unused args * quick fix * precommit * update debugging script * fix out of bound `drop_last=True`; use built-in scheduler * Add PPO examples * push changes * quick change * quick change * various bug fixes * remove unnecessary grad accumulation setting * push new changes * fix DS3 model saving * update ppo.py * refactor * quick change * refactor * update ppo trainer * refactor * quick test * add ds2 /ds3 7 processes config * add vllm trainer * quick change * experiment with reward normalization * push changes * quick push * push changes * push various changes * refactor to use ModelConfig * quick change * refactor * refactor * Simplify DS logic * quick update * remove unnecessary files * precommit * deepspeed fix; handle edge case when eos_token_id = 0 * add PPO tldr example * add TL;DR example * fix undefined var * utilize all samples in rloo * quick setting * remove the unnecessary `value_model` * use exact_div * allow saving the deepspeed model * refactor * remove dead code * Use some shared utilities * add some end-to-end test cases * add PPOv2 docs and RLOO docs / tests * update docs * quikc push * fix ci * fix type annotation for ci * quick update * update trainer docs
- Loading branch information
Showing
23 changed files
with
3,114 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import shlex | ||
import subprocess | ||
import sys | ||
from collections import defaultdict | ||
from dataclasses import dataclass | ||
|
||
import pandas as pd | ||
from datasets import load_dataset | ||
from gpt_tldr_judge import LLMJudgeConfig, llm_judge | ||
from transformers import AutoTokenizer, HfArgumentParser | ||
from vllm import SamplingParams, SingleGPULLM | ||
|
||
|
||
""" | ||
python -i examples/scripts/evals/generate_tldr.py \ | ||
--model_name_or_path vwxyzjn/rloo_tldr \ | ||
--output_path examples/scripts/minimal/evals/rloo_tldr.csv \ | ||
--n 1000 | ||
python -i examples/scripts/evals/generate_tldr.py \ | ||
--model_name_or_path vwxyzjn/ppo_tldr \ | ||
--output_path examples/scripts/minimal/evals/ppo_tldr.csv \ | ||
--n 1000 | ||
""" | ||
|
||
|
||
@dataclass | ||
class Args: | ||
output_path: str | ||
model_name_or_path: str | ||
model_revision: str = "main" | ||
n: int = 1000 | ||
|
||
|
||
def run_command(command: str): | ||
command_list = shlex.split(command) | ||
print(f"running {command}") | ||
subprocess.run(command_list, stderr=sys.stderr, stdout=sys.stdout) | ||
|
||
|
||
MAX_TOKENS = 200 # a very generous max token length | ||
parser = HfArgumentParser(Args) | ||
args = parser.parse_args_into_dataclasses()[0] | ||
tokenizer = AutoTokenizer.from_pretrained( | ||
args.model_name_or_path, | ||
revision=args.model_revision, | ||
) | ||
raw_datasets = load_dataset("trl-internal-testing/tldr-preference-sft-trl-style") | ||
prompts = raw_datasets["test"]["prompt"] | ||
if args.n is not None: | ||
prompts = prompts[: args.n] | ||
reference_summaries = [message[-1]["content"] for message in raw_datasets["test"]["messages"]] | ||
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=MAX_TOKENS) | ||
llm = SingleGPULLM( | ||
model=args.model_name_or_path, | ||
revision=args.model_revision, | ||
tensor_parallel_size=1, | ||
device="cuda:0", | ||
) | ||
outputs = llm.generate(prompts, sampling_params) | ||
table = defaultdict(list) | ||
|
||
# Print the outputs. | ||
for output, reference_response in zip(outputs, reference_summaries): | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
table["prompt"].append(prompt) | ||
table["model_response"].append(generated_text.strip()) # need `strip()` because of the leading space | ||
table["model_response_len"].append(len(output.outputs[0].token_ids)) | ||
table["reference_response"].append(reference_response) | ||
table["reference_response_len"].append( | ||
len(tokenizer(f" {reference_response}")["input_ids"]) | ||
) # prepend leading space | ||
|
||
df = pd.DataFrame(table) | ||
df.to_csv(args.output_path) | ||
|
||
##### | ||
# GPT as a judge | ||
#### | ||
df["response0"] = df["model_response"] | ||
df["response1"] = df["reference_response"] | ||
judged_df = llm_judge( | ||
LLMJudgeConfig( | ||
n=args.n, | ||
model="gpt-3.5-turbo-0125", | ||
), | ||
df, | ||
) | ||
judged_df.to_csv(args.output_path.replace(".csv", "_judged.csv")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
# you can download the CSV from https://wandb.ai/costa-huang/tldr_summarize/runs/gb2dian5 | ||
|
||
import asyncio | ||
import random | ||
import time | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
import pandas as pd | ||
from openai import AsyncOpenAI | ||
from tqdm.asyncio import tqdm_asyncio | ||
from transformers import HfArgumentParser | ||
|
||
|
||
@dataclass | ||
class LLMJudgeConfig: | ||
n: int = 64 | ||
model: str = "gpt-3.5-turbo-0125" | ||
max_parallel_requests: Optional[int] = None | ||
|
||
def __post_init__(self): | ||
if "gpt-3.5" in self.model: | ||
# gpt-3.5 generates so fast that it will exceeds the | ||
# token limit per minute | ||
self.max_parallel_requests = 11 | ||
elif "gpt-4" in self.model: | ||
self.max_parallel_requests = 13 | ||
|
||
|
||
@dataclass | ||
class Args: | ||
csv: str = "trained_response.csv" | ||
output_path: Optional[str] = None | ||
num_trails: int = 1 | ||
|
||
|
||
TEMPLATE = r""" | ||
Which of the following summaries does a better job of summarizing the most important points in the given forum post, without including unimportant or irrelevant details? Judge based on accuracy, coverage, and coherence. | ||
### Post: | ||
{{post}} | ||
### Summary A: | ||
{{response0}} | ||
### Summary B: | ||
{{response1}} | ||
### Instructions: | ||
FIRST provide a one-sentence comparison of the two summaries, explaining which \ | ||
you prefer and why. SECOND, on a new line, state only "A" or "B" to indicate your choice. Your response should use the format: | ||
Comparison: <one-sentence comparison and explanation> | ||
Preferred: <"A" or "B"> | ||
""" | ||
|
||
|
||
def llm_judge(ljc: LLMJudgeConfig, df: pd.DataFrame): | ||
limiter = asyncio.Semaphore(ljc.max_parallel_requests) | ||
async_client = AsyncOpenAI() | ||
|
||
async def process_text(post: str, response0: str, response1: str, i: int): | ||
text = TEMPLATE.replace("{{post}}", post) | ||
text = text.replace("{{response0}}", response0) | ||
text = text.replace("{{response1}}", response1) # Ensure this split logic is correct for your data | ||
|
||
async with limiter: | ||
response = None | ||
while response is None: | ||
try: | ||
response = await async_client.chat.completions.create( | ||
model=ljc.model, | ||
messages=[ | ||
{"role": "system", "content": "You are a helpful assistant."}, | ||
{"role": "user", "content": text}, | ||
], | ||
) | ||
r = response.choices[0].message.content | ||
except Exception as e: | ||
print(f"error in {i}: {e}") | ||
time.sleep(30) # deal with rate limit | ||
continue | ||
|
||
try: | ||
comparison = r.split("Comparison:")[1].split("Preferred:")[0].strip() | ||
preferred = r.split("Preferred:")[1].strip() | ||
return comparison, preferred, i, text + r | ||
except Exception as e: | ||
print(f"error in {i} {e}") | ||
return "", random.choice(["A", "B"]), i, text + r | ||
|
||
async def main(ljc: LLMJudgeConfig, df: pd.DataFrame): | ||
"""`df` should have columns: `prompt`, `response0`, `response1`""" | ||
tasks = [] | ||
df["explanation"] = [None for _ in range(len(df))] | ||
df["preferred"] = [None for _ in range(len(df))] | ||
df["shuffled_index"] = [None for _ in range(len(df))] | ||
df["entire_conversation"] = [None for _ in range(len(df))] | ||
r = range(min(ljc.n, len(df))) | ||
if ljc.n == -1: | ||
r = range(len(df)) | ||
for i in r: | ||
post = df["prompt"].iloc[i].strip() | ||
# shuffled the index to avoid GPT4's preference bias in the content's order | ||
shuffled_index = random.randint(0, 1) | ||
df.at[i, "shuffled_index"] = shuffled_index | ||
responses = [ | ||
df["response0"].iloc[i].strip(), | ||
df["response1"].iloc[i].strip(), | ||
] | ||
response0 = responses[shuffled_index] | ||
response1 = responses[1 - shuffled_index] | ||
task = asyncio.create_task(process_text(post, response0, response1, i)) | ||
tasks.append(task) | ||
|
||
results = await tqdm_asyncio.gather(*tasks) | ||
|
||
for _, (comparison, preferred, i, entire_conversation) in enumerate(results): | ||
df.at[i, "explanation"] = comparison | ||
df.at[i, "entire_conversation"] = entire_conversation | ||
preferred_label = ( | ||
"response0" | ||
if (df.at[i, "shuffled_index"] == 0 and preferred == "A") | ||
or (df.at[i, "shuffled_index"] == 1 and preferred == "B") | ||
else "response1" | ||
) | ||
df.at[i, "preferred"] = preferred_label | ||
print(df["preferred"].value_counts()) | ||
return df | ||
|
||
return asyncio.run(main(ljc, df)) | ||
|
||
|
||
if __name__ == "__main__": | ||
args, ljc = HfArgumentParser((Args, LLMJudgeConfig)).parse_args_into_dataclasses() | ||
df = pd.read_csv(args.csv) | ||
df["reference_response"] = df["reference_response"].map(lambda x: x.split("<|endoftext|>")[0].strip()) | ||
df["prompt"] = df["query"].map(lambda x: x.strip()) | ||
df["response0"] = df["model_response"].map(lambda x: x.strip()) | ||
df["response1"] = df["reference_response"].map(lambda x: x.strip()) | ||
judge_df = llm_judge(ljc, df) | ||
judge_df.to_csv(args.output_path) |
Oops, something went wrong.