[BUG] check_env_specs
+ PixelRenderTransform
does not tolerate "cuda" device #2236
Closed
Description
Describe the bug
Running check_env_specs
on a TransformedEnv
which contains the PixelRenderTransform
fails.
To Reproduce
# BEFORE THE PROGRAM EVEN RUNS, FORCE THE START METHOD TO BE 'SPAWN'
from torch import multiprocessing as mp
mp.set_start_method("spawn", force = True)
from copy import deepcopy
import tqdm
import numpy as np
import math
import torch
from torch import nn
import torch.distributions as D
from torchrl.envs import check_env_specs, PettingZooEnv, ParallelEnv
from torchrl.modules import ProbabilisticActor
from torchrl.modules.models import MLP
from torchrl.modules.models.multiagent import MultiAgentNetBase
from torchrl.collectors import SyncDataCollector
from tensordict.nn import TensorDictModule, TensorDictSequential, NormalParamExtractor
from torchrl.envs import TransformedEnv, Compose, RewardSum, StepCounter
from torchrl.record import CSVLogger, VideoRecorder, PixelRenderTransform
EPS = 1e-7
# Main Function
if __name__ == "__main__":
NUM_AGENTS = 3
NUM_CRITICS = 2
NUM_EXPLORE_WORKERS = 1
EXPLORATION_STEPS = 30000
MAX_EPISODE_STEPS = 1000
DEVICE = "cuda"
REPLAY_BUFFER_SIZE = int(1e6)
VALUE_GAMMA = 0.99
MAX_GRAD_NORM = 1.0
BATCH_SIZE = 512
LR = 3e-4
UPDATE_STEPS_PER_EXPLORATION = 1500
WARMUP_STEPS = int(3e5)
TRAIN_TIMESTEPS = int(1e7)
EVAL_INTERVAL = 10
EVAL_EPISODES = 20
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
def env_fn(mode, parallel = True, rew_scale = True):
if rew_scale:
terminate_scale = -3.0
forward_scale = 2.5
fall_scale = -3.0
else:
# Use the defaults from PZ
terminate_scale, forward_scale, fall_scale = -100.0, 1.0, -10.0
def base_env_fn():
return PettingZooEnv(task = "multiwalker_v9",
parallel = True,
seed = 42,
n_walkers = NUM_AGENTS,
terminate_reward = terminate_scale,
forward_reward = forward_scale,
fall_reward = fall_scale,
shared_reward = False,
max_cycles = MAX_EPISODE_STEPS,
render_mode = mode,
device = DEVICE
)
env = base_env_fn # noqa: E731
def env_with_transforms():
init_env = env()
init_env = TransformedEnv(init_env, Compose(
StepCounter(max_steps = MAX_EPISODE_STEPS),
RewardSum(
in_keys = [init_env.reward_key for _ in range(NUM_AGENTS)],
out_keys = [("walker", "episode_reward")] * NUM_AGENTS,
reset_keys = ["_reset"] * NUM_AGENTS
),
)
)
return init_env
return env_with_transforms
train_env = env_fn(None, parallel = False)()
if train_env.is_closed:
train_env.start()
def create_eval_env(tag = "rendered"):
eval_env = env_fn("rgb_array", parallel = False, rew_scale = False)()
video_recorder = VideoRecorder(
CSVLogger("multiwalker-toy-test", video_format = "mp4"),
tag = tag,
in_keys = ["pixels_record"]
)
# Call the parent's render function
eval_env.append_transform(PixelRenderTransform(out_keys = ["pixels_record"]))
eval_env.append_transform(video_recorder)
if eval_env.is_closed:
eval_env.start()
return eval_env
check_env_specs(train_env)
eval_env = create_eval_env()
check_env_specs(eval_env)
train_env.close()
File "/mnt/c/Users/N00bcak/Desktop/programming/drones_go_brr/scripts/torchrl_cuda_hangs.py", line 115, in <module>
check_env_specs(eval_env)
File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/envs/utils.py", line 728, in check_env_specs
fake_tensordict = env.fake_tensordict()
^^^^^^^^^^^^^^^^^^^^^
File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/envs/common.py", line 2922, in fake_tensordict
observation_spec = self.observation_spec
^^^^^^^^^^^^^^^^^^^^^
File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/envs/common.py", line 1303, in observation_spec
observation_spec = self.output_spec["full_observation_spec"]
^^^^^^^^^^^^^^^^
File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/envs/transforms/transforms.py", line 748, in output_spec
output_spec = self.transform.transform_output_spec(output_spec)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/envs/transforms/transforms.py", line 1104, in transform_output_spec
output_spec = t.transform_output_spec(output_spec)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/envs/transforms/transforms.py", line 376, in transform_output_spec
output_spec["full_observation_spec"] = self.transform_observation_spec(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/record/recorder.py", line 501, in transform_observation_spec
observation_spec[self.out_keys[0]] = spec
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/data/tensor_specs.py", line 3783, in __setitem__
raise RuntimeError(
RuntimeError: Setting a new attribute (pixels_record) on another device (cuda:0 against cuda). All devices of CompositeSpec must match.
Expected behavior
check_env_specs
succeeds and program terminates.
System info
Describe the characteristic of your environment:
- Describe how the library was installed (pip, source, ...)
- Python version
- Versions of any other relevant libraries
>>> import torchrl, numpy, sys
>>> print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
0.4.0 1.26.4 3.11.9 (main, Jun 5 2024, 10:27:27) [GCC 12.2.0] linux
Reason and Possible fixes
A strict check appears to be conducted on the device strings, which results in the error.
For consistency with PyTorch
in general, can consider substituting "cuda" with f"cuda:{torch.cuda.current_device()}"
Depending on availability of current_device()
on other devices, can consider implementing checks for those too.
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)