Skip to content

Commit

Permalink
[Refactor] compile compatibility improvements
Browse files Browse the repository at this point in the history
ghstack-source-id: 6d5df6e8d633ce1887beebf675796dcd627ca397
Pull Request resolved: #2578
  • Loading branch information
vmoens committed Nov 18, 2024
1 parent 568cd33 commit 8912c14
Show file tree
Hide file tree
Showing 26 changed files with 564 additions and 278 deletions.
2 changes: 1 addition & 1 deletion benchmarks/test_objectives_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
) # Anything from 2.5, incl. nightlies, allows for fullgraph


@pytest.fixture(scope="module")
@pytest.fixture(scope="module", autouse=True)
def set_default_device():
cur_device = torch.get_default_device()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Expand Down
100 changes: 97 additions & 3 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os.path
import time
import unittest
import warnings
from functools import wraps

# Get relative file path
Expand All @@ -20,9 +21,15 @@
import torch
import torch.cuda

from tensordict import tensorclass, TensorDict
from torch import nn
from torchrl._utils import implement_for, logger as torchrl_logger, seed_generator
from tensordict import NestedKey, tensorclass, TensorDict, TensorDictBase
from tensordict.nn import TensorDictModuleBase
from torch import nn, vmap
from torchrl._utils import (
implement_for,
logger as torchrl_logger,
RL_WARNINGS,
seed_generator,
)
from torchrl.data.utils import CloudpickleWrapper

from torchrl.envs import MultiThreadedEnv, ObservationNorm
Expand All @@ -35,6 +42,7 @@
ToTensorImage,
TransformedEnv,
)
from torchrl.objectives.value.advantages import _vmap_func

# Specified for test_utils.py
__version__ = "0.3"
Expand Down Expand Up @@ -713,3 +721,89 @@ def forward(
):
input = self.mlp(input)
return self._lstm(input, hidden0_in, hidden1_in)


def _call_value_nets(
value_net: TensorDictModuleBase,
data: TensorDictBase,
params: TensorDictBase,
next_params: TensorDictBase,
single_call: bool,
value_key: NestedKey,
detach_next: bool,
vmap_randomness: str = "error",
):
in_keys = value_net.in_keys
if single_call:
for i, name in enumerate(data.names):
if name == "time":
ndim = i + 1
break
else:
ndim = None
if ndim is not None:
# get data at t and last of t+1
idx0 = (slice(None),) * (ndim - 1) + (slice(-1, None),)
idx = (slice(None),) * (ndim - 1) + (slice(None, -1),)
idx_ = (slice(None),) * (ndim - 1) + (slice(1, None),)
data_in = torch.cat(
[
data.select(*in_keys, value_key, strict=False),
data.get("next").select(*in_keys, value_key, strict=False)[idx0],
],
ndim - 1,
)
else:
if RL_WARNINGS:
warnings.warn(
"Got a tensordict without a time-marked dimension, assuming time is along the last dimension. "
"This warning can be turned off by setting the environment variable RL_WARNINGS to False."
)
ndim = data.ndim
idx = (slice(None),) * (ndim - 1) + (slice(None, data.shape[ndim - 1]),)
idx_ = (slice(None),) * (ndim - 1) + (slice(data.shape[ndim - 1], None),)
data_in = torch.cat(
[
data.select(*in_keys, value_key, strict=False),
data.get("next").select(*in_keys, value_key, strict=False),
],
ndim - 1,
)

# next_params should be None or be identical to params
if next_params is not None and next_params is not params:
raise ValueError(
"the value at t and t+1 cannot be retrieved in a single call without recurring to vmap when both params and next params are passed."
)
if params is not None:
with params.to_module(value_net):
value_est = value_net(data_in).get(value_key)
else:
value_est = value_net(data_in).get(value_key)
value, value_ = value_est[idx], value_est[idx_]
else:
data_in = torch.stack(
[
data.select(*in_keys, value_key, strict=False),
data.get("next").select(*in_keys, value_key, strict=False),
],
0,
)
if (params is not None) ^ (next_params is not None):
raise ValueError(
"params and next_params must be either both provided or not."
)
elif params is not None:
params_stack = torch.stack([params, next_params], 0).contiguous()
data_out = _vmap_func(value_net, (0, 0), randomness=vmap_randomness)(
data_in, params_stack
)
else:
data_out = vmap(value_net, (0,), randomness=vmap_randomness)(data_in)
value_est = data_out.get(value_key)
value, value_ = value_est[0], value_est[1]
data.set(value_key, value)
data.set(("next", value_key), value_)
if detach_next:
value_ = value_.detach()
return value, value_
22 changes: 0 additions & 22 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3172,28 +3172,6 @@ def make_and_test_policy(
)


