Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
- move notebooks to `examples/notebooks``
- removed `_nbdev`file
- refactored `gpt2.py` to make it work with more recent `transformers`
- update `requirements` to add recent `transformers`
  • Loading branch information
younesbelkada committed Dec 16, 2022
1 parent 4fe9988 commit dfb864d
Show file tree
Hide file tree
Showing 19 changed files with 173 additions and 49 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ This process is illustrated in the sketch below:


<div style="text-align: center">
<img src="nbs/images/trl_overview.png" width="800">
<img src="examples/images/trl_overview.png" width="800">
<p style="text-align: center;"> <b>Figure:</b> Sketch of the workflow. </p>
</div>

Expand Down
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
139 changes: 139 additions & 0 deletions examples/scripts/04-ppo-sentiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import torch
import wandb
import time
import os
from tqdm import tqdm
import numpy as np
import pandas as pd
tqdm.pandas()

from datasets import load_dataset

from transformers import AutoTokenizer, pipeline

from trl.gpt2 import GPT2HeadWithValueModel, respond_to_batch
from trl.ppo import PPOTrainer
from trl.core import build_bert_batch_from_txt, listify_batch

config = {
"model_name": "lvwerra/gpt2-imdb",
"cls_model_name": "lvwerra/distilbert-imdb",
"steps": 20000,
"batch_size": 256,
"forward_batch_size": 16,
"ppo_epochs": 4,
"txt_in_min_len": 2,
"txt_in_max_len": 8,
"txt_out_min_len": 4,
"txt_out_max_len": 16,
"lr": 1.41e-5,
"init_kl_coef":0.2,
"target": 6,
"horizon":10000,
"gamma":1,
"lam":0.95,
"cliprange": .2,
"cliprange_value":.2,
"vf_coef":.1,
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipe_device = 0 if torch.cuda.is_available() else -1

wandb.init(name='run-42', project='gpt2-test', config=config)

# load imdb with datasets
ds = load_dataset('imdb', split='train')
ds = ds.rename_columns({'text': 'review', 'label': 'sentiment'})
ds = ds.filter(lambda x: len(x["review"])>200, batched=False)

sent_kwargs = {
"return_all_scores": True,
"function_to_apply": "none",
"batch_size": config["forward_batch_size"]
}

sentiment_pipe = pipeline("sentiment-analysis","lvwerra/distilbert-imdb", device=pipe_device)

gpt2_model = GPT2HeadWithValueModel.from_pretrained(config['model_name'])
gpt2_model_ref = GPT2HeadWithValueModel.from_pretrained(config['model_name'])

gpt2_tokenizer = AutoTokenizer.from_pretrained(config['model_name'])
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token

wandb.watch(gpt2_model, log='all')

gpt2_model.to(device)
gpt2_model_ref.to(device)

class LengthSampler:
def __init__(self, min_value, max_value):
self.values = list(range(min_value, max_value))
def __call__(self):
return np.random.choice(self.values)

input_size = LengthSampler(config["txt_in_min_len"], config["txt_in_max_len"])
output_size = LengthSampler(config["txt_out_min_len"], config["txt_out_max_len"])

def tokenize(sample):
sample["tokens"] = gpt2_tokenizer.encode(sample["review"])[:input_size()]
sample["query"] = gpt2_tokenizer.decode(sample["tokens"])
return sample

ds = ds.map(tokenize, batched=False)

gen_kwargs = {
"min_length":-1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": gpt2_tokenizer.eos_token_id
}

def collater(data):
return dict((key, [d[key] for d in data]) for key in data[0])

dataloader = torch.utils.data.DataLoader(ds, batch_size=config['batch_size'], collate_fn=collater)

ppo_trainer = PPOTrainer(gpt2_model, gpt2_model_ref, gpt2_tokenizer, **config)

total_ppo_epochs = int(np.ceil(config["steps"]/config['batch_size']))

for epoch, batch in tqdm(zip(range(total_ppo_epochs), iter(dataloader))):
logs, timing = dict(), dict()
t0 = time.time()
query_tensors = [torch.tensor(t).long().to(device) for t in batch["tokens"]]

#### Get response from gpt2
t = time.time()
response_tensors = []
for i in range(config['batch_size']):
gen_len = output_size()
response = gpt2_model.generate(query_tensors[i].unsqueeze(dim=0),
max_new_tokens=gen_len, **gen_kwargs)
response_tensors.append(response.squeeze()[-gen_len:])
batch['response'] = [gpt2_tokenizer.decode(r.squeeze()) for r in response_tensors]
timing['time/get_response'] = time.time()-t

#### Compute sentiment score
t = time.time()
texts = [q + r for q,r in zip(batch['query'], batch['response'])]
pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
rewards = torch.tensor([output[1]["score"] for output in pipe_outputs]).to(device)
timing['time/get_sentiment_preds'] = time.time()-t

#### Run PPO step
t = time.time()
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
timing['time/optimization'] = time.time()-t

#### Log everything
timing['time/epoch'] = time.time()-t0
table_rows = [list(r) for r in zip(batch['query'], batch['response'], rewards.cpu().tolist())]
logs.update({'game_log': wandb.Table(columns=['query', 'response', 'reward'], rows=table_rows)})
logs.update(timing)
logs.update(stats)
logs['env/reward_mean'] = torch.mean(rewards).cpu().numpy()
logs['env/reward_std'] = torch.std(rewards).cpu().numpy()
logs['env/reward_dist'] = rewards.cpu().numpy()
wandb.log(logs)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ nbdev==0.2.16
datasets==1.17.0
torch>=1.4.0
tqdm
transformers==4.15.0
transformers
wandb==0.10.20
matplotlib==3.5.1
34 changes: 0 additions & 34 deletions trl/_nbdev.py

This file was deleted.

12 changes: 1 addition & 11 deletions trl/core.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/00-core.ipynb (unless otherwise specified).

__all__ = ['WANDB_PADDING', 'flatten_dict', 'stack_dicts', 'add_suffix', 'pad_to_size', 'logprobs_from_logits',
'whiten', 'clip_by_value', 'entropy_from_logits', 'average_torch_dicts', 'stats_to_np', 'listify_batch',
'build_bert_batch_from_txt']

# Cell
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

import collections
import numpy as np

# Cell

WANDB_PADDING = -1

# Cell

def flatten_dict(nested, sep='/'):
"""Flatten dictionary and concatenate nested keys with separator."""
Expand Down Expand Up @@ -43,7 +35,6 @@ def add_suffix(input_dict, suffix):
"""Add suffix to dict keys."""
return dict((k + suffix, v) for k,v in input_dict.items())

# Cell

def pad_to_size(tensor, size, dim=1, padding=50256):
"""Pad tensor to size."""
Expand Down Expand Up @@ -108,7 +99,6 @@ def listify_batch(tensor):
"""Turns the first dimension of a tensor into a list."""
return [tensor[i] for i in range(tensor.shape[0])]

# Cell

def build_bert_batch_from_txt(text_list, tokenizer, device):
"""Create token id and attention mask tensors from text list for BERT classification."""
Expand Down
33 changes: 31 additions & 2 deletions trl/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def forward(
return_dict=False,
output_attentions=False,
output_hidden_states=False,
use_cache=True,
):
loss=None
transformer_outputs = self.transformer(
Expand All @@ -114,6 +115,7 @@ def forward(
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
)

hidden_states = transformer_outputs[0]
Expand All @@ -123,7 +125,7 @@ def forward(


if not return_dict:
outputs = (lm_logits,) + transformer_outputs[1:] + (value,)
outputs = (lm_logits, loss, value,)
return outputs

return CausalLMOutputWithCrossAttentions(
Expand All @@ -135,7 +137,34 @@ def forward(
cross_attentions=transformer_outputs.cross_attentions,
value=value,
)
return outputs

def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)

attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)

if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}

# Cell

Expand Down

0 comments on commit dfb864d

Please sign in to comment.