Skip to content

Commit

Permalink
add minibatching (huggingface#153)
Browse files Browse the repository at this point in the history
* add minibatching

* all the fixes i missed

* ore fixes

* add dedicated variable for mini batch size

* style

* minor fixes

* fix rewards

* unbiased variance estimation

* mask values/returns

* moar fixes

* style

* change structure and add moar tests

* Apply suggestions from code review

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* deprecate `forward_batch_size`

* remove out of date warning about batching s2s and left padding models

* make style

* fixed failed merge

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
  • Loading branch information
lvwerra and younesbelkada authored Feb 23, 2023
1 parent a757ac4 commit f1300ec
Show file tree
Hide file tree
Showing 15 changed files with 323 additions and 183 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ tokenizer = AutoTokenizer.from_pretrained('gpt2')
# initialize trainer
ppo_config = PPOConfig(
batch_size=1,
forward_batch_size=1
)

# encode a query
Expand Down
10 changes: 5 additions & 5 deletions docs/source/customization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# 2. define config
ppo_config = {'batch_size': 1, 'forward_batch_size': 1, 'learning_rate':1e-5}
ppo_config = {'batch_size': 1, 'learning_rate':1e-5}
config = PPOConfig(**ppo_config)


Expand All @@ -43,7 +43,7 @@ model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# 2. define config
ppo_config = {'batch_size': 1, 'forward_batch_size': 1, 'learning_rate':1e-5}
ppo_config = {'batch_size': 1, 'learning_rate':1e-5}
config = PPOConfig(**ppo_config)


Expand Down Expand Up @@ -84,7 +84,7 @@ model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# 2. define config
ppo_config = {'batch_size': 1, 'forward_batch_size': 1, 'learning_rate':1e-5}
ppo_config = {'batch_size': 1, 'learning_rate':1e-5}
config = PPOConfig(**ppo_config)


Expand All @@ -110,7 +110,7 @@ model_ref = create_reference_model(model, num_shared_layers=6)
tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')

# 2. initialize trainer
ppo_config = {'batch_size': 1, 'forward_batch_size': 1}
ppo_config = {'batch_size': 1}
config = PPOConfig(**ppo_config)
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)
```
Expand Down Expand Up @@ -138,7 +138,7 @@ model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-
tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')

# 2. initialize trainer
ppo_config = {'batch_size': 1, 'forward_batch_size': 1}
ppo_config = {'batch_size': 1}
config = PPOConfig(**ppo_config)
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)
```
2 changes: 1 addition & 1 deletion docs/source/quickstart.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# 2. initialize trainer
ppo_config = {'batch_size': 1, 'forward_batch_size': 1}
ppo_config = {'batch_size': 1}
config = PPOConfig(**ppo_config)
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)

Expand Down
5 changes: 1 addition & 4 deletions examples/sentiment/notebooks/gpt2-sentiment-control.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -698,10 +698,7 @@
"4. Get sentiments for query/responses from BERT\n",
"5. Optimize policy with PPO using the (query, response, reward) triplet\n",
"6. Log all the training statistics\n",
"\n",
"**Forward batching**\n",
"\n",
"Since the models can be fairly big and we want to rollout large PPO batches this can lead to out-of-memory errors when doing the forward passes for text generation and sentiment analysis. We introduce the parameter `forward_batch_size` to split the forward passes into smaller batches. Although this hurts performance a little this is neglectible compared to the computations of the backward passes when optimizing the model. The same parameter is used in the `PPOTrainer` when doing forward passes. The `batch_size` should multiple of `forward_batch_size`.\n",

"\n",
"**Training time**\n",
"\n",
Expand Down
9 changes: 1 addition & 8 deletions examples/sentiment/notebooks/gpt2-sentiment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
"sent_kwargs = {\n",
" \"return_all_scores\": True,\n",
" \"function_to_apply\": \"none\",\n",
" \"batch_size\": config.forward_batch_size\n",
" \"batch_size\": 16\n",
"}"
]
},
Expand All @@ -107,13 +107,6 @@
"wandb.init()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Forward batching**: Since the models can be fairly big and we want to rollout large PPO batches this can lead to out-of-memory errors when doing the forward passes for text generation and sentiment analysis. We introduce the parameter `forward_batch_size` to split the forward passes into smaller batches. Although this hurts performance a little this is neglectible compared to the computations of the backward passes when optimizing the model. The same parameter is used in the `PPOTrainer` when doing forward passes. The `batch_size` should multiple of `forward_batch_size`."
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
3 changes: 2 additions & 1 deletion examples/sentiment/scripts/gpt2-sentiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@

