Skip to content

Commit

Permalink
chore: rename repo
Browse files Browse the repository at this point in the history
  • Loading branch information
leandro committed Mar 30, 2020
1 parent d4122bc commit df428af
Show file tree
Hide file tree
Showing 25 changed files with 130 additions and 58 deletions.
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
SRC = $(wildcard ./nbs//*.ipynb)

all: lm_ppo docs
all: trl docs

lm_ppo: $(SRC)
trl: $(SRC)
nbdev_build_lib
touch lm_ppo
touch trl

docs_serve: docs
cd docs && bundle exec jekyll serve
Expand Down
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Welcome to lm_ppo
# Welcome to trl
> Train transformer language models with Reinforcement Learning.

## What is it?
With `lm_ppo` you can train transformer language models with Proximal Policy Optimization (PPO). The library is built with the `transformer` library by 🤗Huggingface. Therefore, pre-trained language models can be directly loaded via the transformer interface. At this point only GTP2 is implemented.
With `trl` you can train transformer language models with Proximal Policy Optimization (PPO). The library is built with the `transformer` library by 🤗Huggingface. Therefore, pre-trained language models can be directly loaded via the transformer interface. At this point only GTP2 is implemented.

**Highlights:**
- GPT2 model with a value head: A transformer model with an additional scalar output for each token which can be used as a value function in Reinforcement Learning.
Expand All @@ -20,7 +20,7 @@ This process is illustrated in the sketch below:


<div style="text-align: center">
<img src="nbs/images/lm_ppo_overview.png" width="800">
<img src="nbs/images/trl_overview.png" width="800">
<p style="text-align: center;"> <b>Figure:</b> Sketch of the workflow. </p>
</div>

Expand All @@ -29,7 +29,7 @@ This process is illustrated in the sketch below:
### Python package
Install the library with pip:

`pip install lm_ppo`
`pip install trl`

### Repository
If you want to run the examples in the repository a few additional libraries are required. Clone the repository and install it with pip:
Expand Down Expand Up @@ -59,8 +59,8 @@ This is a basic example on how to use the library. Based on a query the language
# imports
import torch
from transformers import GPT2Tokenizer
from lm_ppo.gpt2 import GPT2HeadWithValueModel, respond_to_batch
from lm_ppo.ppo import PPOTrainer
from trl.gpt2 import GPT2HeadWithValueModel, respond_to_batch
from trl.ppo import PPOTrainer
# get models
gpt2_model = GPT2HeadWithValueModel.from_pretrained('gpt2')
Expand Down
22 changes: 11 additions & 11 deletions docs/00-core.html
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ <h2 id="General-utils">General utils<a class="anchor-link" href="#General-utils"


<div class="output_markdown rendered_html output_subarea ">
<h4 id="flatten_dict" class="doc_header"><code>flatten_dict</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L14" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>flatten_dict</code>(<strong><code>nested</code></strong>, <strong><code>sep</code></strong>=<em><code>'/'</code></em>)</p>
<h4 id="flatten_dict" class="doc_header"><code>flatten_dict</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L14" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>flatten_dict</code>(<strong><code>nested</code></strong>, <strong><code>sep</code></strong>=<em><code>'/'</code></em>)</p>
</blockquote>
<p>Flatten dictionary and concatenate nested keys with separator.</p>

Expand All @@ -117,7 +117,7 @@ <h4 id="flatten_dict" class="doc_header"><code>flatten_dict</code><a href="https


<div class="output_markdown rendered_html output_subarea ">
<h4 id="stack_dicts" class="doc_header"><code>stack_dicts</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L28" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>stack_dicts</code>(<strong><code>stats_dicts</code></strong>)</p>
<h4 id="stack_dicts" class="doc_header"><code>stack_dicts</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L28" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>stack_dicts</code>(<strong><code>stats_dicts</code></strong>)</p>
</blockquote>
<p>Stack the values of a dict.</p>

Expand All @@ -142,7 +142,7 @@ <h4 id="stack_dicts" class="doc_header"><code>stack_dicts</code><a href="https:/


<div class="output_markdown rendered_html output_subarea ">
<h4 id="add_suffix" class="doc_header"><code>add_suffix</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L36" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>add_suffix</code>(<strong><code>input_dict</code></strong>, <strong><code>suffix</code></strong>)</p>
<h4 id="add_suffix" class="doc_header"><code>add_suffix</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L36" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>add_suffix</code>(<strong><code>input_dict</code></strong>, <strong><code>suffix</code></strong>)</p>
</blockquote>
<p>Add suffix to dict keys.</p>

Expand Down Expand Up @@ -247,7 +247,7 @@ <h2 id="Torch-utils">Torch utils<a class="anchor-link" href="#Torch-utils"> </a>


<div class="output_markdown rendered_html output_subarea ">
<h4 id="pad_to_size" class="doc_header"><code>pad_to_size</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L42" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>pad_to_size</code>(<strong><code>tensor</code></strong>, <strong><code>size</code></strong>, <strong><code>dim</code></strong>=<em><code>1</code></em>, <strong><code>padding</code></strong>=<em><code>50256</code></em>)</p>
<h4 id="pad_to_size" class="doc_header"><code>pad_to_size</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L42" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>pad_to_size</code>(<strong><code>tensor</code></strong>, <strong><code>size</code></strong>, <strong><code>dim</code></strong>=<em><code>1</code></em>, <strong><code>padding</code></strong>=<em><code>50256</code></em>)</p>
</blockquote>
<p>Pad tensor to size.</p>

Expand All @@ -272,7 +272,7 @@ <h4 id="pad_to_size" class="doc_header"><code>pad_to_size</code><a href="https:/


<div class="output_markdown rendered_html output_subarea ">
<h4 id="logprobs_from_logits" class="doc_header"><code>logprobs_from_logits</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L50" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>logprobs_from_logits</code>(<strong><code>logits</code></strong>, <strong><code>labels</code></strong>)</p>
<h4 id="logprobs_from_logits" class="doc_header"><code>logprobs_from_logits</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L50" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>logprobs_from_logits</code>(<strong><code>logits</code></strong>, <strong><code>labels</code></strong>)</p>
</blockquote>
<p>See: <a href="https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591">https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591</a></p>

Expand All @@ -297,7 +297,7 @@ <h4 id="logprobs_from_logits" class="doc_header"><code>logprobs_from_logits</cod


<div class="output_markdown rendered_html output_subarea ">
<h4 id="whiten" class="doc_header"><code>whiten</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L59" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>whiten</code>(<strong><code>values</code></strong>, <strong><code>shift_mean</code></strong>=<em><code>True</code></em>)</p>
<h4 id="whiten" class="doc_header"><code>whiten</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L59" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>whiten</code>(<strong><code>values</code></strong>, <strong><code>shift_mean</code></strong>=<em><code>True</code></em>)</p>
</blockquote>
<p>Whiten values.</p>

Expand All @@ -322,7 +322,7 @@ <h4 id="whiten" class="doc_header"><code>whiten</code><a href="https://github.co


<div class="output_markdown rendered_html output_subarea ">
<h4 id="clip_by_value" class="doc_header"><code>clip_by_value</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L67" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>clip_by_value</code>(<strong><code>x</code></strong>, <strong><code>tensor_min</code></strong>, <strong><code>tensor_max</code></strong>)</p>
<h4 id="clip_by_value" class="doc_header"><code>clip_by_value</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L67" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>clip_by_value</code>(<strong><code>x</code></strong>, <strong><code>tensor_min</code></strong>, <strong><code>tensor_max</code></strong>)</p>
</blockquote>
<p>Tensor extenstion to torch.clamp
<a href="https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713">https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713</a></p>
Expand All @@ -348,7 +348,7 @@ <h4 id="clip_by_value" class="doc_header"><code>clip_by_value</code><a href="htt


<div class="output_markdown rendered_html output_subarea ">
<h4 id="entropy_from_logits" class="doc_header"><code>entropy_from_logits</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L75" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>entropy_from_logits</code>(<strong><code>logits</code></strong>)</p>
<h4 id="entropy_from_logits" class="doc_header"><code>entropy_from_logits</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L75" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>entropy_from_logits</code>(<strong><code>logits</code></strong>)</p>
</blockquote>
<p>Calculate entropy from logits.</p>

Expand All @@ -373,7 +373,7 @@ <h4 id="entropy_from_logits" class="doc_header"><code>entropy_from_logits</code>


<div class="output_markdown rendered_html output_subarea ">
<h4 id="average_torch_dicts" class="doc_header"><code>average_torch_dicts</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L82" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>average_torch_dicts</code>(<strong><code>list_of_dicts</code></strong>)</p>
<h4 id="average_torch_dicts" class="doc_header"><code>average_torch_dicts</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L82" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>average_torch_dicts</code>(<strong><code>list_of_dicts</code></strong>)</p>
</blockquote>
<p>Average values of a list of dicts wiht torch tensors.</p>

Expand All @@ -398,7 +398,7 @@ <h4 id="average_torch_dicts" class="doc_header"><code>average_torch_dicts</code>


<div class="output_markdown rendered_html output_subarea ">
<h4 id="stats_to_np" class="doc_header"><code>stats_to_np</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L89" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>stats_to_np</code>(<strong><code>stats_dict</code></strong>)</p>
<h4 id="stats_to_np" class="doc_header"><code>stats_to_np</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L89" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>stats_to_np</code>(<strong><code>stats_dict</code></strong>)</p>
</blockquote>
<p>Cast all torch.tensors in dict to numpy arrays.</p>

Expand Down Expand Up @@ -468,7 +468,7 @@ <h2 id="BERT-utils">BERT utils<a class="anchor-link" href="#BERT-utils"> </a></h


<div class="output_markdown rendered_html output_subarea ">
<h4 id="build_bert_batch_from_txt" class="doc_header"><code>build_bert_batch_from_txt</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/core.py#L104" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>build_bert_batch_from_txt</code>(<strong><code>text_list</code></strong>, <strong><code>tokenizer</code></strong>, <strong><code>device</code></strong>)</p>
<h4 id="build_bert_batch_from_txt" class="doc_header"><code>build_bert_batch_from_txt</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L104" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>build_bert_batch_from_txt</code>(<strong><code>text_list</code></strong>, <strong><code>tokenizer</code></strong>, <strong><code>device</code></strong>)</p>
</blockquote>
<p>Create token id and attention mask tensors from text list for BERT classification.</p>

Expand Down
6 changes: 3 additions & 3 deletions docs/01-gpt2-with-value-head.html
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ <h2 id="Detach-head">Detach head<a class="anchor-link" href="#Detach-head"> </a>


<div class="output_markdown rendered_html output_subarea ">
<h2 id="ValueHead" class="doc_header"><code>class</code> <code>ValueHead</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/gpt2.py#L16" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>ValueHead</code>(<strong><code>config</code></strong>) :: <code>Module</code></p>
<h2 id="ValueHead" class="doc_header"><code>class</code> <code>ValueHead</code><a href="https://github.com/lvwerra/trl/tree/master/trl/gpt2.py#L16" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>ValueHead</code>(<strong><code>config</code></strong>) :: <code>Module</code></p>
</blockquote>
<p>The ValueHead class implements a head for GPT2 that returns a scalar for each output token.</p>

Expand Down Expand Up @@ -207,7 +207,7 @@ <h2 id="ValueHead" class="doc_header"><code>class</code> <code>ValueHead</code><


<div class="output_markdown rendered_html output_subarea ">
<h2 id="GPT2HeadWithValueModel" class="doc_header"><code>class</code> <code>GPT2HeadWithValueModel</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/gpt2.py#L61" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>GPT2HeadWithValueModel</code>(<strong><code>config</code></strong>) :: <code>GPT2PreTrainedModel</code></p>
<h2 id="GPT2HeadWithValueModel" class="doc_header"><code>class</code> <code>GPT2HeadWithValueModel</code><a href="https://github.com/lvwerra/trl/tree/master/trl/gpt2.py#L61" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>GPT2HeadWithValueModel</code>(<strong><code>config</code></strong>) :: <code>GPT2PreTrainedModel</code></p>
</blockquote>
<p>The GPT2HeadWithValueModel class implements a GPT2 language model with a secondary, scalar head.</p>

Expand Down Expand Up @@ -508,7 +508,7 @@ <h2 id="Batched-response-to-queries">Batched response to queries<a class="anchor


<div class="output_markdown rendered_html output_subarea ">
<h4 id="respond_to_batch" class="doc_header"><code>respond_to_batch</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/gpt2.py#L113" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>respond_to_batch</code>(<strong><code>model</code></strong>, <strong><code>queries</code></strong>, <strong><code>txt_len</code></strong>=<em><code>20</code></em>, <strong><code>top_k</code></strong>=<em><code>0</code></em>, <strong><code>top_p</code></strong>=<em><code>1.0</code></em>)</p>
<h4 id="respond_to_batch" class="doc_header"><code>respond_to_batch</code><a href="https://github.com/lvwerra/trl/tree/master/trl/gpt2.py#L113" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>respond_to_batch</code>(<strong><code>model</code></strong>, <strong><code>queries</code></strong>, <strong><code>txt_len</code></strong>=<em><code>20</code></em>, <strong><code>top_k</code></strong>=<em><code>0</code></em>, <strong><code>top_p</code></strong>=<em><code>1.0</code></em>)</p>
</blockquote>
<p>Sample text from language model.</p>

Expand Down
6 changes: 3 additions & 3 deletions docs/02-ppo.html
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ <h2 id="KL-controllers">KL-controllers<a class="anchor-link" href="#KL-controlle


<div class="output_markdown rendered_html output_subarea ">
<h2 id="AdaptiveKLController" class="doc_header"><code>class</code> <code>AdaptiveKLController</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/ppo.py#L26" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>AdaptiveKLController</code>(<strong><code>init_kl_coef</code></strong>, <strong><code>target</code></strong>, <strong><code>horizon</code></strong>)</p>
<h2 id="AdaptiveKLController" class="doc_header"><code>class</code> <code>AdaptiveKLController</code><a href="https://github.com/lvwerra/trl/tree/master/trl/ppo.py#L26" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>AdaptiveKLController</code>(<strong><code>init_kl_coef</code></strong>, <strong><code>target</code></strong>, <strong><code>horizon</code></strong>)</p>
</blockquote>
<p>Adaptive KL controller described in the paper:
<a href="https://arxiv.org/pdf/1909.08593.pdf">https://arxiv.org/pdf/1909.08593.pdf</a></p>
Expand Down Expand Up @@ -140,7 +140,7 @@ <h2 id="AdaptiveKLController" class="doc_header"><code>class</code> <code>Adapti


<div class="output_markdown rendered_html output_subarea ">
<h2 id="FixedKLController" class="doc_header"><code>class</code> <code>FixedKLController</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/ppo.py#L44" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>FixedKLController</code>(<strong><code>kl_coef</code></strong>)</p>
<h2 id="FixedKLController" class="doc_header"><code>class</code> <code>FixedKLController</code><a href="https://github.com/lvwerra/trl/tree/master/trl/ppo.py#L44" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>FixedKLController</code>(<strong><code>kl_coef</code></strong>)</p>
</blockquote>
<p>Fixed KL controller.</p>

Expand Down Expand Up @@ -414,7 +414,7 @@ <h2 id="FixedKLController" class="doc_header"><code>class</code> <code>FixedKLCo


<div class="output_markdown rendered_html output_subarea ">
<h2 id="PPOTrainer" class="doc_header"><code>class</code> <code>PPOTrainer</code><a href="https://github.com/lvwerra/lm_ppo/tree/master/lm_ppo/ppo.py#L54" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>PPOTrainer</code>(<strong><code>model</code></strong>, <strong><code>ref_model</code></strong>, <strong>**<code>ppo_params</code></strong>)</p>
<h2 id="PPOTrainer" class="doc_header"><code>class</code> <code>PPOTrainer</code><a href="https://github.com/lvwerra/trl/tree/master/trl/ppo.py#L54" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>PPOTrainer</code>(<strong><code>model</code></strong>, <strong><code>ref_model</code></strong>, <strong>**<code>ppo_params</code></strong>)</p>
</blockquote>
<p>The PPO_trainer uses Proximal Policy Optimization to optimise language models.</p>

Expand Down
Loading

0 comments on commit df428af

Please sign in to comment.