Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] check_env_specs() fails when run_type_checks=True #1567

Closed
3 tasks done
Sefank opened this issue Sep 22, 2023 · 0 comments · Fixed by #1570
Closed
3 tasks done

[BUG] check_env_specs() fails when run_type_checks=True #1567

Sefank opened this issue Sep 22, 2023 · 0 comments · Fixed by #1570
Assignees
Labels
bug Something isn't working

Comments

@Sefank
Copy link

Sefank commented Sep 22, 2023

Describe the bug

When defining a custom environment using EnvBase and setting run_type_checks=True in the constructor of the base class, you will never get a sanity check check_env_specs() passed. It seems that it is due to an unexpected attempt to invoke a get method on CompositeSpec instance which does not have this method in design.

To Reproduce

Any arbitrary environment based on EnvBase with run_type_checks=True will produce the error when check_env_specs() is invoked:

from typing import Optional

import torch
from tensordict import TensorDict, TensorDictBase
from torchrl.envs import EnvBase
from torchrl.data import (
    CompositeSpec,
    UnboundedContinuousTensorSpec,
    OneHotDiscreteTensorSpec,
)
from torchrl.envs.utils import check_env_specs


class DemoEnv(EnvBase):
    batch_locked = True

    def __init__(self):
        super().__init__(run_type_checks=True)

        # For demo purpose, params are hard coded.
        self.node_num = 3
        self.batch_size = (2,)

        self.set_seed(torch.empty((), dtype=torch.int64).random_().item())
        self._make_spec()

    def _make_spec(self):
        """Set the observation, action, and reward specs."""
        self.observation_spec = CompositeSpec(
            current_node=OneHotDiscreteTensorSpec(
                n=self.node_num,
                shape=(*self.batch_size, self.node_num),
                dtype=torch.bool,
            ),
            shape=self.batch_size,
        )
        self.state_spec = self.observation_spec.clone()
        self.action_spec = UnboundedContinuousTensorSpec(shape=self.batch_size)
        self.reward_spec = UnboundedContinuousTensorSpec(shape=self.batch_size)

    def _set_seed(self, seed: Optional[int]):
        rng = torch.manual_seed(seed)
        self.rng = rng

    def _step(self, td: TensorDictBase) -> TensorDict:
        # Trivial step function for demo purpose.
        out = TensorDict(
            {
                "current_node": td["current_node"],
                "reward": td["action"],
                "done": td["done"],
            },
            td.shape,
        )
        return out

    def _reset(self, td) -> TensorDict:
        current_node = torch.zeros(*self.batch_size, self.node_num).bool()
        current_node[:, 0] = True

        out = TensorDict(
            {
                "current_node": current_node,
            },
            self.batch_size,
        )

        return out

env = DemoEnv()
check_env_specs(env)
Traceback (most recent call last):
  File "<project-path>/demo.py", line 72, in <module>
    check_env_specs(env)
  File "<project-path>/venv/lib/python3.11/site-packages/torchrl/envs/utils.py", line 436, in check_env_specs
    real_tensordict = env.rollout(3, return_contiguous=return_contiguous)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<project-path>/venv/lib/python3.11/site-packages/torchrl/envs/common.py", line 1585, in rollout
    tensordict = self.step(tensordict)
                 ^^^^^^^^^^^^^^^^^^^^^
  File "<project-path>/venv/lib/python3.11/site-packages/torchrl/envs/common.py", line 1149, in step
    next_tensordict = self._step_proc_data(next_tensordict)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<project-path>/venv/lib/python3.11/site-packages/torchrl/envs/common.py", line 1216, in _step_proc_data
    is not self.output_spec["full_done_spec"].get(done_key).dtype
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'CompositeSpec' object has no attribute 'get'. Did you mean: 'set'?

Expected behavior

check_env_specs() will pass as normal.

System info

  • TorchRL version: nightly-2023.9.21
  • Python version: 3.11.5 | packaged by conda-forge | (main, Aug 27 2023, 03:34:09) [GCC 12.3.0]
  • OS Platform: Linux

Reason and Possible fixes

Here provides a possible and direct fix, but I'm not sure if it is correct:

https://github.com/Sefank/rl/commit/3e44b4b32a67a2031b2f73766a9c98f2be9c4b33

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@Sefank Sefank added the bug Something isn't working label Sep 22, 2023
@vmoens vmoens linked a pull request Sep 23, 2023 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants