Skip to content

Commit

Permalink
Enable safetensors conversion from PyTorch to other frameworks withou…
Browse files Browse the repository at this point in the history
…t the torch requirement (huggingface#27599)

* Initial commit

* Requirements & tests

* Tests

* Tests

* Rogue import

* Rogue torch import

* Cleanup

* Apply suggestions from code review

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>

* bfloat16 management

* Sanchit's comments

* Import shield

* apply suggestions from code review

* correct bf16

* rebase

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
Co-authored-by: sanchit-gandhi <sanchit@huggingface.co>
  • Loading branch information
3 people authored Jan 23, 2024
1 parent 0398660 commit 008a6a2
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 66 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@
"ruff==0.1.5",
"sacrebleu>=1.4.12,<2.0.0",
"sacremoses",
"safetensors>=0.3.1",
"safetensors>=0.4.1",
"sagemaker>=2.31.0",
"scikit-learn",
"sentencepiece>=0.1.91,!=0.1.92",
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
"ruff": "ruff==0.1.5",
"sacrebleu": "sacrebleu>=1.4.12,<2.0.0",
"sacremoses": "sacremoses",
"safetensors": "safetensors>=0.3.1",
"safetensors": "safetensors>=0.4.1",
"sagemaker": "sagemaker>=2.31.0",
"scikit-learn": "scikit-learn",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
Expand Down
57 changes: 28 additions & 29 deletions src/transformers/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,13 @@

import transformers

from . import is_safetensors_available
from . import is_safetensors_available, is_torch_available
from .utils import logging


if is_torch_available():
import torch

if is_safetensors_available():
from safetensors import safe_open
from safetensors.flax import load_file as safe_load_file
Expand All @@ -48,30 +51,31 @@ def load_pytorch_checkpoint_in_flax_state_dict(
flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False
):
"""Load pytorch checkpoints in a flax model"""
try:
import torch # noqa: F401

from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401
except (ImportError, ModuleNotFoundError):
logger.error(
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
" instructions."
)
raise

if not is_sharded:
pt_path = os.path.abspath(pytorch_checkpoint_path)
logger.info(f"Loading PyTorch weights from {pt_path}")

if pt_path.endswith(".safetensors"):
pt_state_dict = {}
with safe_open(pt_path, framework="pt") as f:
with safe_open(pt_path, framework="flax") as f:
for k in f.keys():
pt_state_dict[k] = f.get_tensor(k)
else:
try:
import torch # noqa: F401

from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401
except (ImportError, ModuleNotFoundError):
logger.error(
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
" instructions."
)
raise

pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13)
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")

flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
else:
Expand Down Expand Up @@ -149,21 +153,17 @@ def is_key_or_prefix_key_in_dict(key: Tuple[str]) -> bool:

def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
# convert pytorch tensor to numpy
# numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision
try:
import torch # noqa: F401
except (ImportError, ModuleNotFoundError):
logger.error(
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
" instructions."
)
raise
from_bin = is_torch_available() and isinstance(next(iter(pt_state_dict.values())), torch.Tensor)
bfloat16 = torch.bfloat16 if from_bin else "bfloat16"

weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
pt_state_dict = {
k: v.numpy() if not v.dtype == torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items()
}

if from_bin:
for k, v in pt_state_dict.items():
# numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision
if v.dtype == bfloat16:
v = v.float()
pt_state_dict[k] = v.numpy()

model_prefix = flax_model.base_model_prefix

Expand Down Expand Up @@ -191,7 +191,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
# Need to change some parameters name to match Flax names
for pt_key, pt_tensor in pt_state_dict.items():
pt_tuple_key = tuple(pt_key.split("."))
is_bfloat_16 = weight_dtypes[pt_key] == torch.bfloat16
is_bfloat_16 = weight_dtypes[pt_key] == bfloat16

# remove base model prefix if necessary
has_base_model_prefix = pt_tuple_key[0] == model_prefix
Expand Down Expand Up @@ -229,7 +229,6 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
flax_state_dict[("params",) + flax_key] = (
jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
)

