Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into announce-npw-deprec
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jul 25, 2024
2 parents 54b0c13 + 94abb50 commit 5066ec0
Show file tree
Hide file tree
Showing 31 changed files with 733 additions and 152 deletions.
9 changes: 7 additions & 2 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -355,13 +355,18 @@ algorithms, such as DQN, DDPG or Dreamer.
Multi-agent-specific modules
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

These networks implement models that can be used in
multi-agent contexts.
These networks implement models that can be used in multi-agent contexts.
They use :func:`~torch.vmap` to execute multiple networks all at once on the
network inputs. Because the parameters are batched, initialization may differ
from what is usually done with other PyTorch modules, see
:meth:`~torchrl.modules.MultiAgentNetBase.get_stateful_net`
for more information.

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

MultiAgentNetBase
MultiAgentMLP
MultiAgentConvNet
QMixer
Expand Down
29 changes: 29 additions & 0 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,35 @@ The main characteristics of TorchRL losses are:

>>> loss_val = sum(loss for key, loss in loss_vals.items() if key.startswith("loss_"))

.. note::
Initializing parameters in losses can be done via a query to :meth:`~torchrl.objectives.LossModule.get_stateful_net`
which will return a stateful version of the network that can be initialized like any other module.
If the modification is done in-place, it will be downstreamed to any other module that uses the same parameter
set (within and outside of the loss): for instance, modifying the ``actor_network`` parameters from the loss
will also modify the actor in the collector.
If the parameters are modified out-of-place, :meth:`~torchrl.objectives.LossModule.from_stateful_net` can be
used to reset the parameters in the loss to the new value.

torch.vmap and randomness
-------------------------

TorchRL loss modules have plenty of calls to :func:`~torch.vmap` to amortize the cost of calling multiple similar models
in a loop, and instead vectorize these operations. `vmap` needs to be told explicitly what to do when random numbers
need to be generated within the call. To do this, a randomness mode need to be set and must be one of `"error"` (default,
errors when dealing with pseudo-random functions), `"same"` (replicates the results across the batch) or `"different"`
(each element of the batch is treated separately).
Relying on the default will typically result in an error such as this one:

>>> RuntimeError: vmap: called random operation while in randomness error mode.

Since the calls to `vmap` are buried down the loss modules, TorchRL
provides an interface to set that vmap mode from the outside through `loss.vmap_randomness = str_value`, see
:meth:`~torchrl.objectives.LossModule.vmap_randomness` for more information.

``LossModule.vmap_randomness`` defaults to `"error"` if no random module is detected, and to `"different"` in
other cases. By default, only a limited number of modules are listed as random, but the list can be extended
using the :func:`~torchrl.objectives.common.add_random_module` function.

Training value functions
------------------------

Expand Down
188 changes: 184 additions & 4 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,9 @@ def __init__(self):
net = nn.Sequential(*layers).to(device)
model = TensorDictModule(net, in_keys=["obs"], out_keys=["action"])
self.convert_to_functional(model, "model", expand_dim=4)
self._make_vmap()

