Skip to content

Commit

Permalink
Merge pull request huggingface#29 from lvwerra/add-fixed-kl-controller
Browse files Browse the repository at this point in the history
add `FixedKLController`
  • Loading branch information
lvwerra authored Jan 1, 2022
2 parents 6895c6a + bc4f170 commit caed471
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 11 deletions.
11 changes: 7 additions & 4 deletions docs/02-ppo.html
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,13 @@ <h2 id="FixedKLController" class="doc_header"><code>class</code> <code>FixedKLCo
<span class="bp">self</span><span class="o">.</span><span class="n">ref_model</span> <span class="o">=</span> <span class="n">ref_model</span>
<span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">model</span>
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">Adam</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ppo_params</span><span class="p">[</span><span class="s1">&#39;lr&#39;</span><span class="p">])</span>

<span class="bp">self</span><span class="o">.</span><span class="n">kl_ctl</span> <span class="o">=</span> <span class="n">AdaptiveKLController</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ppo_params</span><span class="p">[</span><span class="s1">&#39;init_kl_coef&#39;</span><span class="p">],</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ppo_params</span><span class="p">[</span><span class="s1">&#39;target&#39;</span><span class="p">],</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ppo_params</span><span class="p">[</span><span class="s1">&#39;horizon&#39;</span><span class="p">])</span>

<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ppo_params</span><span class="p">[</span><span class="s1">&#39;adap_kl_ctrl&#39;</span><span class="p">]:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kl_ctl</span> <span class="o">=</span> <span class="n">AdaptiveKLController</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ppo_params</span><span class="p">[</span><span class="s1">&#39;init_kl_coef&#39;</span><span class="p">],</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ppo_params</span><span class="p">[</span><span class="s1">&#39;target&#39;</span><span class="p">],</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ppo_params</span><span class="p">[</span><span class="s1">&#39;horizon&#39;</span><span class="p">])</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kl_ctl</span> <span class="o">=</span> <span class="n">FixedKLController</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ppo_params</span><span class="p">[</span><span class="s1">&#39;init_kl_coef&#39;</span><span class="p">])</span>


<span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">query</span><span class="p">,</span> <span class="n">response</span><span class="p">,</span> <span class="n">scores</span><span class="p">):</span>
Expand Down
11 changes: 7 additions & 4 deletions nbs/02-ppo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,13 @@
" self.ref_model = ref_model\n",
" self.model = model\n",
" self.optimizer = Adam(model.parameters(), lr=self.ppo_params['lr'])\n",
" \n",
" self.kl_ctl = AdaptiveKLController(self.ppo_params['init_kl_coef'],\n",
" self.ppo_params['target'],\n",
" self.ppo_params['horizon'])\n",
" \n",
" if self.ppo_params['adap_kl_ctrl']:\n",
" self.kl_ctl = AdaptiveKLController(self.ppo_params['init_kl_coef'],\n",
" self.ppo_params['target'],\n",
" self.ppo_params['horizon'])\n",
" else:\n",
" self.kl_ctl = FixedKLController(self.ppo_params['init_kl_coef'])\n",
"\n",
"\n",
" def step(self, query, response, scores):\n",
Expand Down
9 changes: 6 additions & 3 deletions trl/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,12 @@ def __init__(self, model, ref_model, **ppo_params):
self.model = model
self.optimizer = Adam(model.parameters(), lr=self.ppo_params['lr'])

self.kl_ctl = AdaptiveKLController(self.ppo_params['init_kl_coef'],
self.ppo_params['target'],
self.ppo_params['horizon'])
if self.ppo_params['adap_kl_ctrl']:
self.kl_ctl = AdaptiveKLController(self.ppo_params['init_kl_coef'],
self.ppo_params['target'],
self.ppo_params['horizon'])
else:
self.kl_ctl = FixedKLController(self.ppo_params['init_kl_coef'])


def step(self, query, response, scores):
Expand Down

0 comments on commit caed471

Please sign in to comment.