Skip to content

Commit

Permalink
Dynamic input sizes (huggingface#35)
Browse files Browse the repository at this point in the history
* change ppo input from tensor to list of tensors for varying shapes

* update readme example with new input type

* update docs

* add listification of tensors need for new API

* replace nans in tensors for wandb compatibility

* add `listify_batch` helper function for backwards compatibility

* update sentiment example with new api

* update docs

* update library

* ignore wandb artifacts

* update requirements

* run experiment

* replace respond to batch with generate

* add experiment

* update docs

* fix action

* fix action
  • Loading branch information
lvwerra authored May 15, 2022
1 parent 5410be6 commit 52910d3
Show file tree
Hide file tree
Showing 24 changed files with 1,329 additions and 950 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ jobs:
if [ -n "$(nbdev_diff_nbs)" ]; then echo -e "!!! Detected difference between the notebooks and the library"; false; fi
- name: Run tests
run: |
nbdev_test_nbs
nbdev_test_nbs --fname 'nbs/[!03|!04|!05|]*.ipynb'
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,5 @@ checklink/cookies.txt
# .gitconfig is now autogenerated
.gitconfig


nbs/wandb/
46 changes: 25 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

## What is it?
With `trl` you can train transformer language models with Proximal Policy Optimization (PPO). The library is built with the `transformer` library by 🤗 Hugging Face ([link](https://github.com/huggingface/transformers)). Therefore, pre-trained language models can be directly loaded via the transformer interface. At this point only GTP2 is implemented.
With `trl` you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the [`transformer`](https://github.com/huggingface/transformers) library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via `transformers`. At this point only decoder architectures such as GTP2 are implemented.

**Highlights:**
- GPT2 model with a value head: A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
- PPOTrainer: A PPO trainer for language models that just needs (query, response, reward) triplets to optimise the language model.
- GPT2 model with a value head: A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
- Example: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier.

## How it works
Expand All @@ -29,27 +29,29 @@ This process is illustrated in the sketch below:

### Python package
Install the library with pip:
```bash
pip install trl
```

`pip install trl`

### Repository
### From source
If you want to run the examples in the repository a few additional libraries are required. Clone the repository and install it with pip:

`git clone https://github.com/lvwerra/trl.git`

`cd tlr/`

`pip install -r requirements.txt`

```bash
git clone https://github.com/lvwerra/trl.git
cd tlr/
pip install -r requirements.txt
```
### Jupyter notebooks

If you run Jupyter notebooks you might need to run the following:

`jupyter nbextension enable --py --sys-prefix widgetsnbextension`
```bash
jupyter nbextension enable --py --sys-prefix widgetsnbextension
```

For Jupyterlab additionally this command:

`jupyter labextension install @jupyter-widgets/jupyterlab-manager`
```bash
jupyter labextension install @jupyter-widgets/jupyterlab-manager
```

## How to use

Expand All @@ -70,7 +72,7 @@ gpt2_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# initialize trainer
ppo_config = {'batch_size': 1, 'forward_batch_size': 1}
ppo_trainer = PPOTrainer(gpt2_model, gpt2_model_ref, **ppo_config)
ppo_trainer = PPOTrainer(gpt2_model, gpt2_model_ref, gpt2_tokenizer, **ppo_config)
# encode a query
query_txt = "This morning I went to the "
Expand All @@ -82,14 +84,14 @@ response_txt = gpt2_tokenizer.decode(response_tensor[0,:])
# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = torch.tensor([1.0])
reward = [torch.tensor(1.0)]
# train model with ppo
train_stats = ppo_trainer.step(query_tensor, response_tensor, reward)
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
```

### Advanced example: IMDB sentiment
For a detailed example check out the notebook *Tune GPT2 to generate positive reviews*, where GPT2 is fine-tuned to generate positive movie reviews. An few examples from the language models before and after optimisation are given below:
For a detailed example check out the notebook `04-gpt2-sentiment-ppo-training.ipynb`, where GPT2 is fine-tuned to generate positive movie reviews. An few examples from the language models before and after optimisation are given below:

<div style="text-align: center">
<img src="nbs/images/table_imdb_preview.png" width="800">
Expand All @@ -104,8 +106,10 @@ This library is built with `nbdev` and as such all the library code as well as e
- `00-core.ipynb`: Contains the utility functions used throughout the library and examples.
- `01-gpt2-with-value-head.ipynb`: Implementation of a `transformer` compatible GPT2 model with an additional value head as well as a function to generate sequences.
- `02-ppo.ipynb`: Implementation of the PPOTrainer used to train language models.
- `03-bert-imdb-training.ipynb`: Training of BERT with `simpletransformers` to classify sentiment on the IMDB dataset.
- `03-bert-imdb-training.ipynb`: Training of DistilBERT to classify sentiment on the IMDB dataset.
- `04-gpt2-sentiment-ppo-training.ipynb`: Fine-tune GPT2 with the BERT sentiment classifier to produce positive movie reviews.

Currently using `trl==0.0.3`:
- `05-gpt2-sentiment-control.ipynb`: Fine-tune GPT2 with the BERT sentiment classifier to produce movie reviews with controlled sentiment.

## References
Expand All @@ -114,4 +118,4 @@ This library is built with `nbdev` and as such all the library code as well as e
The PPO implementation largely follows the structure introduced in the paper **"Fine-Tuning Language Models from Human Preferences"** by D. Ziegler et al. \[[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)].

### Language models
The language models utilize the `transformer` library by 🤗Hugging Face.
The language models utilize the `transformers` library by 🤗 Hugging Face.
66 changes: 54 additions & 12 deletions docs/00-core.html
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,19 @@

<div class="cell border-box-sizing code_cell rendered">

</div>
{% endraw %}

<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<h2 id="Constants">Constants<a class="anchor-link" href="#Constants"> </a></h2>
</div>
</div>
</div>
{% raw %}

<div class="cell border-box-sizing code_cell rendered">

</div>
{% endraw %}

Expand Down Expand Up @@ -66,7 +79,7 @@ <h2 id="General-utils">General utils<a class="anchor-link" href="#General-utils"
<span class="n">results</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">stats_dicts</span><span class="p">[</span><span class="mi">0</span><span class="p">]:</span>
<span class="n">stats_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">d</span><span class="p">[</span><span class="n">k</span><span class="p">])</span> <span class="k">for</span> <span class="n">d</span> <span class="ow">in</span> <span class="n">stats_dicts</span><span class="p">]</span>
<span class="n">results</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">stats_list</span><span class="p">)</span>
<span class="n">results</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">pad_sequence</span><span class="p">(</span><span class="n">stats_list</span><span class="p">,</span> <span class="n">batch_first</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">padding_value</span><span class="o">=</span><span class="n">WANDB_PADDING</span><span class="p">)</span>
<span class="k">return</span> <span class="n">results</span>

<span class="k">def</span> <span class="nf">add_suffix</span><span class="p">(</span><span class="n">input_dict</span><span class="p">,</span> <span class="n">suffix</span><span class="p">):</span>
Expand All @@ -92,7 +105,7 @@ <h2 id="General-utils">General utils<a class="anchor-link" href="#General-utils"


<div class="output_markdown rendered_html output_subarea ">
<h4 id="flatten_dict" class="doc_header"><code>flatten_dict</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L14" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>flatten_dict</code>(<strong><code>nested</code></strong>, <strong><code>sep</code></strong>=<em><code>'/'</code></em>)</p>
<h4 id="flatten_dict" class="doc_header"><code>flatten_dict</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L20" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>flatten_dict</code>(<strong><code>nested</code></strong>, <strong><code>sep</code></strong>=<em><code>'/'</code></em>)</p>
</blockquote>
<p>Flatten dictionary and concatenate nested keys with separator.</p>

Expand All @@ -117,7 +130,7 @@ <h4 id="flatten_dict" class="doc_header"><code>flatten_dict</code><a href="https


<div class="output_markdown rendered_html output_subarea ">
<h4 id="stack_dicts" class="doc_header"><code>stack_dicts</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L28" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>stack_dicts</code>(<strong><code>stats_dicts</code></strong>)</p>
<h4 id="stack_dicts" class="doc_header"><code>stack_dicts</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L34" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>stack_dicts</code>(<strong><code>stats_dicts</code></strong>)</p>
</blockquote>
<p>Stack the values of a dict.</p>

Expand All @@ -142,7 +155,7 @@ <h4 id="stack_dicts" class="doc_header"><code>stack_dicts</code><a href="https:/


<div class="output_markdown rendered_html output_subarea ">
<h4 id="add_suffix" class="doc_header"><code>add_suffix</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L36" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>add_suffix</code>(<strong><code>input_dict</code></strong>, <strong><code>suffix</code></strong>)</p>
<h4 id="add_suffix" class="doc_header"><code>add_suffix</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L42" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>add_suffix</code>(<strong><code>input_dict</code></strong>, <strong><code>suffix</code></strong>)</p>
</blockquote>
<p>Add suffix to dict keys.</p>

Expand Down Expand Up @@ -227,6 +240,10 @@ <h2 id="Torch-utils">Torch utils<a class="anchor-link" href="#Torch-utils"> </a>
<span class="k">if</span> <span class="n">np</span><span class="o">.</span><span class="n">isscalar</span><span class="p">(</span><span class="n">new_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]):</span>
<span class="n">new_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="n">new_dict</span><span class="p">[</span><span class="n">k</span><span class="p">])</span>
<span class="k">return</span> <span class="n">new_dict</span>

<span class="k">def</span> <span class="nf">listify_batch</span><span class="p">(</span><span class="n">tensor</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;Turns the first dimension of a tensor into a list.&quot;&quot;&quot;</span>
<span class="k">return</span> <span class="p">[</span><span class="n">tensor</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])]</span>
</pre></div>

</div>
Expand All @@ -247,7 +264,7 @@ <h2 id="Torch-utils">Torch utils<a class="anchor-link" href="#Torch-utils"> </a>


<div class="output_markdown rendered_html output_subarea ">
<h4 id="pad_to_size" class="doc_header"><code>pad_to_size</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L42" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>pad_to_size</code>(<strong><code>tensor</code></strong>, <strong><code>size</code></strong>, <strong><code>dim</code></strong>=<em><code>1</code></em>, <strong><code>padding</code></strong>=<em><code>50256</code></em>)</p>
<h4 id="pad_to_size" class="doc_header"><code>pad_to_size</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L48" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>pad_to_size</code>(<strong><code>tensor</code></strong>, <strong><code>size</code></strong>, <strong><code>dim</code></strong>=<em><code>1</code></em>, <strong><code>padding</code></strong>=<em><code>50256</code></em>)</p>
</blockquote>
<p>Pad tensor to size.</p>

Expand All @@ -272,7 +289,7 @@ <h4 id="pad_to_size" class="doc_header"><code>pad_to_size</code><a href="https:/


<div class="output_markdown rendered_html output_subarea ">
<h4 id="logprobs_from_logits" class="doc_header"><code>logprobs_from_logits</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L50" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>logprobs_from_logits</code>(<strong><code>logits</code></strong>, <strong><code>labels</code></strong>)</p>
<h4 id="logprobs_from_logits" class="doc_header"><code>logprobs_from_logits</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L56" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>logprobs_from_logits</code>(<strong><code>logits</code></strong>, <strong><code>labels</code></strong>)</p>
</blockquote>
<p>See: <a href="https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591">https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591</a></p>

Expand All @@ -297,7 +314,7 @@ <h4 id="logprobs_from_logits" class="doc_header"><code>logprobs_from_logits</cod


<div class="output_markdown rendered_html output_subarea ">
<h4 id="whiten" class="doc_header"><code>whiten</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L59" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>whiten</code>(<strong><code>values</code></strong>, <strong><code>shift_mean</code></strong>=<em><code>True</code></em>)</p>
<h4 id="whiten" class="doc_header"><code>whiten</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L65" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>whiten</code>(<strong><code>values</code></strong>, <strong><code>shift_mean</code></strong>=<em><code>True</code></em>)</p>
</blockquote>
<p>Whiten values.</p>

Expand All @@ -322,7 +339,7 @@ <h4 id="whiten" class="doc_header"><code>whiten</code><a href="https://github.co


<div class="output_markdown rendered_html output_subarea ">
<h4 id="clip_by_value" class="doc_header"><code>clip_by_value</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L67" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>clip_by_value</code>(<strong><code>x</code></strong>, <strong><code>tensor_min</code></strong>, <strong><code>tensor_max</code></strong>)</p>
<h4 id="clip_by_value" class="doc_header"><code>clip_by_value</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L73" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>clip_by_value</code>(<strong><code>x</code></strong>, <strong><code>tensor_min</code></strong>, <strong><code>tensor_max</code></strong>)</p>
</blockquote>
<p>Tensor extenstion to torch.clamp
<a href="https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713">https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713</a></p>
Expand All @@ -348,7 +365,7 @@ <h4 id="clip_by_value" class="doc_header"><code>clip_by_value</code><a href="htt


<div class="output_markdown rendered_html output_subarea ">
<h4 id="entropy_from_logits" class="doc_header"><code>entropy_from_logits</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L75" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>entropy_from_logits</code>(<strong><code>logits</code></strong>)</p>
<h4 id="entropy_from_logits" class="doc_header"><code>entropy_from_logits</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L81" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>entropy_from_logits</code>(<strong><code>logits</code></strong>)</p>
</blockquote>
<p>Calculate entropy from logits.</p>

Expand All @@ -373,7 +390,7 @@ <h4 id="entropy_from_logits" class="doc_header"><code>entropy_from_logits</code>


<div class="output_markdown rendered_html output_subarea ">
<h4 id="average_torch_dicts" class="doc_header"><code>average_torch_dicts</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L82" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>average_torch_dicts</code>(<strong><code>list_of_dicts</code></strong>)</p>
<h4 id="average_torch_dicts" class="doc_header"><code>average_torch_dicts</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L88" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>average_torch_dicts</code>(<strong><code>list_of_dicts</code></strong>)</p>
</blockquote>
<p>Average values of a list of dicts wiht torch tensors.</p>

Expand All @@ -398,7 +415,7 @@ <h4 id="average_torch_dicts" class="doc_header"><code>average_torch_dicts</code>


<div class="output_markdown rendered_html output_subarea ">
<h4 id="stats_to_np" class="doc_header"><code>stats_to_np</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L89" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>stats_to_np</code>(<strong><code>stats_dict</code></strong>)</p>
<h4 id="stats_to_np" class="doc_header"><code>stats_to_np</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L95" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>stats_to_np</code>(<strong><code>stats_dict</code></strong>)</p>
</blockquote>
<p>Cast all torch.tensors in dict to numpy arrays.</p>

Expand All @@ -409,6 +426,31 @@ <h4 id="stats_to_np" class="doc_header"><code>stats_to_np</code><a href="https:/
</div>
</div>

</div>
{% endraw %}

{% raw %}

<div class="cell border-box-sizing code_cell rendered">

<div class="output_wrapper">
<div class="output">

<div class="output_area">


<div class="output_markdown rendered_html output_subarea ">
<h4 id="listify_batch" class="doc_header"><code>listify_batch</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L107" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>listify_batch</code>(<strong><code>tensor</code></strong>)</p>
</blockquote>
<p>Turns the first dimension of a tensor into a list.</p>

</div>

</div>

</div>
</div>

</div>
{% endraw %}

Expand Down Expand Up @@ -468,7 +510,7 @@ <h2 id="BERT-utils">BERT utils<a class="anchor-link" href="#BERT-utils"> </a></h


<div class="output_markdown rendered_html output_subarea ">
<h4 id="build_bert_batch_from_txt" class="doc_header"><code>build_bert_batch_from_txt</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L104" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>build_bert_batch_from_txt</code>(<strong><code>text_list</code></strong>, <strong><code>tokenizer</code></strong>, <strong><code>device</code></strong>)</p>
<h4 id="build_bert_batch_from_txt" class="doc_header"><code>build_bert_batch_from_txt</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L113" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>build_bert_batch_from_txt</code>(<strong><code>text_list</code></strong>, <strong><code>tokenizer</code></strong>, <strong><code>device</code></strong>)</p>
</blockquote>
<p>Create token id and attention mask tensors from text list for BERT classification.</p>

Expand Down
Loading

0 comments on commit 52910d3

Please sign in to comment.