-
Notifications
You must be signed in to change notification settings - Fork 328
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] TensorDictMap hashing functions
ghstack-source-id: 1c959eeeec5bbd0093b6c2367c853d66b355c8e1 Pull Request resolved: #2304
- Loading branch information
Showing
5 changed files
with
246 additions
and
13 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import argparse | ||
import importlib.util | ||
|
||
import pytest | ||
|
||
import torch | ||
|
||
from torchrl.data.map import BinaryToDecimal, RandomProjectionHash, SipHash | ||
|
||
_has_gym = importlib.util.find_spec("gymnasium", None) or importlib.util.find_spec( | ||
"gym", None | ||
) | ||
|
||
|
||
class TestHash: | ||
def test_binary_to_decimal(self): | ||
binary_to_decimal = BinaryToDecimal( | ||
num_bits=4, device="cpu", dtype=torch.int32, convert_to_binary=True | ||
) | ||
binary = torch.Tensor([[0, 0, 1, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 10, 0]]) | ||
decimal = binary_to_decimal(binary) | ||
|
||
assert decimal.shape == (2,) | ||
assert (decimal == torch.Tensor([3, 2])).all() | ||
|
||
def test_sip_hash(self): | ||
a = torch.rand((3, 2)) | ||
b = a.clone() | ||
hash_module = SipHash(as_tensor=True) | ||
hash_a = torch.tensor(hash_module(a)) | ||
hash_b = torch.tensor(hash_module(b)) | ||
assert (hash_a == hash_b).all() | ||
|
||
@pytest.mark.parametrize("n_components", [None, 14]) | ||
@pytest.mark.parametrize("scale", [0.001, 0.01, 1, 100, 1000]) | ||
def test_randomprojection_hash(self, n_components, scale): | ||
torch.manual_seed(0) | ||
r = RandomProjectionHash(n_components=n_components) | ||
x = torch.randn(10000, 100).mul_(scale) | ||
y = r(x) | ||
if n_components is None: | ||
assert r.n_components == r._N_COMPONENTS_DEFAULT | ||
else: | ||
assert r.n_components == n_components | ||
|
||
assert y.shape == (10000,) | ||
assert y.unique().numel() == y.numel() | ||
|
||
|
||
if __name__ == "__main__": | ||
args, unknown = argparse.ArgumentParser().parse_known_args() | ||
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from .hash import BinaryToDecimal, RandomProjectionHash, SipHash |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from __future__ import annotations | ||
|
||
from typing import Callable, List | ||
|
||
import torch | ||
|
||
|
||
class BinaryToDecimal(torch.nn.Module): | ||
"""A Module to convert binaries encoded tensors to decimals. | ||
This is a utility class that allow to convert a binary encoding tensor (e.g. `1001`) to | ||
its decimal value (e.g. `9`) | ||
Args: | ||
num_bits (int): the number of bits to use for the bases table. | ||
The number of bits must be lower or equal to the input length and the input length | ||
must be divisible by ``num_bits``. If ``num_bits`` is lower than the number of | ||
bits in the input, the end result will be aggregated on the last dimension using | ||
:func:`~torch.sum`. | ||
device (torch.device): the device where inputs and outputs are to be expected. | ||
dtype (torch.dtype): the output dtype. | ||
convert_to_binary (bool, optional): if ``True``, the input to the ``forward`` | ||
method will be cast to a binary input using :func:`~torch.heavyside`. | ||
Defaults to ``False``. | ||
Examples: | ||
>>> binary_to_decimal = BinaryToDecimal( | ||
... num_bits=4, device="cpu", dtype=torch.int32, convert_to_binary=True | ||
... ) | ||
>>> binary = torch.Tensor([[0, 0, 1, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 10, 0]]) | ||
>>> decimal = binary_to_decimal(binary) | ||
>>> assert decimal.shape == (2,) | ||
>>> assert (decimal == torch.Tensor([3, 2])).all() | ||
""" | ||
|
||
def __init__( | ||
self, | ||
num_bits: int, | ||
device: torch.device, | ||
dtype: torch.dtype, | ||
convert_to_binary: bool = False, | ||
): | ||
super().__init__() | ||
self.convert_to_binary = convert_to_binary | ||
self.bases = 2 ** torch.arange(num_bits - 1, -1, -1, device=device, dtype=dtype) | ||
self.num_bits = num_bits | ||
self.zero_tensor = torch.zeros((1,), device=device) | ||
|
||
def forward(self, features: torch.Tensor) -> torch.Tensor: | ||
num_features = features.shape[-1] | ||
if self.num_bits > num_features: | ||
raise ValueError(f"{num_features=} is less than {self.num_bits=}") | ||
elif num_features % self.num_bits != 0: | ||
raise ValueError(f"{num_features=} is not divisible by {self.num_bits=}") | ||
|
||
binary_features = ( | ||
torch.heaviside(features, self.zero_tensor) | ||
if self.convert_to_binary | ||
else features | ||
) | ||
feature_parts = binary_features.reshape(shape=(-1, self.num_bits)) | ||
digits = torch.vmap(torch.dot, (None, 0))( | ||
self.bases, feature_parts.to(self.bases.dtype) | ||
) | ||
digits = digits.reshape(shape=(-1, features.shape[-1] // self.num_bits)) | ||
aggregated_digits = torch.sum(digits, dim=-1) | ||
return aggregated_digits | ||
|
||
|
||
class SipHash(torch.nn.Module): | ||
"""A Module to Compute SipHash values for given tensors. | ||
A hash function module based on SipHash implementation in python. | ||
Args: | ||
as_tensor (bool, optional): if ``True``, the bytes will be turned into integers | ||
through the builtin ``hash`` function and mapped to a tensor. Default: ``True``. | ||
.. warning:: This module relies on the builtin ``hash`` function. | ||
To get reproducible results across runs, the ``PYTHONHASHSEED`` environment | ||
variable must be set before the code is run (changing this value during code | ||
execution is without effect). | ||
Examples: | ||
>>> # Assuming we set PYTHONHASHSEED=0 prior to running this code | ||
>>> a = torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) | ||
>>> b = a.clone() | ||
>>> hash_module = SipHash(as_tensor=True) | ||
>>> hash_a = hash_module(a) | ||
>>> hash_a | ||
tensor([-4669941682990263259, -3778166555168484291, -9122128731510687521]) | ||
>>> hash_b = hash_module(b) | ||
>>> assert (hash_a == hash_b).all() | ||
""" | ||
|
||
def __init__(self, as_tensor: bool = True): | ||
super().__init__() | ||
self.as_tensor = as_tensor | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor | List[bytes]: | ||
hash_values = [] | ||
if x.dtype in (torch.bfloat16,): | ||
x = x.to(torch.float16) | ||
for x_i in x.detach().cpu().numpy(): | ||
hash_value = x_i.tobytes() | ||
hash_values.append(hash_value) | ||
if not self.as_tensor: | ||
return hash_value | ||
result = torch.tensor([hash(x) for x in hash_values], dtype=torch.int64) | ||
return result | ||
|
||
|
||
class RandomProjectionHash(SipHash): | ||
"""A module that combines random projections with SipHash to get a low-dimensional tensor, easier to embed through :class:`~.SipHash`. | ||
This module requires sklearn to be installed. | ||
Keyword Args: | ||
n_components (int, optional): the low-dimensional number of components of the projections. | ||
Defaults to 16. | ||
dtype_cast (torch.dtype, optional): the dtype to cast the projection to. | ||
Defaults to ``torch.bfloat16``. | ||
as_tensor (bool, optional): if ``True``, the bytes will be turned into integers | ||
through the builtin ``hash`` function and mapped to a tensor. Default: ``True``. | ||
.. warning:: This module relies on the builtin ``hash`` function. | ||
To get reproducible results across runs, the ``PYTHONHASHSEED`` environment | ||
variable must be set before the code is run (changing this value during code | ||
execution is without effect). | ||
init_method: TODO | ||
""" | ||
|
||
_N_COMPONENTS_DEFAULT = 16 | ||
|
||
def __init__( | ||
self, | ||
*, | ||
n_components: int | None = None, | ||
dtype_cast=torch.bfloat16, | ||
as_tensor: bool = True, | ||
init_method: Callable[[torch.Tensor], torch.Tensor | None] | None = None, | ||
**kwargs, | ||
): | ||
if n_components is None: | ||
n_components = self._N_COMPONENTS_DEFAULT | ||
|
||
super().__init__(as_tensor=as_tensor) | ||
self.register_buffer("_n_components", torch.as_tensor(n_components)) | ||
|
||
self._init = False | ||
if init_method is None: | ||
init_method = torch.nn.init.normal_ | ||
self.init_method = init_method | ||
|
||
self.dtype_cast = dtype_cast | ||
self.register_buffer("transform", torch.nn.UninitializedBuffer()) | ||
|
||
@property | ||
def n_components(self): | ||
return self._n_components.item() | ||
|
||
def fit(self, x): | ||
"""Fits the random projection to the input data.""" | ||
self.transform.materialize( | ||
(x.shape[-1], self.n_components), dtype=self.dtype_cast, device=x.device | ||
) | ||
self.init_method(self.transform) | ||
self._init = True | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
if not self._init: | ||
self.fit(x) | ||
elif not self._init: | ||
raise RuntimeError( | ||
f"The {type(self).__name__} has not been initialized. Call fit before calling this method." | ||
) | ||
x = x.to(self.dtype_cast) @ self.transform | ||
return super().forward(x) |
1a4b2cc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold
2
.benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400]
35.13226131049105
iter/sec (stddev: 0.16820353058832604
)226.47160250795446
iter/sec (stddev: 0.0008046443795841312
)6.45
This comment was automatically generated by workflow using github-action-benchmark.
CC: @vmoens