Skip to content

Commit

Permalink
[Refactor] compile compatibility improvements
Browse files Browse the repository at this point in the history
ghstack-source-id: 95f8241b56e42b80e828485cb5f377288bff6f5e
Pull Request resolved: #2578
  • Loading branch information
vmoens committed Nov 18, 2024
1 parent 507766a commit db7f08d
Show file tree
Hide file tree
Showing 15 changed files with 176 additions and 127 deletions.
22 changes: 0 additions & 22 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3172,28 +3172,6 @@ def make_and_test_policy(
)


@pytest.mark.parametrize(
"ctype", [SyncDataCollector, MultiaSyncDataCollector, MultiSyncDataCollector]
)
def test_no_stopiteration(ctype):
# Tests that there is no StopIteration raised and that the length of the collector is properly set
if ctype is SyncDataCollector:
envs = SerialEnv(16, CountingEnv)
else:
envs = [SerialEnv(8, CountingEnv), SerialEnv(8, CountingEnv)]

collector = ctype(create_env_fn=envs, frames_per_batch=173, total_frames=300)
try:
c_iter = iter(collector)
for i in range(len(collector)): # noqa: B007
c = next(c_iter)
assert c is not None
assert i == 1
finally:
collector.shutdown()
del collector


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
5 changes: 2 additions & 3 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta):
_iterator = None
total_frames: int
frames_per_batch: int
requested_frames_per_batch: int
trust_policy: bool
compiled_policy: bool
cudagraphed_policy: bool
Expand Down Expand Up @@ -306,7 +305,7 @@ def __class_getitem__(self, index):