@pytest.mark.parametrize(
"ctype", [SyncDataCollector, MultiaSyncDataCollector, MultiSyncDataCollector]
)
def test_no_stopiteration(ctype):
# Tests that there is no StopIteration raised and that the length of the collector is properly set
if ctype is SyncDataCollector:
envs = SerialEnv(16, CountingEnv)
else:
envs = [SerialEnv(8, CountingEnv), SerialEnv(8, CountingEnv)]

collector = ctype(create_env_fn=envs, frames_per_batch=173, total_frames=300)
try:
c_iter = iter(collector)
for i in range(len(collector)): # noqa: B007
c = next(c_iter)
assert c is not None
assert i == 1
finally:
collector.shutdown()
del collector


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
3 changes: 2 additions & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@
ValueEstimators,
)
from torchrl.objectives.value.advantages import (
_call_value_nets,
GAE,
TD1Estimator,
TDLambdaEstimator,
Expand Down Expand Up @@ -135,13 +134,15 @@

if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import ( # noqa
_call_value_nets,
dtype_fixture,
get_available_devices,
get_default_devices,
)
from pytorch.rl.test.mocking_classes import ContinuousActionConvMockEnv
else:
from _utils_internal import ( # noqa
_call_value_nets,
dtype_fixture,
get_available_devices,
get_default_devices,
Expand Down
20 changes: 17 additions & 3 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@
set_exploration_type,
)

try:
from torch.compiler import cudagraph_mark_step_begin
except ImportError:

def cudagraph_mark_step_begin():
"""Placeholder for missing cudagraph_mark_step_begin method."""
raise NotImplementedError("cudagraph_mark_step_begin not implemented.")


_TIMEOUT = 1.0
INSTANTIATE_TIMEOUT = 20
_MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory
Expand Down Expand Up @@ -138,7 +147,6 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta):
_iterator = None
total_frames: int
frames_per_batch: int
requested_frames_per_batch: int
trust_policy: bool
compiled_policy: bool
cudagraphed_policy: bool
Expand Down Expand Up @@ -297,7 +305,7 @@ def __class_getitem__(self, index):

def __len__(self) -> int:
if self.total_frames > 0:
return -(self.total_frames // -self.requested_frames_per_batch)
return -(self.total_frames // -self.frames_per_batch)
raise RuntimeError("Non-terminating collectors do not have a length")


Expand Down Expand Up @@ -692,7 +700,7 @@ def __init__(
remainder = total_frames % frames_per_batch
if remainder != 0 and RL_WARNINGS:
warnings.warn(
f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). "
f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch})."
f"This means {frames_per_batch - remainder} additional frames will be collected."
"To silence this message, set the environment variable RL_WARNINGS to False."
)
Expand Down Expand Up @@ -833,6 +841,8 @@ def _make_final_rollout(self):
policy_input_clone = (
policy_input.clone()
) # to test if values have changed in-place
if self.compiled_policy:
cudagraph_mark_step_begin()
policy_output = self.policy(policy_input)

