Skip to content

Commit

Permalink
docs: update html documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
leandro committed Mar 29, 2020
1 parent 275e624 commit c2e5bfe
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 154 deletions.
39 changes: 23 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# lm_ppo: Language modeling with PPO
> A Pytorch implementation of Proximal Policy Optimization for transfomer language models.
# Welcome to lm_ppo
> Train transformer language models with Reinforcement Learning.

## What is it?
The library `lm_ppo` one can fine-tune transformer language models with Proximal Policy Optimization (PPO). The library is built with the `transformer` library by 🤗Huggingface. Therefore, one can load pre-trained language models directly via the transformer interface.
With `lm_ppo` you can train transformer language models with Proximal Policy Optimization (PPO). The library is built with the `transformer` library by 🤗Huggingface. Therefore, pre-trained language models can be directly loaded via the transformer interface. At this point only GTP2 is implemented.

**Highlights:**
- GPT2 model with a value head: A transformer model with an additional scalar output for each token which can be used as a value function in Reinforcement Learning.
- PPOTrainer: A PPO trainer for language models that just needs (query, response, reward) triplets to optimise the language model.

## How it works
Fine-tuning a language model via PPO consists of roughly three steps:
Expand All @@ -14,7 +18,11 @@ Fine-tuning a language model via PPO consists of roughly three steps:

This process is illustrated in the sketch below:

![Overview](nbs/images/lm_ppo_overview.png)

<div style="text-align: center">
<img src="nbs/images/lm_ppo_overview.png" width="800">
<p style="text-align: center;"> <b>Figure:</b> Sketch of the workflow. </p>
</div>

## Install

Expand All @@ -28,7 +36,7 @@ If you want to run the example a few additional libraries are required. Clone th

## How to use

### Basic example
### Example
This is a basic example on how to use the library. Based on a query the language model creates a response which is then evaluated. The evaluation could be a human in the loop or another model's output.

```
Expand All @@ -52,27 +60,26 @@ query_txt = "This morning I went to the "
query_tensor = gpt2_tokenizer.encode(query_txt, return_tensors="pt")
# get model response
response_tensor = respond_to_batch(gpt2_model, query_tensor, pad_token_id=gpt2_tokenizer.eos_token_id)
response_tensor = respond_to_batch(gpt2_model, query_tensor,
pad_token_id=gpt2_tokenizer.eos_token_id)
response_txt = gpt2_tokenizer.decode(response_tensor[0,:])
# define a reward for response
reward = torch.tensor([1.0]) # this could be any reward such as a human or another model
# (this could be any reward such as human feedback or output from another model)
reward = torch.tensor([1.0])
# train model with ppo
train_stats = ppo_trainer.step(query_tensor, response_tensor, reward)
```




2



### Advanced example: IMDB sentiment
For a detailed example check out the notebook `nbs/04-gpt2-sentiment-training.ipynb`, where GPT2 is fine-tuned to generate positive movie reviews. An few examples from the language models before and after optimisation are given below:
For a detailed example check out the notebook *Tune GPT2 to generate positive reviews*, where GPT2 is fine-tuned to generate positive movie reviews. An few examples from the language models before and after optimisation are given below:

<div style="text-align: center">
<img src="nbs/images/table_imdb_preview.png" width="800">
<p style="text-align: center;"> <b>Figure:</b> A few review continuations before and after optimisation. </p>
</div>

![Overview](nbs/images/table_imdb_preview.png)

## Reference

Expand Down
29 changes: 14 additions & 15 deletions docs/00-core.html
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
---

title: Title
title: Utility functions

keywords: fastai
sidebar: home_sidebar

summary: "summary"
description: "summary"
summary: "A set of utility functions used throughout the library."
description: "A set of utility functions used throughout the library."
---
<!--
Expand Down Expand Up @@ -92,7 +92,7 @@ <h2 id="General-utils">General utils<a class="anchor-link" href="#General-utils"


<div class="output_markdown rendered_html output_subarea ">
<h4 id="flatten_dict" class="doc_header"><code>flatten_dict</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L15" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>flatten_dict</code>(<strong><code>nested</code></strong>, <strong><code>sep</code></strong>=<em><code>'/'</code></em>)</p>
<h4 id="flatten_dict" class="doc_header"><code>flatten_dict</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L14" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>flatten_dict</code>(<strong><code>nested</code></strong>, <strong><code>sep</code></strong>=<em><code>'/'</code></em>)</p>
</blockquote>
<p>Flatten dictionary and concatenate nested keys with separator.</p>

Expand All @@ -117,7 +117,7 @@ <h4 id="flatten_dict" class="doc_header"><code>flatten_dict</code><a href="https


