Skip to content

Commit

Permalink
[Feature]: Added support for TensorDictSequence module subsampling (#332
Browse files Browse the repository at this point in the history
)
  • Loading branch information
nicolas-dufour authored Aug 10, 2022
1 parent d4b2ced commit 69e7948
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 13 deletions.
78 changes: 75 additions & 3 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1528,7 +1528,54 @@ def test_vmap_probabilistic(self, safe, spec_type):
assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all()

@pytest.mark.parametrize("functional", [True, False])
def test_submodule_sequence(self, functional):
td_module_1 = TensorDictModule(
nn.Linear(3, 2),
in_keys=["in"],
out_keys=["hidden"],
)
td_module_2 = TensorDictModule(
nn.Linear(2, 4),
in_keys=["hidden"],
out_keys=["out"],
)
td_module = TensorDictSequence(td_module_1, td_module_2)

if functional:
td_1 = TensorDict({"in": torch.randn(5, 3)}, [5])
sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"])
sub_seq_1, (params, buffers) = sub_seq_1.make_functional_with_buffers()
sub_seq_1(
td_1,
params=params,
buffers=buffers,
)
assert "hidden" in td_1.keys()
assert "out" not in td_1.keys()
td_2 = TensorDict({"hidden": torch.randn(5, 2)}, [5])
sub_seq_2 = td_module.select_subsequence(in_keys=["hidden"])
sub_seq_2, (params, buffers) = sub_seq_2.make_functional_with_buffers()
sub_seq_2(
td_2,
params=params,
buffers=buffers,
)
assert "out" in td_2.keys()
assert td_2.get("out").shape == torch.Size([5, 4])
else:
td_1 = TensorDict({"in": torch.randn(5, 3)}, [5])
sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"])
sub_seq_1(td_1)
assert "hidden" in td_1.keys()
assert "out" not in td_1.keys()
td_2 = TensorDict({"hidden": torch.randn(5, 2)}, [5])
sub_seq_2 = td_module.select_subsequence(in_keys=["hidden"])
sub_seq_2(td_2)
assert "out" in td_2.keys()
assert td_2.get("out").shape == torch.Size([5, 4])

@pytest.mark.parametrize("stack", [True, False])
@pytest.mark.parametrize("functional", [True, False])
def test_sequential_partial(self, stack, functional):
torch.manual_seed(0)
param_multiplier = 2
Expand Down Expand Up @@ -1621,7 +1668,32 @@ def test_sequential_partial(self, stack, functional):
assert "out" in td.keys()
assert "b" in td.keys()

def test_subsequence_weight_update(self):
td_module_1 = TensorDictModule(
nn.Linear(3, 2),
in_keys=["in"],
out_keys=["hidden"],
)
td_module_2 = TensorDictModule(
nn.Linear(2, 4),
in_keys=["hidden"],
out_keys=["out"],
)
td_module = TensorDictSequence(td_module_1, td_module_2)

td_1 = TensorDict({"in": torch.randn(5, 3)}, [5])
sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"])
copy = sub_seq_1[0].module.weight.clone()

opt = torch.optim.SGD(td_module.parameters(), lr=0.1)
opt.zero_grad()
td_1 = td_module(td_1)
td_1["out"].mean().backward()
opt.step()

assert not torch.allclose(copy, sub_seq_1[0].module.weight)
assert torch.allclose(td_module[0].module.weight, sub_seq_1[0].module.weight)

if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
69 changes: 59 additions & 10 deletions torchrl/modules/tensordict_module/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,17 @@ def __init__(
*modules: TensorDictModule,
partial_tolerant: bool = False,
):
in_keys, out_keys = self._compute_in_and_out_keys(modules)

super().__init__(
spec=None,
module=nn.ModuleList(list(modules)),
in_keys=in_keys,
out_keys=out_keys,
)
self.partial_tolerant = partial_tolerant

def _compute_in_and_out_keys(self, modules: List[TensorDictModule]) -> Tuple[List]:
in_keys = []
out_keys = []
for module in modules:
Expand All @@ -137,14 +148,7 @@ def __init__(
for i, out_key in enumerate(out_keys)
if out_key not in out_keys[i + 1 :]
]

super().__init__(
spec=None,
module=nn.ModuleList(list(modules)),
in_keys=in_keys,
out_keys=out_keys,
)
self.partial_tolerant = partial_tolerant
return in_keys, out_keys

@staticmethod
def _find_functional_module(module: TensorDictModule) -> nn.Module:
Expand Down Expand Up @@ -199,6 +203,47 @@ def _split_param(
out.append(param_list[a:b])
return out

def select_subsequence(
self, in_keys: Iterable[str] = None, out_keys: Iterable[str] = None
) -> "TensorDictSequence":
"""
Returns a new TensorDictSequence with only the modules that are necessary to compute
the given output keys with the given input keys.
Args:
in_keys: input keys of the subsequence we want to select
out_keys: output keys of the subsequence we want to select
Returns:
A new TensorDictSequence with only the modules that are necessary acording to the given input and output keys.
"""
if in_keys is None:
in_keys = deepcopy(self.in_keys)
if out_keys is None:
out_keys = deepcopy(self.out_keys)
id_to_keep = {i for i in range(len(self.module))}
for i, module in enumerate(self.module):
if all(key in in_keys for key in module.in_keys):
in_keys.extend(module.out_keys)
else:
id_to_keep.remove(i)
for i, module in reversed(list(enumerate(self.module))):
if i in id_to_keep:
if any(key in out_keys for key in module.out_keys):
out_keys.extend(module.in_keys)
else:
id_to_keep.remove(i)
id_to_keep = sorted(list(id_to_keep))

modules = [self.module[i] for i in id_to_keep]

if modules == []:
raise ValueError(
"No modules left after selection. Make sure that in_keys and out_keys are coherent."
)

return TensorDictSequence(*modules)

def _run_module(self, module, tensordict, **kwargs):
tensordict_keys = set(tensordict.keys())
if not self.partial_tolerant or all(
Expand All @@ -214,8 +259,12 @@ def _run_module(self, module, tensordict, **kwargs):
return tensordict

def forward(
self, tensordict: TensorDictBase, tensordict_out=None, **kwargs
self,
tensordict: TensorDictBase,
tensordict_out=None,
**kwargs,
) -> TensorDictBase:

if "params" in kwargs and "buffers" in kwargs:
param_splits = self._split_param(kwargs["params"], "params")
buffer_splits = self._split_param(kwargs["buffers"], "buffers")
Expand Down Expand Up @@ -252,7 +301,7 @@ def forward(
tensordict = self._run_module(module, tensordict, **kwargs)
else:
raise RuntimeError(
"TensorDictSequence does not support keyword arguments other than 'tensordict_out', 'params', 'buffers' and 'vmap'"
"TensorDictSequence does not support keyword arguments other than 'tensordict_out', 'in_keys', 'out_keys' 'params', 'buffers' and 'vmap'"
)
if tensordict_out is not None:
tensordict_out.update(tensordict, inplace=True)
Expand Down

0 comments on commit 69e7948

Please sign in to comment.