Skip to content

Commit

Permalink
[Feature] Masking actions (pytorch#1421)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Sep 3, 2023
1 parent 545a28c commit 9fded1a
Show file tree
Hide file tree
Showing 9 changed files with 762 additions and 62 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ to be able to create this other composition:

Transform
TransformedEnv
ActionMask
BinarizeReward
CatFrames
CatTensors
Expand Down
120 changes: 120 additions & 0 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3169,6 +3169,126 @@ def get_all_keys(spec: TensorSpec, include_exclusive: bool):
return keys


@pytest.mark.parametrize("shape", ((), (1,), (2, 3), (2, 3, 4)))
@pytest.mark.parametrize(
"spectype", ["one_hot", "categorical", "mult_one_hot", "mult_discrete"]
)
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("rand_shape", ((), (2,), (2, 3)))
class TestSpecMasking:
def _make_mask(self, shape):
torch.manual_seed(0)
mask = torch.zeros(shape, dtype=torch.bool).bernoulli_()
if len(shape) == 1:
while not mask.any() or mask.all():
mask = torch.zeros(shape, dtype=torch.bool).bernoulli_()
return mask
mask_view = mask.view(-1, shape[-1])
for i in range(mask_view.shape[0]):
t = mask_view[i]
while not t.any() or t.all():
t.copy_(torch.zeros_like(t).bernoulli_())
return mask

def _one_hot_spec(self, shape, device, n):
shape = torch.Size([*shape, n])
mask = self._make_mask(shape).to(device)
return OneHotDiscreteTensorSpec(n, shape, device, mask=mask)

def _mult_one_hot_spec(self, shape, device, n):
shape = torch.Size([*shape, n + n + 2])
mask = torch.cat(
[
self._make_mask(shape[:-1] + (n,)).to(device),
self._make_mask(shape[:-1] + (n + 2,)).to(device),
],
-1,
)
return MultiOneHotDiscreteTensorSpec([n, n + 2], shape, device, mask=mask)

def _discrete_spec(self, shape, device, n):
mask = self._make_mask(torch.Size([*shape, n])).to(device)
return DiscreteTensorSpec(n, shape, device, mask=mask)

def _mult_discrete_spec(self, shape, device, n):
shape = torch.Size([*shape, 2])
mask = torch.cat(
[
self._make_mask(shape[:-1] + (n,)).to(device),
self._make_mask(shape[:-1] + (n + 2,)).to(device),
],
-1,
)
return MultiDiscreteTensorSpec([n, n + 2], shape, device, mask=mask)

def test_equal(self, shape, device, spectype, rand_shape, n=5):
shape = torch.Size(shape)
spec = (
self._one_hot_spec(shape, device, n=n)
if spectype == "one_hot"
else self._discrete_spec(shape, device, n=n)
if spectype == "categorical"
else self._mult_one_hot_spec(shape, device, n=n)
if spectype == "mult_one_hot"
else self._mult_discrete_spec(shape, device, n=n)
if spectype == "mult_discrete"
else None
)
spec_clone = spec.clone()
assert spec == spec_clone
assert spec.unsqueeze(0).squeeze(0) == spec
spec.update_mask(~spec.mask)
assert (spec.mask != spec_clone.mask).any()
assert spec != spec_clone

def test_is_in(self, shape, device, spectype, rand_shape, n=5):
shape = torch.Size(shape)
rand_shape = torch.Size(rand_shape)
spec = (
self._one_hot_spec(shape, device, n=n)
if spectype == "one_hot"
else self._discrete_spec(shape, device, n=n)
if spectype == "categorical"
else self._mult_one_hot_spec(shape, device, n=n)
if spectype == "mult_one_hot"
else self._mult_discrete_spec(shape, device, n=n)
if spectype == "mult_discrete"
else None
)
s = spec.rand(rand_shape)
assert spec.is_in(s)
spec.update_mask(~spec.mask)
assert not spec.is_in(s)

def test_project(self, shape, device, spectype, rand_shape, n=5):
shape = torch.Size(shape)
rand_shape = torch.Size(rand_shape)
spec = (
self._one_hot_spec(shape, device, n=n)
if spectype == "one_hot"
else self._discrete_spec(shape, device, n=n)
if spectype == "categorical"
else self._mult_one_hot_spec(shape, device, n=n)
if spectype == "mult_one_hot"
else self._mult_discrete_spec(shape, device, n=n)
if spectype == "mult_discrete"
else None
)
s = spec.rand(rand_shape)
assert (spec.project(s) == s).all()
spec.update_mask(~spec.mask)
sp = spec.project(s)
assert sp.shape == s.shape
if spectype == "one_hot":
assert (sp != s).any(-1).all()
assert (sp.any(-1)).all()
elif spectype == "mult_one_hot":
assert (sp != s).any(-1).all()
assert (sp.sum(-1) == 2).all()
else:
assert (sp != s).all()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
111 changes: 111 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
UnboundedContinuousTensorSpec,
)
from torchrl.envs import (
ActionMask,
BinarizeReward,
CatFrames,
CatTensors,
Expand Down Expand Up @@ -8183,6 +8184,116 @@ def test_kl_lstm(self):
klt(env.rollout(3, policy))


class TestActionMask(TransformBase):
@property
def _env_class(self):
from torchrl.data import BinaryDiscreteTensorSpec, DiscreteTensorSpec

class MaskedEnv(EnvBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.action_spec = DiscreteTensorSpec(4)
self.state_spec = CompositeSpec(
mask=BinaryDiscreteTensorSpec(4, dtype=torch.bool)
)
self.observation_spec = CompositeSpec(
obs=UnboundedContinuousTensorSpec(3),
mask=BinaryDiscreteTensorSpec(4, dtype=torch.bool),
)
self.reward_spec = UnboundedContinuousTensorSpec(1)

def _reset(self, tensordict):
td = self.observation_spec.rand()
td.update(torch.ones_like(self.state_spec.rand()))
return td

def _step(self, data):
td = self.observation_spec.rand()
mask = data.get("mask")
action = data.get("action")
mask = mask.scatter(-1, action.unsqueeze(-1), 0)

td.set("mask", mask)
td.set("reward", self.reward_spec.rand())
td.set("done", ~(mask.any().view(1)))
return td

def _set_seed(self, seed):
return seed

return MaskedEnv

def test_single_trans_env_check(self):
env = self._env_class()
env = TransformedEnv(env, ActionMask())
check_env_specs(env)

def test_serial_trans_env_check(self):
env = SerialEnv(2, lambda: TransformedEnv(self._env_class(), ActionMask()))
check_env_specs(env)

def test_parallel_trans_env_check(self):
env = ParallelEnv(2, lambda: TransformedEnv(self._env_class(), ActionMask()))
check_env_specs(env)

def test_trans_serial_env_check(self):
env = TransformedEnv(SerialEnv(2, self._env_class), ActionMask())
check_env_specs(env)

def test_trans_parallel_env_check(self):
env = TransformedEnv(ParallelEnv(2, self._env_class), ActionMask())
check_env_specs(env)

def test_transform_no_env(self):
t = ActionMask()
with pytest.raises(RuntimeError, match="parent cannot be None"):
t._call(TensorDict({}, []))

def test_transform_compose(self):
env = self._env_class()
env = TransformedEnv(env, Compose(ActionMask()))
check_env_specs(env)

def test_transform_env(self):
env = TransformedEnv(ContinuousActionVecMockEnv(), ActionMask())
with pytest.raises(ValueError, match="The action spec must be one of"):
env.rollout(2)
env = self._env_class()
env = TransformedEnv(env, ActionMask())
td = env.reset()
for _ in range(1000):
td = env.rand_action(td)
assert env.action_spec.is_in(td.get("action"))
td = env.step(td)
td = step_mdp(td)
if td.get("done"):
break
else:
raise RuntimeError
assert not td.get("mask").any()

def test_transform_model(self):
t = ActionMask()
with pytest.raises(
RuntimeError, match="ActionMask must be executed within an environment"
):
t(TensorDict({}, []))

def test_transform_rb(self):
t = ActionMask()
rb = ReplayBuffer(storage=LazyTensorStorage(100))
rb.append_transform(t)
rb.extend(TensorDict({"a": [1]}, [1]).expand(10))
with pytest.raises(
RuntimeError, match="ActionMask must be executed within an environment"
):
rb.sample(3)

def test_transform_inverse(self):
# no inverse transform
return


class TestDeviceCastTransform(TransformBase):
def test_single_trans_env_check(self):
env = ContinuousActionVecMockEnv(device="cpu:0")
Expand Down
Loading

0 comments on commit 9fded1a

Please sign in to comment.