# We then define the arguments to pass to the sentiment analysis pipeline.
# We set `return_all_scores` to True to get the sentiment score for each token.
sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": config.forward_batch_size}
sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": 16}


# Below is an example function to build the dataset. In our case, we use the IMDB dataset
# from the `datasets` library. One should customize this function to train the model on
Expand Down
4 changes: 3 additions & 1 deletion examples/sentiment/scripts/t5-sentiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@
config = PPOConfig(model_name="lvwerra/t5-imdb", learning_rate=5e-5, batch_size=256)
# We then define the arguments to pass to the sentiment analysis pipeline.
# We set `return_all_scores` to True to get the sentiment score for each token.
sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": config.forward_batch_size}
sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": 16}


# Below is an example function to build the dataset. In our case, we use the IMDB dataset
# from the `datasets` library. One should customize this function to train the model on
# its own dataset.
Expand Down
Empty file removed tests/models/__init__.py
Empty file.
42 changes: 42 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest

import torch

from trl.core import masked_mean, masked_var, masked_whiten, whiten


class CoreTester(unittest.TestCase):
"""
A wrapper class for testing core utils functions
"""

@classmethod
def setUpClass(cls):
cls.test_input = torch.Tensor([1, 2, 3, 4])
cls.test_mask = torch.Tensor([0, 1, 1, 0])
cls.test_input_unmasked = cls.test_input[1:3]

def test_masked_mean(self):
self.assertEqual(torch.mean(self.test_input_unmasked), masked_mean(self.test_input, self.test_mask))

def test_masked_var(self):
self.assertEqual(torch.var(self.test_input_unmasked), masked_var(self.test_input, self.test_mask))

def test_masked_whiten(self):
whiten_unmasked = whiten(self.test_input_unmasked)
whiten_masked = masked_whiten(self.test_input, self.test_mask)[1:3]
diffs = (whiten_unmasked - whiten_masked).sum()
self.assertAlmostEqual(diffs, 0)
File renamed without changes.
123 changes: 84 additions & 39 deletions tests/trainer/test_ppo_trainer.py → tests/test_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
from requests.exceptions import HTTPError
from transformers import AutoTokenizer

from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead, PPOConfig, PPOTrainer
from trl.core import logprobs_from_logits, respond_to_batch
from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead, PPOConfig, PPOTrainer, set_seed
from trl.core import respond_to_batch

from ..testing_constants import CI_HUB_ENDPOINT, CI_HUB_USER, CI_HUB_USER_TOKEN
from .testing_constants import CI_HUB_ENDPOINT, CI_HUB_USER, CI_HUB_USER_TOKEN


EXPECTED_STATS = [
Expand Down Expand Up @@ -77,13 +77,30 @@ def __getitem__(self, idx):
return self.query_data[idx], self.response_data[idx]


def apply_mask(values, mask):
unmasked_values = []
for v, m in zip(values, mask):
if m == 1:
unmasked_values.append(v)
return torch.Tensor(unmasked_values)


def abs_diff_masked_tensors(tensor_1, tensor_2, mask_1, mask_2):
diffs = []
for l1, l2, m1, m2 in zip(tensor_1, tensor_2, mask_1, mask_2):
diff = apply_mask(l1, m1) - apply_mask(l2, m2)
diffs.append(diff.sum())
return abs(sum(diffs))


class PPOTrainerTester(unittest.TestCase):
"""
A wrapper class for testing PPOTrainer
"""

@classmethod
def setUpClass(cls):
set_seed(42)
cls._token = CI_HUB_USER_TOKEN
cls._api = HfApi(endpoint=CI_HUB_ENDPOINT)
cls._api.set_access_token(CI_HUB_USER_TOKEN)
Expand All @@ -109,7 +126,7 @@ def setUpClass(cls):
cls.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)

# initialize trainer
cls.ppo_config = PPOConfig(batch_size=2, forward_batch_size=1, log_with=None)
cls.ppo_config = PPOConfig(batch_size=2, mini_batch_size=1, log_with=None)

@classmethod
def tearDownClass(cls):
Expand All @@ -121,8 +138,8 @@ def tearDownClass(cls):

def setUp(self):
# initialize trainer
self.ppo_config = PPOConfig(batch_size=2, forward_batch_size=1, log_with=None)

