Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable accurate torch hashing #1599

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions joblib/hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,54 @@
Hasher.save(self, obj)


class TorchHasher(NumpyHasher):
""" Special case for the hasher for when torch is loaded.

This class extends the NumpyHasher class to handle torch tensors and
torch modules. It converts torch tensors and torch modules to numpy
arrays for deterministic hashing.
"""

def __init__(self, hash_name="md5", coerce_mmap=False):
super().__init__(hash_name, coerce_mmap)
from torch.nn import Module as torch_nnModule # noqa: import-outside-toplevel
from torch import Tensor as torch_Tensor # noqa: import-outside-toplevel

Check warning on line 255 in joblib/hashing.py

View check run for this annotation

Codecov / codecov/patch

joblib/hashing.py#L253-L255

Added lines #L253 - L255 were not covered by tests

self.torch_nnModule = torch_nnModule
self.torch_Tensor = torch_Tensor

Check warning on line 258 in joblib/hashing.py

View check run for this annotation

Codecov / codecov/patch

joblib/hashing.py#L257-L258

Added lines #L257 - L258 were not covered by tests

def _convert_tensors_to_numpy(self, obj):
# Recursively convert torch tensors in obj to numpy arrays
if isinstance(obj, dict):
for key, value in obj.items():
obj[key] = self._convert_tensors_to_numpy(value)
if isinstance(obj, self.torch_nnModule):
state_dict = obj.state_dict()
obj = {key: self._convert_tensors_to_numpy(value)

Check warning on line 267 in joblib/hashing.py

View check run for this annotation

Codecov / codecov/patch

joblib/hashing.py#L262-L267

Added lines #L262 - L267 were not covered by tests
for key, value in state_dict.items()}
return obj
if isinstance(obj, self.torch_Tensor):
obj_as_numpy = obj.cpu().detach().numpy()
return obj_as_numpy
return obj

Check warning on line 273 in joblib/hashing.py

View check run for this annotation

Codecov / codecov/patch

joblib/hashing.py#L269-L273

Added lines #L269 - L273 were not covered by tests

def save(self, obj):
""" Subclass again to convert torch tensors and torch modules to numpy
arrays for deterministic hashing. Torch tensors do not have
deterministic pickle representations and therefore hashing them is
not reliable.

Torch tensors are converted to numpy arrays directly, and torch
modules are converted to dictionaries of numpy arrays
corresponding to each component in the state_dict.
"""
obj = self._convert_tensors_to_numpy(obj)
NumpyHasher.save(self, obj)

Check warning on line 286 in joblib/hashing.py

View check run for this annotation

Codecov / codecov/patch

joblib/hashing.py#L285-L286

Added lines #L285 - L286 were not covered by tests


def hash(obj, hash_name='md5', coerce_mmap=False):
""" Quick calculation of a hash to identify uniquely Python objects
containing numpy arrays.
containing numpy arrays or torch tensors.

Parameters
----------
Expand All @@ -258,7 +303,9 @@
raise ValueError("Valid options for 'hash_name' are {}. "
"Got hash_name={!r} instead."
.format(valid_hash_names, hash_name))
if 'numpy' in sys.modules:
if 'torch' in sys.modules:
hasher = TorchHasher(hash_name=hash_name, coerce_mmap=coerce_mmap)

Check warning on line 307 in joblib/hashing.py

View check run for this annotation

Codecov / codecov/patch

joblib/hashing.py#L307

Added line #L307 was not covered by tests
elif 'numpy' in sys.modules:
hasher = NumpyHasher(hash_name=hash_name, coerce_mmap=coerce_mmap)
else:
hasher = Hasher(hash_name=hash_name)
Expand Down