Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]: Added support for TensorDictSequence module subsampling #332

Merged
merged 12 commits into from
Aug 10, 2022
Prev Previous commit
Next Next commit
Changed to module subsequencing
  • Loading branch information
nicolas-dufour committed Aug 5, 2022
commit 974eb9750278ae673b54eb623b53a18d6fc416b6
50 changes: 37 additions & 13 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1542,42 +1542,66 @@ def test_submodule_sequence(self, functional):
td_module = TensorDictSequence(td_module_1, td_module_2)

if functional:
td_module, (params, buffers) = td_module.make_functional_with_buffers()
td_0 = TensorDict({"in": torch.randn(5, 3)}, [5])
td_module(td_0, params=params, buffers=buffers)
assert td_0.get("out").shape == torch.Size([5, 4])
td_1 = TensorDict({"in": torch.randn(5, 3)}, [5])
td_module(
sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"])
sub_seq_1, (params, buffers) = sub_seq_1.make_functional_with_buffers()
sub_seq_1(
td_1,
out_keys_filter=["hidden"],
params=params,
buffers=buffers,
)
assert "hidden" in td_1.keys()
assert "out" not in td_1.keys()
td_2 = TensorDict({"hidden": torch.randn(5, 2)}, [5])
td_module(
sub_seq_2 = td_module.select_subsequence(in_keys=["hidden"])
sub_seq_2, (params, buffers) = sub_seq_2.make_functional_with_buffers()
sub_seq_2(
td_2,
in_keys_filter=["hidden"],
params=params,
buffers=buffers,
)
assert "out" in td_2.keys()
assert td_2.get("out").shape == torch.Size([5, 4])
else:
td_0 = TensorDict({"in": torch.randn(5, 3)}, [5])
td_module(td_0)
assert td_0.get("out").shape == torch.Size([5, 4])
td_1 = TensorDict({"in": torch.randn(5, 3)}, [5])
td_module(td_1, out_keys_filter=["hidden"])
sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"])
sub_seq_1(td_1)
assert "hidden" in td_1.keys()
assert "out" not in td_1.keys()
td_2 = TensorDict({"hidden": torch.randn(5, 2)}, [5])
td_module(td_2, in_keys_filter=["hidden"])
sub_seq_2 = td_module.select_subsequence(in_keys=["hidden"])
sub_seq_2(td_2)
assert "out" in td_2.keys()
assert td_2.get("out").shape == torch.Size([5, 4])


def test_subsequence_weight_update():
td_module_1 = TensorDictModule(
nn.Linear(3, 2),
in_keys=["in"],
out_keys=["hidden"],
)
td_module_2 = TensorDictModule(
nn.Linear(2, 4),
in_keys=["hidden"],
out_keys=["out"],
)
td_module = TensorDictSequence(td_module_1, td_module_2)

td_1 = TensorDict({"in": torch.randn(5, 3)}, [5])
sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"])
copy = sub_seq_1[0].module.weight.clone()

opt = torch.optim.SGD(td_module.parameters(), lr=0.1)
opt.zero_grad()
td_1 = td_module(td_1)
td_1["out"].mean().backward()
opt.step()

assert not torch.allclose(copy, sub_seq_1[0].module.weight)
assert torch.allclose(td_module[0].module.weight, sub_seq_1[0].module.weight)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
63 changes: 36 additions & 27 deletions torchrl/modules/tensordict_module/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,44 +195,57 @@ def _split_param(
out.append(param_list[a:b])
return out

def forward(
self,
tensordict: TensorDictBase,
in_keys_filter=None,
out_keys_filter=None,
tensordict_out=None,
**kwargs,
) -> TensorDictBase:
def select_subsequence(
self, in_keys: Iterable[str] = None, out_keys: Iterable[str] = None
) -> "TensorDictSequence":
"""
Returns a new TensorDictSequence with only the modules that are necessary to compute
the given output keys with the given input keys.

Args:
in_keys: input keys of the subsequence we want to select
out_keys: output keys of the subsequence we want to select

# Filter modules to avoid calling modules that don't require the desired in_keys or out keys.
if in_keys_filter is None:
in_keys_filter = deepcopy(self.in_keys)
if out_keys_filter is None:
out_keys_filter = deepcopy(self.out_keys)
Returns:
A new TensorDictSequence with only the modules that are necessary acording to the given input and output keys.
"""
if in_keys is None:
in_keys = deepcopy(self.in_keys)
if out_keys is None:
out_keys = deepcopy(self.out_keys)
id_to_keep = set([i for i in range(len(self.module))])
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
for i, module in enumerate(self.module):
if all(key in in_keys_filter for key in module.in_keys):
in_keys_filter.extend(module.out_keys)
if all(key in in_keys for key in module.in_keys):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the usage of selecting in_keys? I can understand why we want to restrict the outputs, but I don't really see when we want to restrict inputs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also what happens if you say you want some out_keys but they conflict with the in_keys? Is the sequence going to be empty?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to be able to select the in_keys to be able to directly input an intermediarry block. For example imagine you have a hidden layer that you want to inject from a precomputed tensordict, this allows to do so.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it would be empty if your out keys are before the in_keys

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it
We're doing something similar in #352: you can input an incomplete tensordict and only the relevant ops will be executed. I wonder if we need both ways of doing the same thing. The advantage of your implementation is that it is self-consistent though.

in_keys.extend(module.out_keys)
else:
id_to_keep.remove(i)
for i, module in reversed(list(enumerate(self.module))):
if i in id_to_keep:
if any(key in out_keys_filter for key in module.out_keys):
out_keys_filter.extend(module.in_keys)
if any(key in out_keys for key in module.out_keys):
out_keys.extend(module.in_keys)
else:
id_to_keep.remove(i)
id_to_keep = sorted(list(id_to_keep))

modules = [self.module[i] for i in id_to_keep]

in_keys, _ = self._compute_in_and_out_keys(modules)
return TensorDictSequence(*modules)

if not all(key in tensordict.keys() for key in in_keys):
def forward(
self,
tensordict: TensorDictBase,
tensordict_out=None,
**kwargs,
) -> TensorDictBase:

if not all(key in tensordict.keys() for key in self.in_keys):

raise ValueError(
f"Not all in_keys found in input TensorDict. missing keys:{set(in_keys) - set(tensordict.keys())}"
f"Not all in_keys found in input TensorDict. missing keys:{set(self.in_keys) - set(tensordict.keys())}"
)

# Filter modules to avoid calling modules that don't require the desired in_keys or out keys

if "params" in kwargs and "buffers" in kwargs:
param_splits = self._split_param(kwargs["params"], "params")
buffer_splits = self._split_param(kwargs["buffers"], "buffers")
Expand All @@ -241,11 +254,8 @@ def forward(
for key, item in kwargs.items()
if key not in ("params", "buffers")
}
param_splits = [param_splits[i] for i in id_to_keep]
buffer_splits = [buffer_splits[i] for i in id_to_keep]

for i, (module, param, buffer) in enumerate(
zip(modules, param_splits, buffer_splits)
zip(self.module, param_splits, buffer_splits)
):
if "vmap" in kwargs_pruned and i > 0:
# the tensordict is already expended
Expand All @@ -259,15 +269,14 @@ def forward(
kwargs_pruned = {
key: item for key, item in kwargs.items() if key not in ("params",)
}
param_splits = [param_splits[i] for i in id_to_keep]
for i, (module, param) in enumerate(zip(modules, param_splits)):
for i, (module, param) in enumerate(zip(self.module, param_splits)):
if "vmap" in kwargs_pruned and i > 0:
# the tensordict is already expended
kwargs_pruned["vmap"] = (0, *(0,) * len(module.in_keys))
tensordict = module(tensordict, params=param, **kwargs_pruned)

elif not len(kwargs):
for module in modules:
for module in self.module:
tensordict = module(tensordict)
else:
raise RuntimeError(
Expand Down