<div class="output_markdown rendered_html output_subarea ">
<h4 id="stack_dicts" class="doc_header"><code>stack_dicts</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L29" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>stack_dicts</code>(<strong><code>stats_dicts</code></strong>)</p>
<h4 id="stack_dicts" class="doc_header"><code>stack_dicts</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L28" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>stack_dicts</code>(<strong><code>stats_dicts</code></strong>)</p>
</blockquote>
<p>Stack the values of a dict.</p>

Expand All @@ -142,7 +142,7 @@ <h4 id="stack_dicts" class="doc_header"><code>stack_dicts</code><a href="https:/


<div class="output_markdown rendered_html output_subarea ">
<h4 id="add_suffix" class="doc_header"><code>add_suffix</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L37" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>add_suffix</code>(<strong><code>input_dict</code></strong>, <strong><code>suffix</code></strong>)</p>
<h4 id="add_suffix" class="doc_header"><code>add_suffix</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L36" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>add_suffix</code>(<strong><code>input_dict</code></strong>, <strong><code>suffix</code></strong>)</p>
</blockquote>
<p>Add suffix to dict keys.</p>

Expand Down Expand Up @@ -223,7 +223,6 @@ <h2 id="Torch-utils">Torch utils<a class="anchor-link" href="#Torch-utils"> </a>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="n">new_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span>
<span class="n">new_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span>
<span class="k">if</span> <span class="n">np</span><span class="o">.</span><span class="n">isscalar</span><span class="p">(</span><span class="n">new_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]):</span>
<span class="n">new_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="n">new_dict</span><span class="p">[</span><span class="n">k</span><span class="p">])</span>
Expand All @@ -248,7 +247,7 @@ <h2 id="Torch-utils">Torch utils<a class="anchor-link" href="#Torch-utils"> </a>


<div class="output_markdown rendered_html output_subarea ">
<h4 id="pad_to_size" class="doc_header"><code>pad_to_size</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L43" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>pad_to_size</code>(<strong><code>tensor</code></strong>, <strong><code>size</code></strong>, <strong><code>dim</code></strong>=<em><code>1</code></em>, <strong><code>padding</code></strong>=<em><code>50256</code></em>)</p>
<h4 id="pad_to_size" class="doc_header"><code>pad_to_size</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L42" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>pad_to_size</code>(<strong><code>tensor</code></strong>, <strong><code>size</code></strong>, <strong><code>dim</code></strong>=<em><code>1</code></em>, <strong><code>padding</code></strong>=<em><code>50256</code></em>)</p>
</blockquote>
<p>Pad tensor to size.</p>

Expand All @@ -273,7 +272,7 @@ <h4 id="pad_to_size" class="doc_header"><code>pad_to_size</code><a href="https:/


<div class="output_markdown rendered_html output_subarea ">
<h4 id="logprobs_from_logits" class="doc_header"><code>logprobs_from_logits</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L51" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>logprobs_from_logits</code>(<strong><code>logits</code></strong>, <strong><code>labels</code></strong>)</p>
<h4 id="logprobs_from_logits" class="doc_header"><code>logprobs_from_logits</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L50" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>logprobs_from_logits</code>(<strong><code>logits</code></strong>, <strong><code>labels</code></strong>)</p>
</blockquote>
<p>See: <a href="https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591">https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591</a></p>

Expand All @@ -298,7 +297,7 @@ <h4 id="logprobs_from_logits" class="doc_header"><code>logprobs_from_logits</cod


<div class="output_markdown rendered_html output_subarea ">
<h4 id="whiten" class="doc_header"><code>whiten</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L60" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>whiten</code>(<strong><code>values</code></strong>, <strong><code>shift_mean</code></strong>=<em><code>True</code></em>)</p>
<h4 id="whiten" class="doc_header"><code>whiten</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L59" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>whiten</code>(<strong><code>values</code></strong>, <strong><code>shift_mean</code></strong>=<em><code>True</code></em>)</p>
</blockquote>
<p>Whiten values.</p>

Expand All @@ -323,7 +322,7 @@ <h4 id="whiten" class="doc_header"><code>whiten</code><a href="https://github.co


<div class="output_markdown rendered_html output_subarea ">
<h4 id="clip_by_value" class="doc_header"><code>clip_by_value</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L68" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>clip_by_value</code>(<strong><code>x</code></strong>, <strong><code>tensor_min</code></strong>, <strong><code>tensor_max</code></strong>)</p>
<h4 id="clip_by_value" class="doc_header"><code>clip_by_value</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L67" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>clip_by_value</code>(<strong><code>x</code></strong>, <strong><code>tensor_min</code></strong>, <strong><code>tensor_max</code></strong>)</p>
</blockquote>
<p>Tensor extenstion to torch.clamp
<a href="https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713">https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713</a></p>
Expand All @@ -349,7 +348,7 @@ <h4 id="clip_by_value" class="doc_header"><code>clip_by_value</code><a href="htt


