Skip to content

Commit

Permalink
Updates README
Browse files Browse the repository at this point in the history
  • Loading branch information
edbeeching committed Jan 4, 2023
1 parent 8b0b393 commit 12b0e83
Showing 1 changed file with 42 additions and 41 deletions.
83 changes: 42 additions & 41 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ With `trl` you can train transformer language models with Proximal Policy Optimi

**Highlights:**
- PPOTrainer: A PPO trainer for language models that just needs (query, response, reward) triplets to optimise the language model.
- GPT2 model with a value head: A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
- AutoModelForCausalLMWithValueHead: A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
- Example: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier.

## How it works
Expand Down Expand Up @@ -38,19 +38,7 @@ If you want to run the examples in the repository a few additional libraries are
```bash
git clone https://github.com/lvwerra/trl.git
cd trl/
pip install -r requirements.txt
```
### Jupyter notebooks

If you run Jupyter notebooks you might need to run the following:
```bash
jupyter nbextension enable --py --sys-prefix widgetsnbextension
```

For Jupyterlab additionally this command:

```bash
jupyter labextension install @jupyter-widgets/jupyterlab-manager
pip install .
```

## How to use
Expand All @@ -61,57 +49,70 @@ This is a basic example on how to use the library. Based on a query the language
```python
# imports
import torch
from transformers import GPT2Tokenizer
from trl.gpt2 import GPT2HeadWithValueModel, respond_to_batch
from trl.ppo import PPOTrainer
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from trl.core import respond_to_batch

# get models
gpt2_model = GPT2HeadWithValueModel.from_pretrained('gpt2')
gpt2_model_ref = GPT2HeadWithValueModel.from_pretrained('gpt2')
gpt2_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
gpt2_model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
gpt2_tokenizer = AutoTokenizer.from_pretrained('gpt2')

# initialize trainer
ppo_config = {'batch_size': 1, 'forward_batch_size': 1}
ppo_trainer = PPOTrainer(gpt2_model, gpt2_model_ref, gpt2_tokenizer, **ppo_config)
ppo_config = PPOConfig(
batch_size=1,
forward_batch_size=1
)

# encode a query
query_txt = "This morning I went to the "
query_tensor = gpt2_tokenizer.encode(query_txt, return_tensors="pt")

# get model response
response_tensor = respond_to_batch(gpt2_model, query_tensor)
response_tensor = respond_to_batch(gpt2_model_ref, query_tensor)
response_txt = gpt2_tokenizer.decode(response_tensor[0,:])

# create a dummy dataset
class DummyDataset(torch.utils.data.Dataset):
def __init__(self, query_data, response_data):
self.query_data = query_data
self.response_data = response_data

def __len__(self):
return len(self.query_data)

def __getitem__(self, idx):
return self.query_data[idx], self.response_data[idx]

dummy_dataset = torch.utils.data.Dataset()

min_length = min(len(query_tensor[0]), len(response_tensor[0]))

dummy_dataset = DummyDataset(
[query_tensor[:, :min_length].squeeze(0) for _ in range(2)],
[response_tensor[:, :min_length].squeeze(0) for _ in range(2)],
)

# create a ppo trainer
ppo_trainer = PPOTrainer(ppo_config, gpt2_model, gpt2_model_ref, gpt2_tokenizer, dummy_dataset)
device = ppo_trainer.accelerator.device

# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0)]
reward = [torch.tensor(1.0).to(device)]

# train model with ppo
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
# train model for one step with ppo
train_stats = ppo_trainer.step([query_tensor[0].to(device)], [response_tensor[0].to(device)], reward)
```

### Advanced example: IMDB sentiment
For a detailed example check out the notebook `04-gpt2-sentiment-ppo-training.ipynb`, where GPT2 is fine-tuned to generate positive movie reviews. An few examples from the language models before and after optimisation are given below:
For a detailed example check out the example python script `examples/scripts/ppo-sentiment.py`, where GPT2 is fine-tuned to generate positive movie reviews. An few examples from the language models before and after optimisation are given below:

<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/table_imdb_preview.png" width="800">
<p style="text-align: center;"> <b>Figure:</b> A few review continuations before and after optimisation. </p>
</div>


## Notebooks
This library is built with `nbdev` and as such all the library code as well as examples are in Jupyter notebooks. The following list gives an overview:

- `index.ipynb`: Generates the README and the overview page.
- `00-core.ipynb`: Contains the utility functions used throughout the library and examples.
- `01-gpt2-with-value-head.ipynb`: Implementation of a `transformer` compatible GPT2 model with an additional value head as well as a function to generate sequences.
- `02-ppo.ipynb`: Implementation of the PPOTrainer used to train language models.
- `03-bert-imdb-training.ipynb`: Training of DistilBERT to classify sentiment on the IMDB dataset.
- `04-gpt2-sentiment-ppo-training.ipynb`: Fine-tune GPT2 with the BERT sentiment classifier to produce positive movie reviews.

Currently using `trl==0.0.3`:
- `05-gpt2-sentiment-control.ipynb`: Fine-tune GPT2 with the BERT sentiment classifier to produce movie reviews with controlled sentiment.

## References

### Proximal Policy Optimisation
Expand Down

0 comments on commit 12b0e83

Please sign in to comment.