Skip to content

Commit

Permalink
Standardize dataset_num_proc usage (#1925)
Browse files Browse the repository at this point in the history
* uniform dataset_num_proc

* num_proc in shuffle

* Update examples/datasets/anthropic_hh.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/ppo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/ppo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
  • Loading branch information
3 people authored Aug 13, 2024
1 parent a9a7565 commit 54f806b
Show file tree
Hide file tree
Showing 25 changed files with 94 additions and 59 deletions.
6 changes: 4 additions & 2 deletions examples/datasets/anthropic_hh.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import multiprocessing
import sys
from dataclasses import dataclass, field
from typing import Optional
Expand Down Expand Up @@ -32,6 +31,9 @@ class ScriptArguments:
default=True, metadata={"help": "Update the main revision of the repository"}
)
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the dataset to the Hugging Face Hub"})
dataset_num_proc: Optional[int] = field(
default=None, metadata={"help": "The number of workers to use for dataset processing"}
)


# GPT-4 generated 😄 Define a function to process the input and extract the dialogue into structured format
Expand Down Expand Up @@ -79,8 +81,8 @@ def process(row):

ds = ds.map(
process,
num_proc=1 if args.debug else multiprocessing.cpu_count(),
load_from_cache_file=False,
num_proc=args.dataset_num_proc,
)
if args.push_to_hub:
revisions = ["main"] if args.update_main_revision else []
Expand Down
6 changes: 5 additions & 1 deletion examples/datasets/sentiment_descriptiveness.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ class ScriptArguments:
)
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the dataset to the Hugging Face Hub"})
task: str = field(default="sentiment", metadata={"help": "The task of the dataset"})
dataset_num_proc: Optional[int] = field(
default=None, metadata={"help": "The number of workers to use to tokenize the data"}
)


task_to_filename = {
Expand Down Expand Up @@ -106,7 +109,7 @@ def filter(row):
return True

print("=== Before filtering ===", ds)
ds = ds.filter(filter, load_from_cache_file=False)
ds = ds.filter(filter, load_from_cache_file=False, num_proc=args.dataset_num_proc)
print("=== After filtering ===", ds)

# here we simply take the preferred sample as the chosen one and the first non-preferred sample as the rejected one
Expand Down Expand Up @@ -147,6 +150,7 @@ def process(row):
process,
batched=True,
load_from_cache_file=False,
num_proc=args.dataset_num_proc,
)
for key in ds: # reorder columns
ds[key] = ds[key].select_columns(["prompt", "chosen", "rejected"])
Expand Down
16 changes: 11 additions & 5 deletions examples/datasets/tldr_preference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import multiprocessing
import sys
from dataclasses import dataclass, field
from typing import Optional
Expand Down Expand Up @@ -35,6 +34,9 @@ class ScriptArguments:
default=True, metadata={"help": "Update the main revision of the repository"}
)
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the dataset to the Hugging Face Hub"})
dataset_num_proc: Optional[int] = field(
default=None, metadata={"help": "The number of workers to use to tokenize the data"}
)


if __name__ == "__main__":
Expand All @@ -53,8 +55,12 @@ class ScriptArguments:
ds[key] = ds[key].select(range(50))
cnndm_batches = ["batch0_cnndm", "cnndm0", "cnndm2"]
if not args.debug:
ds["validation_cnndm"] = ds["validation"].filter(lambda x: x["batch"] in cnndm_batches)
ds["validation"] = ds["validation"].filter(lambda x: x["batch"] not in cnndm_batches)
ds["validation_cnndm"] = ds["validation"].filter(
lambda x: x["batch"] in cnndm_batches, num_proc=args.dataset_num_proc
)
ds["validation"] = ds["validation"].filter(
lambda x: x["batch"] not in cnndm_batches, num_proc=args.dataset_num_proc
)

tldr_format_str = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:"
cnndm_format_str = "Article:\n{article}\n\nTL;DR:"
Expand All @@ -72,8 +78,8 @@ def process(row):

ds = ds.map(
process,
num_proc=1 if args.debug else multiprocessing.cpu_count(),
load_from_cache_file=False,
num_proc=args.dataset_num_proc,
)
for key in ds: # reorder columns
ds[key] = ds[key].select_columns(
Expand Down Expand Up @@ -141,8 +147,8 @@ def sft_process(row):

sft_ds = sft_ds.map(
sft_process,
num_proc=1 if args.debug else multiprocessing.cpu_count(),
load_from_cache_file=False,
num_proc=args.dataset_num_proc,
)
for key in sft_ds: # reorder columns
sft_ds[key] = sft_ds[key].select_columns(["prompt", "messages", "id", "subreddit", "title", "post", "summary"])
Expand Down
6 changes: 4 additions & 2 deletions examples/datasets/tokenize_ds.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import multiprocessing
from dataclasses import dataclass, field
from typing import Optional

Expand All @@ -19,6 +18,9 @@ class ScriptArguments:
default="trl-internal-testing/hh-rlhf-helpful-base-trl-style", metadata={"help": "The dataset to load"}
)
model: str = field(default="gpt2", metadata={"help": "The model to use for tokenization"})
dataset_num_proc: Optional[int] = field(
default=None, metadata={"help": "The number of workers to use to tokenize the data"}
)