def __len__(self) -> int:
if self.total_frames > 0:
return -(self.total_frames // -self.requested_frames_per_batch)
return -(self.total_frames // -self.frames_per_batch)
raise RuntimeError("Non-terminating collectors do not have a length")


Expand Down Expand Up @@ -701,7 +700,7 @@ def __init__(
remainder = total_frames % frames_per_batch
if remainder != 0 and RL_WARNINGS:
warnings.warn(
f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). "
f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch})."
f"This means {frames_per_batch - remainder} additional frames will be collected."
"To silence this message, set the environment variable RL_WARNINGS to False."
)
Expand Down
6 changes: 3 additions & 3 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2312,10 +2312,10 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Bounded:
dest_device = torch.device(dest)
if dest_device == self.device and dest_dtype == self.dtype:
return self
self.space.device = dest_device
space = self.space.to(dest_device)
return Bounded(
low=self.space.low,
high=self.space.high,
low=space.low,
high=space.high,
shape=self.shape,
device=dest_device,
dtype=dest_dtype,
Expand Down
11 changes: 7 additions & 4 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,12 +1356,15 @@ def _start_workers(self) -> None:

from torchrl.envs.env_creator import EnvCreator

num_threads = max(
1, torch.get_num_threads() - self.num_workers
) # 1 more thread for this proc

if self.num_threads is None:
self.num_threads = max(
1, torch.get_num_threads() - self.num_workers
) # 1 more thread for this proc
self.num_threads = num_threads

torch.set_num_threads(self.num_threads)
if self.num_threads != torch.get_num_threads():
torch.set_num_threads(self.num_threads)

if self._mp_start_method is not None:
ctx = mp.get_context(self._mp_start_method)
Expand Down
21 changes: 14 additions & 7 deletions torchrl/modules/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,6 @@ def __init__(
event_dims: int | None = None,
tanh_loc: bool = False,
safe_tanh: bool = True,
**kwargs,
):
if not isinstance(loc, torch.Tensor):
loc = torch.as_tensor(loc, dtype=torch.get_default_dtype())
Expand Down Expand Up @@ -683,6 +682,7 @@ def __init__(
event_dims: int = 1,
atol: float = 1e-6,
rtol: float = 1e-6,
safe: bool = True,
):
minmax_msg = "high value has been found to be equal or less than low value"
if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor):
Expand All @@ -695,12 +695,19 @@ def __init__(
if not all(high > low):
raise ValueError(minmax_msg)

t = SafeTanhTransform()
non_trivial_min = (isinstance(low, torch.Tensor) and (low != -1.0).any()) or (
not isinstance(low, torch.Tensor) and low != -1.0
if safe:
if is_dynamo_compiling():
_err_compile_safetanh()
t = SafeTanhTransform()
else:
t = torch.distributions.TanhTransform()
non_trivial_min = is_dynamo_compiling or (
(isinstance(low, torch.Tensor) and (low != -1.0).any())
or (not isinstance(low, torch.Tensor) and low != -1.0)
)
non_trivial_max = (isinstance(high, torch.Tensor) and (high != 1.0).any()) or (
not isinstance(high, torch.Tensor) and high != 1.0
non_trivial_max = is_dynamo_compiling or (
(isinstance(high, torch.Tensor) and (high != 1.0).any())
or (not isinstance(high, torch.Tensor) and high != 1.0)
)
self.non_trivial = non_trivial_min or non_trivial_max

Expand Down Expand Up @@ -778,7 +785,7 @@ def _uniform_sample_delta(dist: Delta, size=None) -> torch.Tensor:
def _err_compile_safetanh():
raise RuntimeError(
"safe_tanh=True in TanhNormal is not compatible with torch.compile with torch pre 2.6.0. "
"To deactivate it, pass safe_tanh=False. "
" To deactivate it, pass safe_tanh=False. "
"If you are using a ProbabilisticTensorDictModule, this can be done via "
"`distribution_kwargs={'safe_tanh': False}`. "
"See https://github.com/pytorch/pytorch/issues/133529 for more details."
Expand Down
13 changes: 10 additions & 3 deletions torchrl/modules/distributions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
from torch import autograd, distributions as d
from torch.distributions import Independent, Transform, TransformedDistribution

try:
from torch.compiler import is_dynamo_compiling
except ImportError:
from torch._dynamo import is_compiling as is_dynamo_compiling


def _cast_device(elt: Union[torch.Tensor, float], device) -> Union[torch.Tensor, float]:
if isinstance(elt, torch.Tensor):
Expand Down Expand Up @@ -40,10 +45,12 @@ class FasterTransformedDistribution(TransformedDistribution):
__doc__ = __doc__ + TransformedDistribution.__doc__

def __init__(self, base_distribution, transforms, validate_args=None):
if is_dynamo_compiling():
return super().__init__(
base_distribution, transforms, validate_args=validate_args
)
if isinstance(transforms, Transform):
self.transforms = [
transforms,
]
self.transforms = [transforms]
elif isinstance(transforms, list):
raise ValueError("Make a ComposeTransform first.")
else:
Expand Down
5 changes: 5 additions & 0 deletions torchrl/modules/models/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,12 @@ def __init__(
state_dim,
action_dim,
config: dict | DTConfig = None,
device: torch.device | None = None,
):
if device is not None:
with torch.device(device):
return self.__init__(state_dim, action_dim, config)

if not _has_transformers:
raise ImportError(
"transformers is not installed. Please install it with `pip install transformers`."
Expand Down
4 changes: 4 additions & 0 deletions torchrl/modules/tensordict_module/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1783,6 +1783,7 @@ class DecisionTransformerInferenceWrapper(TensorDictModuleWrapper):
For example for an observation input of shape [batch_size, context, obs_dim] with context=20 and inference_context=5, the first 15 entries
of the context will be masked. Defaults to 5.
spec (Optional[TensorSpec]): The spec of the input TensorDict. If None, it will be inferred from the policy module.
device (torch.device, optional): if provided, the device where the buffers / specs will be placed.
Examples:
>>> import torch
Expand Down Expand Up @@ -1836,6 +1837,7 @@ def __init__(
*,
inference_context: int = 5,
spec: Optional[TensorSpec] = None,
device: torch.device | None = None,
):
super().__init__(policy)
self.observation_key = "observation"
Expand All @@ -1857,6 +1859,8 @@ def __init__(
self._spec[self.action_key] = None
else:
self._spec = Composite({key: None for key in policy.out_keys})
if device is not None:
self._spec = self._spec.to(device)
self.checked = False

@property
Expand Down
3 changes: 2 additions & 1 deletion torchrl/modules/tensordict_module/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def _forward_hook_safe_action(module, tensordict_in, tensordict_out):
keys = [out_key]
values = [spec]
else:
keys = list(spec.keys(True, True))
# Make dynamo happy with the list creation
keys = [key for key in spec.keys(True, True)] # noqa: C416
values = [spec[key] for key in keys]
for _spec, _key in zip(values, keys):
if _spec is None:
Expand Down
Loading

0 comments on commit db7f08d

Please sign in to comment.