Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Fine control over devices in collectors #1835

Merged
merged 69 commits into from
Jan 30, 2024
Merged
Changes from 1 commit
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
35d8847
init
vmoens Jan 25, 2024
0f9fad2
amend
vmoens Jan 25, 2024
213c477
amend
vmoens Jan 25, 2024
a1fd23f
amend
vmoens Jan 25, 2024
17aef47
clear_device_
vmoens Jan 25, 2024
601d6d6
stream
vmoens Jan 25, 2024
439022b
amend
vmoens Jan 25, 2024
d26a8d2
amend
vmoens Jan 25, 2024
0e20e03
amend
vmoens Jan 26, 2024
c56fda4
amend
vmoens Jan 26, 2024
bd9ee28
amend
vmoens Jan 26, 2024
61b1f87
amend
vmoens Jan 26, 2024
03a3747
amend
vmoens Jan 26, 2024
515de7d
amend
vmoens Jan 26, 2024
b994996
amend
vmoens Jan 26, 2024
dc264c6
amend
vmoens Jan 26, 2024
12a8ae1
amend
vmoens Jan 28, 2024
2150f0b
amend
vmoens Jan 28, 2024
0759eb1
lint
vmoens Jan 28, 2024
5d3d9c4
amend
vmoens Jan 28, 2024
c1870f9
amend
vmoens Jan 29, 2024
b033426
amend
vmoens Jan 29, 2024
0dcbc0e
Merge remote-tracking branch 'origin/main' into fix-device-collectors
vmoens Jan 29, 2024
7759b3f
amend
vmoens Jan 29, 2024
bd06a5e
amend
vmoens Jan 29, 2024
f0406eb
amend
vmoens Jan 29, 2024
b63aca3
amend
vmoens Jan 29, 2024
0f08d26
amend
vmoens Jan 29, 2024
2e4445c
amend
vmoens Jan 29, 2024
14ab95c
Merge remote-tracking branch 'origin/main' into fix-device-collectors
vmoens Jan 29, 2024
a4d1df6
amend
vmoens Jan 29, 2024
5976de0
Merge remote-tracking branch 'origin/main' into fix-device-collectors
vmoens Jan 29, 2024
b215157
amend
vmoens Jan 29, 2024
5560b70
amend
vmoens Jan 30, 2024
b98a806
amend
vmoens Jan 30, 2024
1f1dd92
amend
vmoens Jan 30, 2024
fc228d9
amend
vmoens Jan 30, 2024
1ef6ce1
amend
vmoens Jan 30, 2024
6258a92
amend
vmoens Jan 30, 2024
1e2fc62
amend
vmoens Jan 30, 2024
4feac80
amend
vmoens Jan 30, 2024
b12e564
amend
vmoens Jan 30, 2024
90d2343
amend
vmoens Jan 30, 2024
78d7c43
Merge remote-tracking branch 'origin/fix-device-collectors' into fix-…
vmoens Jan 30, 2024
8fd0ba3
amend
vmoens Jan 30, 2024
cec3ab1
amend
vmoens Jan 30, 2024
377944c
lint
vmoens Jan 30, 2024
4688f98
amend
vmoens Jan 30, 2024
9ec6786
amend
vmoens Jan 30, 2024
3b0653a
amend
vmoens Jan 30, 2024
9dcc6c3
amend
vmoens Jan 30, 2024
0d921c5
amend
vmoens Jan 30, 2024
8412a0c
amend
vmoens Jan 30, 2024
27c8605
amend
vmoens Jan 30, 2024
6326c6d
amend
vmoens Jan 30, 2024
5350bcf
amend
vmoens Jan 30, 2024
ecaa8b0
amend
vmoens Jan 30, 2024
5312610
amend
vmoens Jan 30, 2024
1221b4e
amend
vmoens Jan 30, 2024
6df335a
amend
vmoens Jan 30, 2024
7cd328e
amend
vmoens Jan 30, 2024
6f4c637
amend
vmoens Jan 30, 2024
52a1197
amend
vmoens Jan 30, 2024
f66c73b
amend
vmoens Jan 30, 2024
abe7436
lint
vmoens Jan 30, 2024
5f86a24
amend
vmoens Jan 30, 2024
205d224
lint
vmoens Jan 30, 2024
76bb4cf
amend
vmoens Jan 30, 2024
a7110b0
remove asserts
vmoens Jan 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed Jan 30, 2024
commit 6258a9272a8ad3e7a60b9ab80748efe40762ccde
147 changes: 70 additions & 77 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,17 +683,18 @@ def __init__(
)
self.return_same_td = return_same_td