# check that we don't have exclusive keys, because they don't appear in keys
Expand Down Expand Up @@ -1146,7 +1156,11 @@ def rollout(self) -> TensorDictBase:
else:
policy_input = self._shuttle
# we still do the assignment for security
if self.compiled_policy:
cudagraph_mark_step_begin()
policy_output = self.policy(policy_input)
if self.compiled_policy:
policy_output = policy_output.clone()
if self._shuttle is not policy_output:
# ad-hoc update shuttle
self._shuttle.update(
Expand Down
14 changes: 11 additions & 3 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,6 +1392,7 @@ def type_check(self, value: torch.Tensor, key: NestedKey | None = None) -> None:
spec.type_check(val)

def is_in(self, value) -> bool:
raise RuntimeError
if self.dim == 0 and not hasattr(value, "unbind"):
# We don't use unbind because value could be a tuple or a nested tensor
return all(
Expand Down Expand Up @@ -1820,6 +1821,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
return val

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
if self.mask is None:
shape = torch.broadcast_shapes(self._safe_shape, val.shape)
shape_match = val.shape == shape
Expand Down Expand Up @@ -2270,6 +2272,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
return val

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
val_shape = _remove_neg_shapes(tensordict.utils._shape(val))
shape = torch.broadcast_shapes(self._safe_shape, val_shape)
shape = list(shape)
Expand Down Expand Up @@ -2312,10 +2315,10 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Bounded:
dest_device = torch.device(dest)
if dest_device == self.device and dest_dtype == self.dtype:
return self
self.space.device = dest_device
space = self.space.to(dest_device)
return Bounded(
low=self.space.low,
high=self.space.high,
low=space.low,
high=space.high,
shape=self.shape,
device=dest_device,
dtype=dest_dtype,
Expand Down Expand Up @@ -2467,6 +2470,7 @@ def one(self, shape=None):
)

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
shape = torch.broadcast_shapes(self._safe_shape, val.shape)
return (
isinstance(val, NonTensorData)
Expand Down Expand Up @@ -2659,6 +2663,7 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor:
return torch.empty(shape, device=self.device, dtype=self.dtype).random_()

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
shape = torch.broadcast_shapes(self._safe_shape, val.shape)
return val.shape == shape and val.dtype == self.dtype

Expand Down Expand Up @@ -3007,6 +3012,7 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten
return torch.cat(out, -1)

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
vals = self._split(val)
if vals is None:
return False
Expand Down Expand Up @@ -3352,6 +3358,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
return val

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
if self.mask is None:
shape = torch.broadcast_shapes(self._safe_shape, val.shape)
shape_match = val.shape == shape
Expand Down Expand Up @@ -3977,6 +3984,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
return val.squeeze(0) if val_is_scalar else val

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
if self.mask is not None:
vals = val.unbind(-1)
splits = self._split_self()
Expand Down
11 changes: 7 additions & 4 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,12 +1356,15 @@ def _start_workers(self) -> None:

from torchrl.envs.env_creator import EnvCreator

num_threads = max(
1, torch.get_num_threads() - self.num_workers
) # 1 more thread for this proc

if self.num_threads is None:
self.num_threads = max(
1, torch.get_num_threads() - self.num_workers
) # 1 more thread for this proc
self.num_threads = num_threads

torch.set_num_threads(self.num_threads)
if self.num_threads != torch.get_num_threads():
torch.set_num_threads(self.num_threads)

if self._mp_start_method is not None:
ctx = mp.get_context(self._mp_start_method)
Expand Down
3 changes: 1 addition & 2 deletions torchrl/envs/gym_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,6 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:

if _reward is not None:
reward = reward + _reward

terminated, truncated, done, do_break = self.read_done(
terminated=terminated, truncated=truncated, done=done
)
Expand All @@ -323,7 +322,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
# if truncated/terminated is not in the keys, we just don't pass it even if it
# is defined.
if terminated is None:
terminated = done
terminated = done.clone()
if truncated is not None:
obs_dict["truncated"] = truncated
obs_dict["done"] = done
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1423,7 +1423,7 @@ def _make_compatible_policy(
env_maker=None,
env_maker_kwargs=None,
):
if trust_policy:
if trust_policy or isinstance(policy, torch._dynamo.eval_frame.OptimizedModule):
return policy
if policy is None:
input_spec = None
Expand Down
14 changes: 14 additions & 0 deletions torchrl/modules/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

from tensordict.nn import NormalParamExtractor
from torch import distributions as torch_dist

from .continuous import (
Delta,
Expand Down Expand Up @@ -37,3 +38,16 @@
OneHotOrdinal,
)
}

HAS_ENTROPY = {
Delta: False,
IndependentNormal: True,
TanhDelta: False,
TanhNormal: False,
TruncatedNormal: False,
MaskedCategorical: False,
MaskedOneHotCategorical: False,
OneHotCategorical: True,
torch_dist.Categorical: True,
torch_dist.Normal: True,
}
Loading

0 comments on commit 8912c14

Please sign in to comment.