Skip to content

Commit

Permalink
[Feature] TorchRL2Gym conversion (pytorch#1795)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jan 19, 2024
1 parent 57139bd commit c3ffb5a
Show file tree
Hide file tree
Showing 23 changed files with 1,964 additions and 141 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,7 @@ to be able to create this other composition:
RewardScaling
RewardSum
Reward2GoTransform
RemoveEmptySpecs
SelectTransform
SignTransform
SqueezeTransform
Expand Down
2 changes: 1 addition & 1 deletion test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _set_gym_environments(): # noqa: F811
PONG_VERSIONED = "ALE/Pong-v5"


@implement_for("gymnasium", "0.27.0", None)
@implement_for("gymnasium")
def _set_gym_environments(): # noqa: F811
global CARTPOLE_VERSIONED, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, PONG_VERSIONED

Expand Down
4 changes: 3 additions & 1 deletion test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def _step(
assert (a.sum(-1) == 1).all()

obs = self._get_in_obs(tensordict.get(self._out_key)) + a / self.maxstep
tensordict = tensordict.empty() # empty tensordict
tensordict = tensordict.empty()

tensordict.set(self.out_key, self._get_out_obs(obs))
tensordict.set(self._out_key, self._get_out_obs(obs))
Expand Down Expand Up @@ -603,6 +603,7 @@ def _reset(self, tensordict: TensorDictBase) -> TensorDictBase:
# state = torch.zeros(self.size) + self.counter
if tensordict is None:
tensordict = TensorDict({}, self.batch_size, device=self.device)

tensordict = tensordict.empty()
tensordict.update(self.observation_spec.rand())
# tensordict.set("next_" + self.out_key, self._get_out_obs(state))
Expand All @@ -622,6 +623,7 @@ def _step(
a = tensordict.get("action")

obs = self._obs_step(self._get_in_obs(tensordict.get(self._out_key)), a)

tensordict = tensordict.empty() # empty tensordict

tensordict.set(self.out_key, self._get_out_obs(obs))
Expand Down
Loading

0 comments on commit c3ffb5a

Please sign in to comment.