Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable accurate torch hashing #1599

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

landoskape
Copy link

As described in the docs, torch tensors have non-deterministic pickle representations. This makes it impossible to accurately hash & cache when torch tensors are involved.

This is described in issue #1282.

This PR solves the problem by converting torch tensors to numpy arrays before hashing. It is built to handle torch tensors or torch modules, which it handles by converting the state_dict to a dictionary of {key: numpy_array} pairs.

Test for hashing :

from joblib import hash
import torch, torchvision

net1 = torchvision.models.alexnet(weights=torchvision.models.AlexNet_Weights.DEFAULT)
net2 = torchvision.models.alexnet(weights=torchvision.models.AlexNet_Weights.DEFAULT)
net3 = torchvision.models.alexnet()

print(hash(net1)) # 0e8b7766fbc3f36ecd92377e02868999
print(hash(net2)) # 0e8b7766fbc3f36ecd92377e02868999
print(hash(net3)) # f784717a3920ef6dd166d6b018419ac8

t1 = torch.tensor(1.5)
t2 = torch.tensor(1.5)
t3 = torch.tensor(2.5)

print(hash(t1)) # 9293bd3ea4c30e5e1ef0881447869946
print(hash(t2)) # 9293bd3ea4c30e5e1ef0881447869946
print(hash(t3)) # 441a82b6cc7a97f60693ed5344a8e5bc

Test for caching:

import time
from joblib import Memory
import torch, torchvision

memory = Memory(location='.cache', verbose=0)

@memory.cache
def check_torch_cache(tensor, module, string):
    time.sleep(1)
    return tensor

t = torch.tensor(1.5)
m = torchvision.models.alexnet()
s = 'hello'

t_alt = torch.tensor(2.5)
m_alt = torchvision.models.alexnet(weights=torchvision.models.AlexNet_Weights.DEFAULT)
s_alt = 'world'

# initialize cache
t0 = time.time()
print(check_torch_cache(t, m, s))
print(time.time() - t0)

# should use cache and be very fast
t1 = time.time()
print(check_torch_cache(t, m, s))
print(time.time() - t1)

# changing tensor should prevent use of the cache
t2 = time.time()
print(check_torch_cache(t_alt, m, s))
print(time.time() - t2)

# changing model should prevent use of the cache
t3 = time.time()
print(check_torch_cache(t, m_alt, s))
print(time.time() - t3)

# changing string should prevent use of the cache
t4 = time.time()
print(check_torch_cache(t_alt, m, s_alt))
print(time.time() - t4)


# returns:
# tensor(1.5000)
# 1.245258092880249
# tensor(1.5000)
# 0.27921152114868164
# tensor(2.5000)
# 1.263361930847168
# tensor(1.5000)
# 1.2799842357635498
# tensor(2.5000)
# 1.271822452545166

Copy link

codecov bot commented Jun 28, 2024

Codecov Report

Attention: Patch coverage is 24.00000% with 19 lines in your changes missing coverage. Please review.

Project coverage is 95.03%. Comparing base (f70939c) to head (1160fdd).

Files Patch % Lines
joblib/hashing.py 24.00% 19 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1599      +/-   ##
==========================================
- Coverage   95.24%   95.03%   -0.21%     
==========================================
  Files          45       45              
  Lines        7715     7739      +24     
==========================================
+ Hits         7348     7355       +7     
- Misses        367      384      +17     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant