Skip to content

[BUG] check_env_specs + PixelRenderTransform does not tolerate "cuda" device #2236

Closed
@N00bcak

Description

Describe the bug

Running check_env_specs on a TransformedEnv which contains the PixelRenderTransform fails.

To Reproduce

# BEFORE THE PROGRAM EVEN RUNS, FORCE THE START METHOD TO BE 'SPAWN'
from torch import multiprocessing as mp
mp.set_start_method("spawn", force = True)

from copy import deepcopy
import tqdm
import numpy as np

import math

import torch
from torch import nn
import torch.distributions as D

from torchrl.envs import check_env_specs, PettingZooEnv, ParallelEnv
from torchrl.modules import ProbabilisticActor
from torchrl.modules.models import MLP
from torchrl.modules.models.multiagent import MultiAgentNetBase
from torchrl.collectors import SyncDataCollector
from tensordict.nn import TensorDictModule, TensorDictSequential, NormalParamExtractor
from torchrl.envs import TransformedEnv, Compose, RewardSum, StepCounter
from torchrl.record import CSVLogger, VideoRecorder, PixelRenderTransform

EPS = 1e-7
    
# Main Function
if __name__ == "__main__":    
    NUM_AGENTS = 3
    NUM_CRITICS = 2
    NUM_EXPLORE_WORKERS = 1
    EXPLORATION_STEPS = 30000
    MAX_EPISODE_STEPS = 1000
    DEVICE = "cuda"
    REPLAY_BUFFER_SIZE = int(1e6)
    VALUE_GAMMA = 0.99
    MAX_GRAD_NORM = 1.0
    BATCH_SIZE = 512
    LR = 3e-4
    UPDATE_STEPS_PER_EXPLORATION = 1500
    WARMUP_STEPS = int(3e5)
    TRAIN_TIMESTEPS = int(1e7)
    EVAL_INTERVAL = 10
    EVAL_EPISODES = 20

    SEED = 42
    torch.manual_seed(SEED)
    np.random.seed(SEED)

    def env_fn(mode, parallel = True, rew_scale = True):

        if rew_scale:
            terminate_scale = -3.0
            forward_scale = 2.5
            fall_scale = -3.0
        else:
            # Use the defaults from PZ
            terminate_scale, forward_scale, fall_scale = -100.0, 1.0, -10.0

        def base_env_fn():
            return PettingZooEnv(task = "multiwalker_v9", 
                                    parallel = True,
                                    seed = 42,
                                    n_walkers = NUM_AGENTS, 
                                    terminate_reward = terminate_scale,
                                    forward_reward = forward_scale,
                                    fall_reward = fall_scale,
                                    shared_reward = False, 
                                    max_cycles = MAX_EPISODE_STEPS, 
                                    render_mode = mode, 
                                    device = DEVICE
                                )

        env = base_env_fn # noqa: E731

        def env_with_transforms():
            init_env = env()
            init_env = TransformedEnv(init_env, Compose(
                                            StepCounter(max_steps = MAX_EPISODE_STEPS),
                                            RewardSum(
                                                in_keys = [init_env.reward_key for _ in range(NUM_AGENTS)], 
                                                out_keys = [("walker", "episode_reward")] * NUM_AGENTS, 
                                                reset_keys = ["_reset"] * NUM_AGENTS
                                            ),
                                        )
                                    )
            return init_env

        return env_with_transforms

    train_env = env_fn(None, parallel = False)()

    if train_env.is_closed:
        train_env.start()


    def create_eval_env(tag = "rendered"):
        
        eval_env = env_fn("rgb_array", parallel = False, rew_scale = False)()
        video_recorder = VideoRecorder(
                                        CSVLogger("multiwalker-toy-test", video_format = "mp4"), 
                                        tag = tag, 
                                        in_keys = ["pixels_record"]
                                    )
        
        # Call the parent's render function
        eval_env.append_transform(PixelRenderTransform(out_keys = ["pixels_record"]))
        eval_env.append_transform(video_recorder)

        if eval_env.is_closed:
            eval_env.start()
        return eval_env

    check_env_specs(train_env)
    eval_env = create_eval_env()
    check_env_specs(eval_env)

    train_env.close()
File "/mnt/c/Users/N00bcak/Desktop/programming/drones_go_brr/scripts/torchrl_cuda_hangs.py", line 115, in <module>
    check_env_specs(eval_env)
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/envs/utils.py", line 728, in check_env_specs
    fake_tensordict = env.fake_tensordict()
                      ^^^^^^^^^^^^^^^^^^^^^
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/envs/common.py", line 2922, in fake_tensordict
    observation_spec = self.observation_spec
                       ^^^^^^^^^^^^^^^^^^^^^
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/envs/common.py", line 1303, in observation_spec
    observation_spec = self.output_spec["full_observation_spec"]
                       ^^^^^^^^^^^^^^^^
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/envs/transforms/transforms.py", line 748, in output_spec
    output_spec = self.transform.transform_output_spec(output_spec)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/envs/transforms/transforms.py", line 1104, in transform_output_spec
    output_spec = t.transform_output_spec(output_spec)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/envs/transforms/transforms.py", line 376, in transform_output_spec
    output_spec["full_observation_spec"] = self.transform_observation_spec(
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/record/recorder.py", line 501, in transform_observation_spec
    observation_spec[self.out_keys[0]] = spec
    ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/data/tensor_specs.py", line 3783, in __setitem__
    raise RuntimeError(
RuntimeError: Setting a new attribute (pixels_record) on another device (cuda:0 against cuda). All devices of CompositeSpec must match.

Expected behavior

check_env_specs succeeds and program terminates.

System info

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, source, ...)
  • Python version
  • Versions of any other relevant libraries
>>> import torchrl, numpy, sys
>>> print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
0.4.0 1.26.4 3.11.9 (main, Jun  5 2024, 10:27:27) [GCC 12.2.0] linux

Reason and Possible fixes

A strict check appears to be conducted on the device strings, which results in the error.

For consistency with PyTorch in general, can consider substituting "cuda" with f"cuda:{torch.cuda.current_device()}"

Depending on availability of current_device() on other devices, can consider implementing checks for those too.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions