Skip to content

Commit

Permalink
[core] Add torch_dtype support (huggingface#147)
Browse files Browse the repository at this point in the history
* fix torch_dtype

- add tests

* Update tests/models/test_modeling_value_head.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Update tests/models/test_modeling_value_head.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
  • Loading branch information
younesbelkada and lvwerra authored Feb 16, 2023
1 parent 00aa31e commit 032676a
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 6 deletions.
56 changes: 55 additions & 1 deletion tests/models/test_modeling_value_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_from_save_trl_sharded(self):
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)

model_from_save = self.trl_model_class.from_pretrained(tmp_dir, max_shard_size="1MB")
model_from_save = self.trl_model_class.from_pretrained(tmp_dir)

# Check if the weights are the same
for key in model_from_save.state_dict():
Expand Down Expand Up @@ -264,6 +264,31 @@ def test_raise_error_not_causallm(self):
pretrained_model = AutoModelForCausalLM.from_pretrained(model_id)
_ = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model.transformer)

def test_transformers_bf16_kwargs(self):
r"""
Test if the transformers kwargs are correctly passed
Here we check that loading a model in half precision works as expected, i.e. the weights of
the `pretrained_model` attribute is loaded in half precision and you can run a dummy
forward pass without any issue.
"""
for model_name in self.all_model_names:
trl_model = self.trl_model_class.from_pretrained(model_name, torch_dtype=torch.bfloat16)

lm_head_namings = self.trl_model_class.lm_head_namings

self.assertTrue(
any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings)
)

for lm_head_naming in lm_head_namings:
if hasattr(trl_model.pretrained_model, lm_head_naming):
self.assertTrue(getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16)

dummy_input = torch.LongTensor([[0, 1, 0, 1]])

# check dummy forward pass works in half precision
_ = trl_model(dummy_input)

@unittest.skip("This test needs to be run manually due to HF token issue.")
def test_push_to_hub(self):
for model_name in self.all_model_names:
Expand Down Expand Up @@ -384,6 +409,35 @@ def test_push_to_hub(self):
f"Parameter {name} is not the same after push_to_hub and from_pretrained",
)

def test_transformers_bf16_kwargs(self):
r"""
Test if the transformers kwargs are correctly passed
Here we check that loading a model in half precision works as expected, i.e. the weights of
the `pretrained_model` attribute is loaded in half precision and you can run a dummy
forward pass without any issue.
"""
for model_name in self.all_model_names:
trl_model = self.trl_model_class.from_pretrained(model_name, torch_dtype=torch.bfloat16)

lm_head_namings = self.trl_model_class.lm_head_namings

if model_name == "trl-internal-testing/tiny-random-FSMTForConditionalGeneration":
# skip the test for FSMT as it does not support mixed-prec
continue

self.assertTrue(
any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings)
)

for lm_head_naming in lm_head_namings:
if hasattr(trl_model.pretrained_model, lm_head_naming):
self.assertTrue(getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16)

dummy_input = torch.LongTensor([[0, 1, 0, 1]])

# check dummy forward pass works in half precision
_ = trl_model(input_ids=dummy_input, decoder_input_ids=dummy_input)


class ReferenceModelTest(unittest.TestCase):
def setUp(self):
Expand Down
6 changes: 1 addition & 5 deletions trl/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import os
from copy import deepcopy
Expand Down Expand Up @@ -145,10 +144,7 @@ def _split_kwargs(cls, kwargs):
unsupported_kwargs = {}

for key, value in kwargs.items():
if (
key in cls.supported_args
or key not in inspect.signature(cls.transformers_parent_class.from_pretrained).parameters.keys()
):
if key in cls.supported_args:
supported_kwargs[key] = value
else:
unsupported_kwargs[key] = value
Expand Down

0 comments on commit 032676a

Please sign in to comment.