if __name__ == "__main__":
Expand All @@ -38,7 +40,7 @@ def process(row):

ds = ds.map(
process,
num_proc=1 if args.debug else multiprocessing.cpu_count(),
load_from_cache_file=False,
num_proc=args.dataset_num_proc,
)
print(ds["train"][0]["chosen"])
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ def preprocess_function(examples):
remove_columns=original_columns,
)
train_dataset = train_dataset.filter(
lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length
lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length,
num_proc=num_proc,
)

eval_dataset = eval_dataset.map(
Expand All @@ -208,7 +209,8 @@ def preprocess_function(examples):
remove_columns=original_columns,
)
eval_dataset = eval_dataset.filter(
lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length
lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length,
num_proc=num_proc,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def preprocess_function(examples):
num_proc=num_proc,
remove_columns=original_columns,
)
ds = ds.filter(lambda x: len(x["input_ids"]) < 512, batched=False)
ds = ds.filter(lambda x: len(x["input_ids"]) < 512, batched=False, num_proc=num_proc)

ds.set_format(type="torch")
return ds
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,16 @@ def return_prompt_and_responses(samples) -> Dict[str, str]:
train_dataset = get_stack_exchange_paired(data_dir="data/rl", sanity_check=script_args.sanity_check)
train_dataset = train_dataset.filter(
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length,
num_proc=script_args.num_proc,
)

# 3. Load evaluation dataset
eval_dataset = get_stack_exchange_paired(data_dir="data/evaluation", sanity_check=True)
eval_dataset = eval_dataset.filter(
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length,
num_proc=script_args.num_proc,
)

# 4. initialize training arguments:
Expand Down
22 changes: 12 additions & 10 deletions examples/scripts/bco.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
import logging
from dataclasses import dataclass
from functools import partial
from typing import Literal
from typing import Literal, Optional

import torch
import torch.nn.functional as F
Expand All @@ -76,7 +76,7 @@ class ScriptArguments:
llm_name: Literal["gpt-3.5-turbo", "llama-2-7b-chat", "llama-2-70b-chat"] = "gpt-3.5-turbo"


def build_helpfulness_dataset(llm_name: str) -> Dataset:
def build_helpfulness_dataset(llm_name: str, num_proc: Optional[int] = None) -> Dataset:
"""
Filter `llm_name` completions and binarize given their helpfulness score.
If helpfulness score is 5, it is desirable. Otherwise, it is undesirable.
Expand All @@ -100,34 +100,36 @@ def get_model_response(example, llm_name: str):

dataset = load_dataset("openbmb/UltraFeedback")["train"]

ds = dataset.filter(lambda example: llm_name in example["models"], batched=False, num_proc=8)
ds = ds.filter(lambda example: len(example["models"]) == len(example["completions"]), batched=False, num_proc=8)
ds = dataset.filter(lambda example: llm_name in example["models"], batched=False, num_proc=num_proc)
ds = ds.filter(
lambda example: len(example["models"]) == len(example["completions"]), batched=False, num_proc=num_proc
)

METRIC = "helpfulness"

ds = ds.map(
get_model_rating,
batched=False,
num_proc=8,
fn_kwargs={"metric": METRIC, "llm_name": llm_name},
num_proc=num_proc,
)

ds = ds.map(
get_model_response,
batched=False,
num_proc=8,
fn_kwargs={"llm_name": llm_name},
num_proc=num_proc,
)

ds = ds.select_columns(["source", "instruction", "response", "helpfulness"])

ds = ds.rename_columns({"instruction": "prompt", "response": "completion"})
ds = ds.map(lambda example: {"label": example["helpfulness"] >= 5}, batched=False, num_proc=8)
ds = ds.map(lambda example: {"label": example["helpfulness"] >= 5}, batched=False, num_proc=num_proc)

ds = ds.map(
lambda example: {"prompt": [{"role": "user", "content": example["prompt"]}]},
batched=False,
num_proc=8,
num_proc=num_proc,
)
dataset = ds.train_test_split(test_size=0.05, seed=42)

Expand Down Expand Up @@ -182,7 +184,7 @@ def mean_pooling(model_output, attention_mask):
model, tokenizer = setup_chat_format(model, tokenizer)

# Load the dataset
dataset = build_helpfulness_dataset(script_args.llm_name)
dataset = build_helpfulness_dataset(script_args.llm_name, num_proc=bco_args.dataset_num_proc)

# Apply chat template
def format_dataset(example):
Expand All @@ -192,7 +194,7 @@ def format_dataset(example):
return example

with PartialState().local_main_process_first():
formatted_dataset = dataset.map(format_dataset, batched=False, num_proc=8)
formatted_dataset = dataset.map(format_dataset, batched=False, num_proc=bco_args.dataset_num_proc)

accelerator = Accelerator()
embedding_model = AutoModel.from_pretrained(
Expand Down
3 changes: 1 addition & 2 deletions examples/scripts/cpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
--lora_alpha=16
"""

import multiprocessing
from dataclasses import dataclass, field