<div class="output_markdown rendered_html output_subarea ">
<h4 id="entropy_from_logits" class="doc_header"><code>entropy_from_logits</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L76" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>entropy_from_logits</code>(<strong><code>logits</code></strong>)</p>
<h4 id="entropy_from_logits" class="doc_header"><code>entropy_from_logits</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L75" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>entropy_from_logits</code>(<strong><code>logits</code></strong>)</p>
</blockquote>
<p>Calculate entropy from logits.</p>

Expand All @@ -374,7 +373,7 @@ <h4 id="entropy_from_logits" class="doc_header"><code>entropy_from_logits</code>


<div class="output_markdown rendered_html output_subarea ">
<h4 id="average_torch_dicts" class="doc_header"><code>average_torch_dicts</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L83" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>average_torch_dicts</code>(<strong><code>list_of_dicts</code></strong>)</p>
<h4 id="average_torch_dicts" class="doc_header"><code>average_torch_dicts</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L82" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>average_torch_dicts</code>(<strong><code>list_of_dicts</code></strong>)</p>
</blockquote>
<p>Average values of a list of dicts wiht torch tensors.</p>

Expand All @@ -399,7 +398,7 @@ <h4 id="average_torch_dicts" class="doc_header"><code>average_torch_dicts</code>


<div class="output_markdown rendered_html output_subarea ">
<h4 id="stats_to_np" class="doc_header"><code>stats_to_np</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L90" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>stats_to_np</code>(<strong><code>stats_dict</code></strong>)</p>
<h4 id="stats_to_np" class="doc_header"><code>stats_to_np</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L89" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>stats_to_np</code>(<strong><code>stats_dict</code></strong>)</p>
</blockquote>
<p>Cast all torch.tensors in dict to numpy arrays.</p>

Expand Down Expand Up @@ -469,7 +468,7 @@ <h2 id="BERT-utils">BERT utils<a class="anchor-link" href="#BERT-utils"> </a></h


<div class="output_markdown rendered_html output_subarea ">
<h4 id="build_bert_batch_from_txt" class="doc_header"><code>build_bert_batch_from_txt</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L106" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>build_bert_batch_from_txt</code>(<strong><code>text_list</code></strong>, <strong><code>tokenizer</code></strong>, <strong><code>device</code></strong>)</p>
<h4 id="build_bert_batch_from_txt" class="doc_header"><code>build_bert_batch_from_txt</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L104" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>build_bert_batch_from_txt</code>(<strong><code>text_list</code></strong>, <strong><code>tokenizer</code></strong>, <strong><code>device</code></strong>)</p>
</blockquote>
<p>Create token id and attention mask tensors from text list for BERT classification.</p>

Expand Down
4 changes: 2 additions & 2 deletions docs/02-ppo.html
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ <h2 id="FixedKLController" class="doc_header"><code>class</code> <code>FixedKLCo
<div class="input_area">
<div class=" highlight hl-ipython3"><pre><span></span><span class="k">class</span> <span class="nc">PPOTrainer</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> The PPO_trainer uses Proximal Policy Optimization to tune a language model.</span>
<span class="sd"> The PPO_trainer uses Proximal Policy Optimization to optimise language models.</span>
<span class="sd"> &quot;&quot;&quot;</span>

<span class="n">default_params</span> <span class="o">=</span> <span class="p">{</span>
Expand Down Expand Up @@ -417,7 +417,7 @@ <h2 id="FixedKLController" class="doc_header"><code>class</code> <code>FixedKLCo
<div class="output_markdown rendered_html output_subarea ">
<h2 id="PPOTrainer" class="doc_header"><code>class</code> <code>PPOTrainer</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/ppo.py#L54" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>PPOTrainer</code>(<strong><code>model</code></strong>, <strong><code>ref_model</code></strong>, <strong>**<code>ppo_params</code></strong>)</p>
</blockquote>
<p>The PPO_trainer uses Proximal Policy Optimization to tune a language model.</p>
<p>The PPO_trainer uses Proximal Policy Optimization to optimise language models.</p>

</div>

Expand Down
2 changes: 1 addition & 1 deletion docs/03-bert-imdb-training.html
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<p>We use the <code>simpletransformers</code> library to train BERT for sentiment classification.</p>
<p>We use the <code>simpletransformers</code> library to train BERT (large) for sentiment classification on the IMDB dataset.</p>

</div>
</div>
Expand Down
Loading

0 comments on commit c2e5bfe

Please sign in to comment.