def _make_vmap(self):
self.vmap_model = _vmap_func(
self.model,
(None, 0),
Expand Down Expand Up @@ -6852,6 +6855,71 @@ def test_cql(
p.grad is None or p.grad.norm() == 0.0
), f"target parameter {name} (shape: {p.shape}) has a non-null gradient"

@pytest.mark.parametrize("delay_actor", (True,))
@pytest.mark.parametrize("delay_qvalue", (True,))
@pytest.mark.parametrize(
"max_q_backup",
[
True,
],
)
@pytest.mark.parametrize(
"deterministic_backup",
[
True,
],
)
@pytest.mark.parametrize(
"with_lagrange",
[
True,
],
)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("td_est", [None])
def test_cql_qvalfromlist(
self,
delay_actor,
delay_qvalue,
max_q_backup,
deterministic_backup,
with_lagrange,
device,
td_est,
):
torch.manual_seed(self.seed)
td = self._create_mock_data_cql(device=device)

actor = self._create_mock_actor(device=device)
qvalue0 = self._create_mock_qvalue(device=device)
qvalue1 = self._create_mock_qvalue(device=device)

loss_fn_single = CQLLoss(
actor_network=actor,
qvalue_network=qvalue0,
loss_function="l2",
max_q_backup=max_q_backup,
deterministic_backup=deterministic_backup,
with_lagrange=with_lagrange,
delay_actor=delay_actor,
delay_qvalue=delay_qvalue,
)
loss_fn_mult = CQLLoss(
actor_network=actor,
qvalue_network=[qvalue0, qvalue1],
loss_function="l2",
max_q_backup=max_q_backup,
deterministic_backup=deterministic_backup,
with_lagrange=with_lagrange,
delay_actor=delay_actor,
delay_qvalue=delay_qvalue,
)
# Check that all params have the same shape
p2 = dict(loss_fn_mult.named_parameters())
for key, val in loss_fn_single.named_parameters():
assert val.shape == p2[key].shape
assert len(dict(loss_fn_single.named_parameters())) == len(p2)

@pytest.mark.parametrize("delay_actor", (True, False))
@pytest.mark.parametrize("delay_qvalue", (True, False))
@pytest.mark.parametrize("max_q_backup", [True])
Expand Down Expand Up @@ -14547,6 +14615,118 @@ def __init__(self, compare_against, expand_dim):
for key in ["module.1.bias", "module.1.weight"]:
loss_module.module_b_params.flatten_keys()[key].requires_grad

def test_init_params(self):
class MyLoss(LossModule):
module_a: TensorDictModule
module_b: TensorDictModule
module_a_params: TensorDict
module_b_params: TensorDict
target_module_a_params: TensorDict
target_module_b_params: TensorDict

def __init__(self, expand_dim=2):
super().__init__()
module1 = nn.Linear(3, 4)
module2 = nn.Linear(3, 4)
module3 = nn.Linear(3, 4)
module_a = TensorDictModule(
nn.Sequential(module1, module2), in_keys=["a"], out_keys=["c"]
)
module_b = TensorDictModule(
nn.Sequential(module1, module3), in_keys=["b"], out_keys=["c"]
)
self.convert_to_functional(module_a, "module_a")
self.convert_to_functional(
module_b,
"module_b",
compare_against=module_a.parameters(),
expand_dim=expand_dim,
)

loss = MyLoss()

module_a = loss.get_stateful_net("module_a", copy=False)
assert module_a is loss.module_a

module_a = loss.get_stateful_net("module_a")
assert module_a is not loss.module_a

def init(mod):
if hasattr(mod, "weight"):
mod.weight.data.zero_()
if hasattr(mod, "bias"):
mod.bias.data.zero_()

module_a.apply(init)
assert (loss.module_a_params == 0).all()

def init(mod):
if hasattr(mod, "weight"):
mod.weight = torch.nn.Parameter(mod.weight.data + 1)
if hasattr(mod, "bias"):
mod.bias = torch.nn.Parameter(mod.bias.data + 1)

module_a.apply(init)
assert (loss.module_a_params == 0).all()
loss.from_stateful_net("module_a", module_a)
assert (loss.module_a_params == 1).all()

def test_from_module_list(self):
class MyLoss(LossModule):
module_a: TensorDictModule
module_b: TensorDictModule

module_a_params: TensorDict
module_b_params: TensorDict

target_module_a_params: TensorDict
target_module_b_params: TensorDict

def __init__(self, module_a, module_b0, module_b1, expand_dim=2):
super().__init__()
self.convert_to_functional(module_a, "module_a")
self.convert_to_functional(
[module_b0, module_b1],
"module_b",
# This will be ignored
compare_against=module_a.parameters(),
expand_dim=expand_dim,
)

module1 = nn.Linear(3, 4)
module2 = nn.Linear(3, 4)
module3a = nn.Linear(3, 4)
module3b = nn.Linear(3, 4)

module_a = TensorDictModule(
nn.Sequential(module1, module2), in_keys=["a"], out_keys=["c"]
)

module_b0 = TensorDictModule(
nn.Sequential(module1, module3a), in_keys=["b"], out_keys=["c"]
)
module_b1 = TensorDictModule(
nn.Sequential(module1, module3b), in_keys=["b"], out_keys=["c"]
)

loss = MyLoss(module_a, module_b0, module_b1)

# This should be extended
assert not isinstance(
loss.module_b_params["module", "0", "weight"], nn.Parameter
)
assert loss.module_b_params["module", "0", "weight"].shape[0] == 2
assert (
loss.module_b_params["module", "0", "weight"].data.data_ptr()
== loss.module_a_params["module", "0", "weight"].data.data_ptr()
)
assert isinstance(loss.module_b_params["module", "1", "weight"], nn.Parameter)
assert loss.module_b_params["module", "1", "weight"].shape[0] == 2
assert (
loss.module_b_params["module", "1", "weight"].data.data_ptr()
!= loss.module_a_params["module", "1", "weight"].data.data_ptr()
)

def test_tensordict_keys(self):
"""Test configurable tensordict key behavior with derived classes."""

Expand Down Expand Up @@ -14904,10 +15084,10 @@ def __init__(self):
assert v_p1 == v_p2
assert v_params1 == v_params2
assert v_buffers1 == v_buffers2
for p in mod.parameters():
assert isinstance(p, nn.Parameter)
for p in mod.buffers():
assert isinstance(p, Buffer)
for k, p in mod.named_parameters():
assert isinstance(p, nn.Parameter), k
for k, p in mod.named_buffers():
assert isinstance(p, Buffer), k
for p in mod.actor_params.values(True, True):
assert isinstance(p, (nn.Parameter, Buffer))
for p in mod.value_params.values(True, True):
Expand Down
59 changes: 59 additions & 0 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,65 @@ def test_multiagent_mlp(
agent_dim={-2}\)"""
assert re.match(pattern, str(mlp), re.DOTALL)

@retry(AssertionError, 5)
@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize("centralized", [True, False])
@pytest.mark.parametrize("n_agent_inputs", [6, None])
@pytest.mark.parametrize("batch", [(4,), (4, 3), ()])
def test_multiagent_mlp_init(
self,
n_agents,
centralized,
share_params,
batch,
n_agent_inputs,
n_agent_outputs=2,
):
torch.manual_seed(1)
mlp = MultiAgentMLP(
n_agent_inputs=n_agent_inputs,
n_agent_outputs=n_agent_outputs,
n_agents=n_agents,
centralized=centralized,
share_params=share_params,
depth=2,
)
for m in mlp.modules():
if isinstance(m, nn.Linear):
assert not isinstance(m.weight, nn.Parameter)
assert m.weight.device == torch.device("meta")
break
else:
raise RuntimeError("could not find a Linear module")
if n_agent_inputs is None:
n_agent_inputs = 6
td = self._get_mock_input_td(n_agents, n_agent_inputs, batch=batch)
obs = td.get(("agents", "observation"))
mlp(obs)
snet = mlp.get_stateful_net()
assert snet is not mlp._empty_net

def zero_inplace(mod):
if hasattr(mod, "weight"):
mod.weight.data *= 0
if hasattr(mod, "bias"):
mod.bias.data *= 0

snet.apply(zero_inplace)
assert (mlp.params == 0).all()

def one_outofplace(mod):
if hasattr(mod, "weight"):
mod.weight = nn.Parameter(torch.ones_like(mod.weight.data))
if hasattr(mod, "bias"):
mod.bias = nn.Parameter(torch.ones_like(mod.bias.data))

snet.apply(one_outofplace)
assert (mlp.params == 0).all()
mlp.from_stateful_net(snet)
assert (mlp.params == 1).all()

def test_multiagent_mlp_lazy(self):
mlp = MultiAgentMLP(
n_agent_inputs=None,
Expand Down
17 changes: 13 additions & 4 deletions torchrl/envs/libs/brax.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ class BraxWrapper(_EnvWrapper):
Examples:
>>> import brax.envs
>>> from torchrl.envs import BraxWrapper
>>> import torch
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> base_env = brax.envs.get_environment("ant")
>>> env = BraxWrapper(base_env)
>>> env = BraxWrapper(base_env, device=device)
>>> env.set_seed(0)
>>> td = env.reset()
>>> td["action"] = env.action_spec.rand()
Expand Down Expand Up @@ -111,15 +113,17 @@ class BraxWrapper(_EnvWrapper):
and report the execution time for a short rollout:
Examples:
>>> import torch
>>> from torch.utils.benchmark import Timer
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> for batch_size in [4, 16, 128]:
... timer = Timer('''
... env.rollout(100)
... ''',
... setup=f'''
... import brax.envs
... from torchrl.envs import BraxWrapper
... env = BraxWrapper(brax.envs.get_environment("ant"), batch_size=[{batch_size}])
... env = BraxWrapper(brax.envs.get_environment("ant"), batch_size=[{batch_size}], device="{device}")
... env.set_seed(0)
... env.rollout(2)
... ''')
Expand Down Expand Up @@ -459,7 +463,9 @@ class BraxEnv(BraxWrapper):
Examples:
>>> from torchrl.envs import BraxEnv
>>> env = BraxEnv("ant")
>>> import torch
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> env = BraxEnv("ant", device=device)
>>> env.set_seed(0)
>>> td = env.reset()
>>> td["action"] = env.action_spec.rand()
Expand Down Expand Up @@ -489,13 +495,16 @@ class BraxEnv(BraxWrapper):
and report the execution time for a short rollout:
Examples:
>>> import torch
>>> from torch.utils.benchmark import Timer
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> for batch_size in [4, 16, 128]:
... timer = Timer('''
... env.rollout(100)
... ''',
... setup=f'''
... from torchrl.envs import BraxEnv
... env = BraxEnv("ant", batch_size=[{batch_size}])
... env = BraxEnv("ant", batch_size=[{batch_size}], device="{device}")
... env.set_seed(0)
... env.rollout(2)
... ''')
Expand Down
Loading

0 comments on commit 5066ec0

Please sign in to comment.