self.ppo_config = PPOConfig(batch_size=2, mini_batch_size=1, log_with=None)
self.gpt2_model.train()
return super().setUp()

def tearDown(self):
Expand All @@ -148,7 +165,7 @@ def _init_dummy_dataset(self):
return dummy_dataset

def test_drop_last_dataloader(self):
self.ppo_config = PPOConfig(batch_size=3, forward_batch_size=1, log_with=None)
self.ppo_config = PPOConfig(batch_size=3, mini_batch_size=1, log_with=None)

dummy_dataset = self._init_dummy_dataset()

Expand Down Expand Up @@ -502,6 +519,8 @@ def test_loss_trainer(self):
# initialize dataset
dummy_dataset = self._init_dummy_dataset()

self.gpt2_model.eval()

ppo_trainer = PPOTrainer(
config=self.ppo_config,
model=self.gpt2_model,
Expand All @@ -510,31 +529,51 @@ def test_loss_trainer(self):
dataset=dummy_dataset,
)

dummy_model_input = torch.tensor([[1, 2, 3, 4, 5, 6, 7]])
dummy_query = torch.tensor([[1, 2, 3, 4]])
dummy_response = torch.tensor([[5, 6, 7]])
rewards = torch.tensor([[0, 1, 0]])
gen_len = rewards.shape[-1]

old_logits, _, values = ppo_trainer.ref_model(dummy_model_input)
old_logprobs = logprobs_from_logits(old_logits[:, :-1, :], dummy_model_input[:, 1:])
values = values[:, -gen_len - 1 : -1]
old_logprobs = old_logprobs[:, -gen_len:]
dummy_queries = [torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 3, 4, 5, 6, 7])]
dummy_responses = [torch.tensor([5, 6, 7, 8, 9]), torch.tensor([8, 9, 10, 11, 12, 13])]
dummy_scores = torch.Tensor([1, 2])