from datasets import load_dataset
Expand Down Expand Up @@ -102,8 +101,8 @@ def process(row):

ds = ds.map(
process,
num_proc=1 if cpo_args.debug else multiprocessing.cpu_count(),
load_from_cache_file=False,
num_proc=cpo_args.dataset_num_proc,
)
train_dataset = ds["train"]
eval_dataset = ds["test"]
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ def process(row):

ds = ds.map(
process,
num_proc=multiprocessing.cpu_count(),
load_from_cache_file=False,
num_proc=training_args.dataset_num_proc,
)
train_dataset = ds[args.dataset_train_split]
eval_dataset = ds[args.dataset_test_split]
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/kto.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def format_dataset(example):
example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False)
return example

formatted_dataset = dataset.map(format_dataset)
formatted_dataset = dataset.map(format_dataset, num_proc=kto_args.dataset_num_proc)

# Initialize the KTO trainer
kto_trainer = KTOTrainer(
Expand Down
8 changes: 4 additions & 4 deletions examples/scripts/online_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class ScriptArguments:
max_length: int = 512


def prepare_dataset(dataset, tokenizer, dataset_text_field):
def prepare_dataset(dataset, tokenizer, dataset_text_field, num_proc):
"""pre-tokenize the dataset before training; only collate during training"""

def tokenize(element):
Expand All @@ -73,7 +73,7 @@ def tokenize(element):
tokenize,
remove_columns=dataset.column_names,
batched=True,
num_proc=4, # multiprocessing.cpu_count(),
num_proc=num_proc,
load_from_cache_file=False,
)

Expand Down Expand Up @@ -105,11 +105,11 @@ def tokenize(element):
for key in raw_datasets:
raw_datasets[key] = raw_datasets[key].select(range(1024))
train_dataset = raw_datasets[args.dataset_train_split]
train_dataset = prepare_dataset(train_dataset, tokenizer, args.dataset_text_field)
train_dataset = prepare_dataset(train_dataset, tokenizer, args.dataset_text_field, config.dataset_num_proc)

if args.dataset_test_split is not None:
eval_dataset = raw_datasets[args.dataset_test_split]
eval_dataset = prepare_dataset(eval_dataset, tokenizer, args.dataset_text_field)
eval_dataset = prepare_dataset(eval_dataset, tokenizer, args.dataset_text_field, config.dataset_num_proc)
else:
eval_dataset = None
################
Expand Down
3 changes: 1 addition & 2 deletions examples/scripts/orpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
--lora_alpha=16
"""

import multiprocessing
from dataclasses import dataclass, field

from datasets import load_dataset
Expand Down Expand Up @@ -103,8 +102,8 @@ def process(row):

ds = ds.map(
process,
num_proc=1 if orpo_args.debug else multiprocessing.cpu_count(),
load_from_cache_file=False,
num_prc=orpo_args.dataset_num_proc,
)
train_dataset = ds["train"]
eval_dataset = ds["test"]
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def build_dataset(config, query_dataset, input_min_text_length=2, input_max_text
# load imdb with datasets
ds = load_dataset(query_dataset, split="train")
ds = ds.rename_columns({"text": "review"})
ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)
ds = ds.filter(lambda x: len(x["review"]) > 200, num_proc=args.dataset_num_proc)

input_size = LengthSampler(input_min_text_length, input_max_text_length)

Expand All @@ -84,7 +84,7 @@ def tokenize(sample):
sample["query"] = tokenizer.decode(sample["input_ids"])
return sample

ds = ds.map(tokenize, batched=False)
ds = ds.map(tokenize, num_proc=args.dataset_num_proc)
ds.set_format(type="torch")
return ds

Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ def tokenize(element):

return dataset.map(
tokenize,
remove_columns=dataset.column_names,
batched=True,
num_proc=4, # multiprocessing.cpu_count(),
remove_columns=dataset.column_names,
load_from_cache_file=False,
num_proc=config.dataset_num_proc,
)

################
Expand Down
7 changes: 3 additions & 4 deletions examples/scripts/ppo/ppo_tldr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import multiprocessing
import shutil

from datasets import load_dataset
Expand Down Expand Up @@ -98,15 +97,15 @@ def tokenize(element):
return dataset.map(
tokenize,
remove_columns=dataset.column_names,
num_proc=1 if config.sanity_check else multiprocessing.cpu_count(),
load_from_cache_file=not config.sanity_check,
num_proc=config.dataset_num_proc,
)

train_dataset = prepare_dataset(train_dataset, tokenizer)
eval_dataset = prepare_dataset(eval_dataset, tokenizer)
# filtering
train_dataset = train_dataset.filter(lambda x: x["lengths"] <= 512)
eval_dataset = eval_dataset.filter(lambda x: x["lengths"] <= 512)
train_dataset = train_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=config.dataset_num_proc)
eval_dataset = eval_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=config.dataset_num_proc)
assert train_dataset[0]["input_ids"][-1] != tokenizer.eos_token_id, "The last token should not be an EOS token"
################
# Training
Expand Down
Loading

0 comments on commit 54f806b

Please sign in to comment.