Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Mar 25, 2024
1 parent f1a2196 commit 7c5e618
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2664,7 +2664,8 @@ def _step(self, tensordict):
"reward": action.sum().unsqueeze(0),
**self.full_done_spec.zero(),
"observation": obs,
}
},
batch_size=[],
)

torch.manual_seed(0)
Expand Down Expand Up @@ -2829,19 +2830,19 @@ def test_single_task_share_individual_td():

def test_stackable():
# Tests the _stackable util
stack = [TensorDict({"a": 0}), TensorDict({"b": 1})]
stack = [TensorDict({"a": 0}, []), TensorDict({"b": 1}, [])]
assert not _stackable(*stack), torch.stack(stack)
stack = [TensorDict({"a": [0]}), TensorDict({"a": 1})]
stack = [TensorDict({"a": [0]}, []), TensorDict({"a": 1}, [])]
assert not _stackable(*stack)
stack = [TensorDict({"a": [0]}), TensorDict({"a": [1]})]
stack = [TensorDict({"a": [0]}, []), TensorDict({"a": [1]}, [])]
assert _stackable(*stack)
stack = [TensorDict({"a": [0]}), TensorDict({"a": [1], "b": {}})]
stack = [TensorDict({"a": [0]}, []), TensorDict({"a": [1], "b": {}}, [])]
assert _stackable(*stack)
stack = [TensorDict({"a": {"b": [0]}}), TensorDict({"a": {"b": [1]}})]
stack = [TensorDict({"a": {"b": [0]}}, []), TensorDict({"a": {"b": [1]}}, [])]
assert _stackable(*stack)
stack = [TensorDict({"a": {"b": [0]}}), TensorDict({"a": {"b": 1}})]
stack = [TensorDict({"a": {"b": [0]}}, []), TensorDict({"a": {"b": 1}}, [])]
assert not _stackable(*stack)
stack = [TensorDict({"a": "a string"}), TensorDict({"a": "another string"})]
stack = [TensorDict({"a": "a string"}, []), TensorDict({"a": "another string"}, [])]
assert _stackable(*stack)


Expand Down

0 comments on commit 7c5e618

Please sign in to comment.