self._tensordict = env.reset()
traj_ids = torch.arange(self.n_env, device=self.env_device).view(
# Shuttle is a deviceless tensordict that just carried data from env to policy and policy to env
self._shuttle = env.reset().clear_device_()
traj_ids = torch.arange(self.n_env, device=self.storing_device).view(
self.env.batch_size
)
self._tensordict.set(
self._shuttle.set(
("collector", "traj_ids"),
traj_ids,
)

with torch.no_grad():
self._tensordict_out = self.env.fake_tensordict()
self._final_rollout = self.env.fake_tensordict()

# If storing device is not None, we use this to cast the storage.
# If it is None and the env and policy are on the same device,
Expand All @@ -702,12 +703,12 @@ def __init__(
# In all other cases, we can't really put a device on the storage,
# since at least one data source has a device that is not clear.
if self.storing_device:
self._tensordict_out = self._tensordict_out.to(
self._final_rollout = self._final_rollout.to(
self.storing_device, non_blocking=True
)
else:
# erase all devices
self._tensordict_out.clear_device_()
self._final_rollout.clear_device_()

# If the policy has a valid spec, we use it
if (
Expand All @@ -716,29 +717,29 @@ def __init__(
and all(v is not None for v in self.policy.spec.values(True, True))
):
if any(
key not in self._tensordict_out.keys(isinstance(key, tuple))
key not in self._final_rollout.keys(isinstance(key, tuple))
for key in self.policy.spec.keys(True, True)
):
# if policy spec is non-empty, all the values are not None and the keys
# match the out_keys we assume the user has given all relevant information
# the policy could have more keys than the env:
policy_spec = self.policy.spec
if policy_spec.ndim < self._tensordict_out.ndim:
policy_spec = policy_spec.expand(self._tensordict_out.shape)
if policy_spec.ndim < self._final_rollout.ndim:
policy_spec = policy_spec.expand(self._final_rollout.shape)
for key, spec in policy_spec.items(True, True):
if key in self._tensordict_out.keys(isinstance(key, tuple)):
if key in self._final_rollout.keys(isinstance(key, tuple)):
continue
self._tensordict_out.set(key, spec.zero())
self._final_rollout.set(key, spec.zero())

else:
# otherwise, we perform a small number of steps with the policy to
# determine the relevant keys with which to pre-populate _tensordict_out.
# determine the relevant keys with which to pre-populate _final_rollout.
# This is the safest thing to do if the spec has None fields or if there is
# no spec at all.
# See #505 for additional context.
self._tensordict_out.update(self._tensordict)
self._final_rollout.update(self._shuttle)
with torch.no_grad():
input = self._tensordict_out
input = self._final_rollout
if self.policy_device:
input = input.to(self.policy_device)
elif self._cast_to_policy_device and input.device is not None:
Expand All @@ -749,33 +750,25 @@ def __init__(
_tensordict_out = _tensordict_out.to(self.storing_device, non_blocking=True)
else:
_tensordict_out.clear_device_()
self._tensordict_out = _tensordict_out
self._final_rollout = _tensordict_out

self._tensordict_out = (
self._tensordict_out.unsqueeze(-1)
self._final_rollout = (
self._final_rollout.unsqueeze(-1)
.expand(*env.batch_size, self.frames_per_batch)
.clone()
.zero_()
)
# in addition to outputs of the policy, we add traj_ids to
# _tensordict_out which will be collected during rollout
print("env", self.env_device)
print("policy", self.policy_device)
print("traj ids", torch.zeros(
*self._tensordict_out.batch_size,
dtype=torch.int64,
device=self.env_device, # will be cast to storing_device if needed
))
print(self._tensordict_out)
self._tensordict_out.set(
# _final_rollout which will be collected during rollout
self._final_rollout.set(
("collector", "traj_ids"),
torch.zeros(
*self._tensordict_out.batch_size,
*self._final_rollout.batch_size,
dtype=torch.int64,
device=self.env_device, # will be cast to storing_device if needed
device=self.storing_device,
),
)
self._tensordict_out.refine_names(..., "time")
self._final_rollout.refine_names(..., "time")

if split_trajs is None:
split_trajs = False
Expand Down Expand Up @@ -859,7 +852,7 @@ def cuda_check(tensor: torch.Tensor):
if tensor.is_cuda:
cuda_devices.add(tensor.device)

self._tensordict_out.apply(cuda_check)
self._final_rollout.apply(cuda_check)
for device in cuda_devices:
streams.append(torch.cuda.Stream(device, priority=-1))
events.append(streams[-1].record_event())
Expand Down Expand Up @@ -929,12 +922,12 @@ def _update_traj_ids(self, tensordict) -> None:
tensordict.get("next"), done_keys=self.env.done_keys
)
if traj_sop.any():
traj_ids = self._tensordict.get(("collector", "traj_ids"))
traj_ids = self._shuttle.get(("collector", "traj_ids"))
traj_ids = traj_ids.clone()
traj_ids[traj_sop] = traj_ids.max() + torch.arange(
1, traj_sop.sum() + 1, device=traj_ids.device
)
self._tensordict.set(("collector", "traj_ids"), traj_ids)
self._shuttle.set(("collector", "traj_ids"), traj_ids)

@torch.no_grad()
def rollout(self) -> TensorDictBase:
Expand All @@ -945,94 +938,94 @@ def rollout(self) -> TensorDictBase:

"""
if self.reset_at_each_iter:
self._tensordict.update(self.env.reset())
self._shuttle.update(self.env.reset())

# self._tensordict.fill_(("collector", "step_count"), 0)
self._tensordict_out.fill_(("collector", "traj_ids"), -1)
# self._shuttle.fill_(("collector", "step_count"), 0)
self._final_rollout.fill_(("collector", "traj_ids"), -1)
tensordicts = []
with set_exploration_type(self.exploration_type):
for t in range(self.frames_per_batch):
if (
self.init_random_frames is not None
and self._frames < self.init_random_frames
):
self.env.rand_action(self._tensordict)
self.env.rand_action(self._shuttle)
else:
if self._cast_to_policy_device:
if self.policy_device is not None:
self._tensordict = self._tensordict.to(self.policy_device, non_blocking=True)
policy_input = self._shuttle.to(self.policy_device, non_blocking=True)
elif self.policy_device is None:
# we know the tensordict has a device otherwise we would not be here
self._tensordict.clear_device_()
policy_input = self._shuttle.copy().clear_device_()
else:
policy_input = self._shuttle
# we still do the assignment for security
_tensordict_in = self._tensordict
_tensordict_out = self.policy(_tensordict_in)
if _tensordict_out is not _tensordict_in:
self._tensordict.update(_tensordict_out)
policy_output = self.policy(policy_input)
# update is a no-op if identities match, so this is safe and efficient in all cases
self._shuttle.update(policy_output, inplace=True)

if self._cast_to_policy_device:
if self.env_device is not None:
self._tensordict = self._tensordict.to(self.env_device)
env_input = self._shuttle.to(self.env_device)
elif self.env_device is None:
# we know the tensordict has a device otherwise we would not be here
self._tensordict.clear_device_()
env_input = self._shuttle.copy().clear_device_()
else:
env_input = self._shuttle

tensordict, tensordict_ = self.env.step_and_maybe_reset(
self._tensordict
)
self._tensordict = tensordict_.set(
"collector", tensordict.get("collector").clone(False)
env_output, env_next_output = self.env.step_and_maybe_reset(
env_input
)
self._shuttle.update(env_next_output.set(
"collector", env_output.get("collector").copy()
), inplace=True)

if self.storing_device is not None:
tensordicts.append(
tensordict.to(self.storing_device, non_blocking=True)
env_output.to(self.storing_device, non_blocking=True)
)
else:
tensordicts.append(tensordict)
tensordicts.append(env_output)

self._update_traj_ids(tensordict)
self._update_traj_ids(env_output)
if (
self.interruptor is not None
and self.interruptor.collection_stopped()
):
try:
torch.stack(
tensordicts,
self._tensordict_out.ndim - 1,
out=self._tensordict_out[: t + 1],
self._final_rollout.ndim - 1,
out=self._final_rollout[: t + 1],
)
except RuntimeError:
with self._tensordict_out.unlock_():
with self._final_rollout.unlock_():
torch.stack(
tensordicts,
self._tensordict_out.ndim - 1,
out=self._tensordict_out[: t + 1],
self._final_rollout.ndim - 1,
out=self._final_rollout[: t + 1],
)
break
else:
print(tensordicts)
print("\n\n\n")
print("self._tensordict_out", self._tensordict_out)
print("\n\n\n")
try:
self._tensordict_out = torch.stack(
self._final_rollout = torch.stack(
tensordicts,
self._tensordict_out.ndim - 1,
out=self._tensordict_out,
self._final_rollout.ndim - 1,
out=self._final_rollout,
)
except RuntimeError:
with self._tensordict_out.unlock_():
self._tensordict_out = torch.stack(
with self._final_rollout.unlock_():
self._final_rollout = torch.stack(
tensordicts,
self._tensordict_out.ndim - 1,
out=self._tensordict_out,
self._final_rollout.ndim - 1,
out=self._final_rollout,
)
return self._tensordict_out
return self._final_rollout

def reset(self, index=None, **kwargs) -> None:
"""Resets the environments to a new initial state."""
# metadata
md = self._tensordict.get("collector").clone()
md = self._shuttle.get("collector").clone()
if index is not None:
# check that the env supports partial reset
if prod(self.env.batch_size) == 0:
Expand All @@ -1046,20 +1039,20 @@ def reset(self, index=None, **kwargs) -> None:
device=self.env.device,
)
_reset[index] = 1
self._tensordict.set(reset_key, _reset)
self._shuttle.set(reset_key, _reset)
else:
_reset = None
self._tensordict.zero_()
self._shuttle.zero_()

self._tensordict.update(self.env.reset(**kwargs))
self._shuttle.update(self.env.reset(**kwargs))
md["traj_ids"] = md["traj_ids"] - md["traj_ids"].min()
self._tensordict["collector"] = md
self._shuttle["collector"] = md

def shutdown(self) -> None:
"""Shuts down all workers and/or closes the local environment."""
if not self.closed:
self.closed = True
del self._tensordict, self._tensordict_out
del self._shuttle, self._final_rollout
if not self.env.is_closed:
self.env.close()
del self.env
Expand Down Expand Up @@ -1124,7 +1117,7 @@ def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None:
def __repr__(self) -> str:
env_str = indent(f"env={self.env}", 4 * " ")
policy_str = indent(f"policy={self.policy}", 4 * " ")
td_out_str = indent(f"td_out={self._tensordict_out}", 4 * " ")
td_out_str = indent(f"td_out={self._final_rollout}", 4 * " ")
string = (
f"{self.__class__.__name__}("
f"\n{env_str},"
Expand Down
Loading