# logprobs, logits, vpred = ppo_trainer.model(dummy_tokens)
logprobs, vpred, logits = ppo_trainer.compute_logits_vpred(
dummy_model_input, dummy_query, dummy_response, rewards
ppo_trainer.config.mini_batch_size = 1
ppo_trainer.config.batch_size = 1
model_inputs = ppo_trainer.prepare_model_inputs(dummy_queries, dummy_responses)
all_logprobs, _, values, mask = ppo_trainer.batched_forward_pass(
self.gpt2_model, dummy_queries, dummy_responses, model_inputs
)

# dummy values
ref_logprobs = all_logprobs + 1
logits = torch.exp(all_logprobs)
vpreds = values + 0.1

score, non_score = ppo_trainer.compute_rewards(dummy_scores, all_logprobs, ref_logprobs, mask)

# just make sure a dummy loss is computed
pg_loss, vf_loss, _ = ppo_trainer.loss(
old_logprobs=old_logprobs,
logprob=logprobs,
logits=logits,
values=values,
rewards=rewards,
vpred=vpred,
idx = 0
pg_loss, v_loss, _ = ppo_trainer.loss(
all_logprobs[idx].unsqueeze(0),
values[idx].unsqueeze(0),
score[idx].unsqueeze(0),
logits[idx].unsqueeze(0),
vpreds[idx].unsqueeze(0),
ref_logprobs[idx].unsqueeze(0),
mask[idx].unsqueeze(0),
)

self.assertAlmostEqual(pg_loss.item(), 0.62516, 4)
self.assertAlmostEqual(v_loss.item(), 0.09950, 4)

# check if we get same results with masked parts removed
pg_loss_unmasked, v_loss_unmasked, _ = ppo_trainer.loss(
apply_mask(all_logprobs[idx], mask[idx]).unsqueeze(0),
apply_mask(values[idx], mask[idx]).unsqueeze(0),
apply_mask(score[idx], mask[idx]).unsqueeze(0),
apply_mask(logits[idx], mask[idx]).unsqueeze(0),
apply_mask(vpreds[idx], mask[idx]).unsqueeze(0),
apply_mask(ref_logprobs[idx], mask[idx]).unsqueeze(0),
apply_mask(mask[idx], mask[idx]).unsqueeze(0),
)
self.assertAlmostEqual(pg_loss_unmasked.item(), 0.62516, 4)
self.assertAlmostEqual(v_loss_unmasked.item(), 0.09950, 4)

@parameterized.expand(
[
Expand Down Expand Up @@ -578,25 +617,31 @@ def test_batched_forward_pass(self, name):
# if fwd_bs=1/bs=2: padding is applied and results computed in two fwd passes
# if fwd_bs=bs=2: padding is applied and results computed in one fwd pass

ppo_trainer.config.forward_batch_size = 1
ppo_trainer.config.mini_batch_size = 1
ppo_trainer.config.batch_size = 1

logprobs_0, ref_logprobs_0, values_0 = ppo_trainer.batched_forward_pass(
[dummy_queries[0]], [dummy_responses[0]]
model_inputs = ppo_trainer.prepare_model_inputs([dummy_queries[0]], [dummy_responses[0]])
logprobs_0, logits_0, values_0, mask_0 = ppo_trainer.batched_forward_pass(
model, [dummy_queries[0]], [dummy_responses[0]], model_inputs
)

ppo_trainer.config.batch_size = 2
logprobs_1, ref_logprobs_1, values_1 = ppo_trainer.batched_forward_pass(dummy_queries, dummy_responses)
model_inputs = ppo_trainer.prepare_model_inputs(dummy_queries, dummy_responses)
logprobs_1, logits_1, values_1, mask_1 = ppo_trainer.batched_forward_pass(
model, dummy_queries, dummy_responses, model_inputs
)

ppo_trainer.config.forward_batch_size = 2
logprobs_2, ref_logprobs_2, values_2 = ppo_trainer.batched_forward_pass(dummy_queries, dummy_responses)
ppo_trainer.config.mini_batch_size = 2
model_inputs = ppo_trainer.prepare_model_inputs(dummy_queries, dummy_responses)
logprobs_2, logits_2, values_2, mask_2 = ppo_trainer.batched_forward_pass(
model, dummy_queries, dummy_responses, model_inputs
)

self.assertLessEqual(abs(sum([(l1 - l2).sum() for l1, l2 in zip(logprobs_1, logprobs_2)])), 1e-4)
self.assertLessEqual(abs(sum([(l1 - l2).sum() for l1, l2 in zip(ref_logprobs_1, ref_logprobs_2)])), 1e-4)
self.assertLessEqual(abs(sum([(v1 - v2).sum() for v1, v2 in zip(values_1, values_2)])), 1e-4)
self.assertLessEqual(abs_diff_masked_tensors(logprobs_1, logprobs_2, mask_1, mask_2), 1e-4)
self.assertLessEqual(abs_diff_masked_tensors(values_1, values_2, mask_1, mask_2), 1e-4)

self.assertLessEqual(abs((logprobs_0[0] - logprobs_2[0]).sum()), 1e-4)
self.assertLessEqual(abs((ref_logprobs_0[0] - ref_logprobs_2[0]).sum()), 1e-4)
self.assertLessEqual(abs((values_0[0] - values_2[0]).sum()), 1e-4)
self.assertLessEqual(abs_diff_masked_tensors(logprobs_0, logprobs_2[:1], mask_0, mask_2[:1]), 1e-4)
self.assertLessEqual(abs_diff_masked_tensors(values_0, values_2[:1], mask_0, mask_2[:1]), 1e-4)

@unittest.skip("Fix by either patching `whomai()` to work in the staging endpoint or use a dummy prod user.")
def test_push_to_hub(self):
Expand Down
Empty file removed tests/trainer/__init__.py
Empty file.
25 changes: 25 additions & 0 deletions trl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,31 @@ def whiten(values, shift_mean=True):
return whitened


def masked_mean(values, mask, axis=None):
"""Compute mean of tensor with a masked values."""
return (values * mask).sum(axis=axis) / mask.sum(axis=axis)


def masked_var(values, mask, unbiased=True):
"""Compute variance of tensor with masked values."""
mean = masked_mean(values, mask)
centered_values = values - mean
variance = masked_mean(centered_values**2, mask)
if unbiased:
bessel_correction = mask.sum() / (mask.sum() - 1)
variance = variance * bessel_correction
return variance


def masked_whiten(values, mask, shift_mean=True):
"""Whiten values with masked values."""
mean, var = masked_mean(values, mask), masked_var(values, mask)
whitened = (values - mean) * torch.rsqrt(var + 1e-8)
if not shift_mean:
whitened += mean
return whitened


def clip_by_value(x, tensor_min, tensor_max):
"""
Tensor extenstion to torch.clamp
Expand Down
Loading

0 comments on commit f1300ec

Please sign in to comment.