else:
# also add unexpected weight so that warning is thrown
flax_state_dict[flax_key] = (
Expand Down
99 changes: 64 additions & 35 deletions tests/test_modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@
from transformers.testing_utils import (
TOKEN,
USER,
CaptureLogger,
is_pt_flax_cross_test,
is_staging_test,
require_flax,
require_safetensors,
require_torch,
)
from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME, logging


if is_flax_available():
Expand All @@ -42,6 +43,9 @@

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8

if is_torch_available():
import torch


@require_flax
@is_staging_test
Expand Down Expand Up @@ -251,7 +255,6 @@ def test_safetensors_load_from_local(self):

self.assertTrue(check_models_equal(flax_model, safetensors_model))

@require_torch
@require_safetensors
@is_pt_flax_cross_test
def test_safetensors_load_from_hub_from_safetensors_pt(self):
Expand All @@ -265,57 +268,44 @@ def test_safetensors_load_from_hub_from_safetensors_pt(self):
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
self.assertTrue(check_models_equal(flax_model, safetensors_model))

@require_torch
@require_safetensors
@require_torch
@is_pt_flax_cross_test
def test_safetensors_load_from_local_from_safetensors_pt(self):
def test_safetensors_load_from_hub_from_safetensors_pt_bf16(self):
"""
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
saved in the "pt" format.
"""
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-msgpack", cache_dir=tmp)
flax_model = FlaxBertModel.from_pretrained(location)
import torch

model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
model.to(torch.bfloat16)

# Can load from the PyTorch-formatted checkpoint
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
safetensors_model = FlaxBertModel.from_pretrained(location)
model.save_pretrained(tmp)
flax_model = FlaxBertModel.from_pretrained(tmp)

# Can load from the PyTorch-formatted checkpoint
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
self.assertTrue(check_models_equal(flax_model, safetensors_model))

@require_safetensors
def test_safetensors_load_from_hub_from_safetensors_pt_without_torch_installed(self):
"""
This test checks that we cannot load safetensors from a checkpoint that only has safetensors
saved in the "pt" format if torch isn't installed.
"""
if is_torch_available():
# This test verifies that a correct error message is shown when loading from a pt safetensors
# PyTorch shouldn't be installed for this to work correctly.
return

# Cannot load from the PyTorch-formatted checkpoint without PyTorch installed
with self.assertRaises(ModuleNotFoundError):
_ = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")

@require_safetensors
def test_safetensors_load_from_local_from_safetensors_pt_without_torch_installed(self):
@is_pt_flax_cross_test
def test_safetensors_load_from_local_from_safetensors_pt(self):
"""
This test checks that we cannot load safetensors from a checkpoint that only has safetensors
saved in the "pt" format if torch isn't installed.
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
saved in the "pt" format.
"""
if is_torch_available():
# This test verifies that a correct error message is shown when loading from a pt safetensors
# PyTorch shouldn't be installed for this to work correctly.
return
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-msgpack", cache_dir=tmp)
flax_model = FlaxBertModel.from_pretrained(location)

# Can load from the PyTorch-formatted checkpoint
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
safetensors_model = FlaxBertModel.from_pretrained(location)

# Cannot load from the PyTorch-formatted checkpoint without PyTorch installed
with self.assertRaises(ModuleNotFoundError):
_ = FlaxBertModel.from_pretrained(location)
self.assertTrue(check_models_equal(flax_model, safetensors_model))

@require_safetensors
def test_safetensors_load_from_hub_msgpack_before_safetensors(self):
Expand Down Expand Up @@ -347,6 +337,7 @@ def test_safetensors_flax_from_flax(self):

@require_safetensors
@require_torch
@is_pt_flax_cross_test
def test_safetensors_flax_from_torch(self):
hub_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
Expand All @@ -372,3 +363,41 @@ def test_safetensors_flax_from_sharded_msgpack_with_sharded_safetensors_hub(self
# This should not raise even if there are two types of sharded weights
# This should discard the safetensors weights in favor of the msgpack sharded weights
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-msgpack-sharded")

@require_safetensors
def test_safetensors_from_pt_bf16(self):
# This should not raise; should be able to load bf16-serialized torch safetensors without issue
# and without torch.
logger = logging.get_logger("transformers.modeling_flax_utils")

with CaptureLogger(logger) as cl:
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")

self.assertTrue(
"Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint"
in cl.out
)

@require_torch
@require_safetensors
@is_pt_flax_cross_test
def test_from_pt_bf16(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
model.to(torch.bfloat16)

with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=False)

logger = logging.get_logger("transformers.modeling_flax_utils")

with CaptureLogger(logger) as cl:
new_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")

self.assertTrue(
"Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint"
in cl.out
)

flat_params_1 = flatten_dict(new_model.params)
for value in flat_params_1.values():
self.assertEqual(value.dtype, "bfloat16")

0 comments on commit 008a6a2

Please sign in to comment.