# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import re
from numbers import Number
import numpy as np
import pytest
import torch
from _utils_internal import get_default_devices, retry
from mocking_classes import MockBatchedUnLockedEnv
from packaging import version
from tensordict import TensorDict
from torch import nn
from torchrl.data.tensor_specs import BoundedTensorSpec, CompositeSpec
from torchrl.modules import (
CEMPlanner,
DTActor,
GRU,
GRUCell,
LSTM,
LSTMCell,
MultiAgentConvNet,
MultiAgentMLP,
OnlineDTActor,
QMixer,
SafeModule,
TanhModule,
ValueOperator,
VDNMixer,
)
from torchrl.modules.distributions.utils import safeatanh, safetanh
from torchrl.modules.models import (
BatchRenorm1d,
Conv3dNet,
ConvNet,
MLP,
NoisyLazyLinear,
NoisyLinear,
)
from torchrl.modules.models.decision_transformer import (
_has_transformers,
DecisionTransformer,
)
from torchrl.modules.models.model_based import (
DreamerActor,
ObsDecoder,
ObsEncoder,
RSSMPosterior,
RSSMPrior,
RSSMRollout,
)
from torchrl.modules.models.utils import SquashDims
from torchrl.modules.planners.mppi import MPPIPlanner
from torchrl.objectives.value import TDLambdaEstimator
@pytest.fixture
def double_prec_fixture():
dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.double)
yield
torch.set_default_dtype(dtype)
class TestMLP:
@pytest.mark.parametrize("in_features", [3, 10, None])
@pytest.mark.parametrize("out_features", [3, (3, 10)])
@pytest.mark.parametrize("depth, num_cells", [(3, 32), (None, (32, 32, 32))])
@pytest.mark.parametrize(
"activation_class, activation_kwargs",
[(nn.ReLU, {"inplace": True}), (nn.ReLU, {}), (nn.PReLU, {})],
)
@pytest.mark.parametrize(
"norm_class, norm_kwargs",
[
(nn.LazyBatchNorm1d, {}),
(nn.BatchNorm1d, {"num_features": 32}),
(nn.LayerNorm, {"normalized_shape": 32}),
],
)
@pytest.mark.parametrize("dropout", [0.0, 0.5])
@pytest.mark.parametrize("bias_last_layer", [True, False])
@pytest.mark.parametrize("single_bias_last_layer", [True, False])
@pytest.mark.parametrize("layer_class", [nn.Linear, NoisyLinear])
@pytest.mark.parametrize("device", get_default_devices())
def test_mlp(
self,
in_features,
out_features,
depth,
num_cells,
activation_class,
activation_kwargs,
dropout,
bias_last_layer,
norm_class,
norm_kwargs,
single_bias_last_layer,
layer_class,
device,
seed=0,
):
torch.manual_seed(seed)
batch = 2
mlp = MLP(
in_features=in_features,
out_features=out_features,
depth=depth,
num_cells=num_cells,
activation_class=activation_class,
activation_kwargs=activation_kwargs,
norm_class=norm_class,
norm_kwargs=norm_kwargs,
dropout=dropout,
bias_last_layer=bias_last_layer,
single_bias_last_layer=False,
layer_class=layer_class,
device=device,
)
if in_features is None:
in_features = 5
x = torch.randn(batch, in_features, device=device)
y = mlp(x)
out_features = (
[out_features] if isinstance(out_features, Number) else out_features
)
assert y.shape == torch.Size([batch, *out_features])
def test_kwargs(self):
def make_activation(shift):
return lambda x: x + shift
def layer(*args, **kwargs):
linear = nn.Linear(*args, **kwargs)
linear.weight.data.copy_(torch.eye(4))
return linear
in_features = 4
out_features = 4
num_cells = [4, 4, 4]
mlp = MLP(
in_features=in_features,
out_features=out_features,
num_cells=num_cells,
activation_class=make_activation,
activation_kwargs=[{"shift": 0}, {"shift": 1}, {"shift": 2}],
layer_class=layer,
layer_kwargs=[{"bias": False}] * 4,
bias_last_layer=False,
)
x = torch.zeros(4)
y = mlp(x)
for i, module in enumerate(mlp.modules()):
if isinstance(module, nn.Linear):
assert (module.weight == torch.eye(4)).all(), i
assert module.bias is None, i
assert (y == 3).all()
@pytest.mark.parametrize("in_features", [3, 10, None])
@pytest.mark.parametrize(
"input_size, depth, num_cells, kernel_sizes, strides, paddings, expected_features",
[(100, None, None, 3, 1, 0, 32 * 94 * 94), (100, 3, 32, 3, 1, 1, 32 * 100 * 100)],
)
@pytest.mark.parametrize(
"activation_class, activation_kwargs",
[(nn.ReLU, {"inplace": True}), (nn.ReLU, {}), (nn.PReLU, {})],
)
@pytest.mark.parametrize(
"norm_class, norm_kwargs",
[(None, None), (nn.LazyBatchNorm2d, {}), (nn.BatchNorm2d, {"num_features": 32})],
)
@pytest.mark.parametrize("bias_last_layer", [True, False])
@pytest.mark.parametrize(
"aggregator_class, aggregator_kwargs",
[(SquashDims, {})],
)
@pytest.mark.parametrize("squeeze_output", [False])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("batch", [(2,), (2, 2)])
def test_convnet(
batch,
in_features,
depth,
num_cells,
kernel_sizes,
strides,
paddings,
activation_class,
activation_kwargs,
norm_class,
norm_kwargs,
bias_last_layer,
aggregator_class,
aggregator_kwargs,
squeeze_output,
device,
input_size,
expected_features,
seed=0,
):
torch.manual_seed(seed)
convnet = ConvNet(
in_features=in_features,
depth=depth,
num_cells=num_cells,
kernel_sizes=kernel_sizes,
strides=strides,
paddings=paddings,
activation_class=activation_class,
activation_kwargs=activation_kwargs,
norm_class=norm_class,
norm_kwargs=norm_kwargs,
bias_last_layer=bias_last_layer,
aggregator_class=aggregator_class,
aggregator_kwargs=aggregator_kwargs,
squeeze_output=squeeze_output,
device=device,
)
if in_features is None:
in_features = 5
x = torch.randn(*batch, in_features, input_size, input_size, device=device)
y = convnet(x)
assert y.shape == torch.Size([*batch, expected_features])
class TestConv3d:
@pytest.mark.parametrize("in_features", [3, 10, None])
@pytest.mark.parametrize(
"input_size, depth, num_cells, kernel_sizes, strides, paddings, expected_features",
[
(10, None, None, 3, 1, 0, 32 * 4 * 4 * 4),
(10, 3, 32, 3, 1, 1, 32 * 10 * 10 * 10),
],
)
@pytest.mark.parametrize(
"activation_class, activation_kwargs",
[(nn.ReLU, {"inplace": True}), (nn.ReLU, {}), (nn.PReLU, {})],
)
@pytest.mark.parametrize(
"norm_class, norm_kwargs",
[
(None, None),
(nn.LazyBatchNorm3d, {}),
(nn.BatchNorm3d, {"num_features": 32}),
],
)
@pytest.mark.parametrize("bias_last_layer", [True, False])
@pytest.mark.parametrize(
"aggregator_class, aggregator_kwargs",
[(SquashDims, None)],
)
@pytest.mark.parametrize("squeeze_output", [False])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("batch", [(2,), (2, 2)])
def test_conv3dnet(
self,
batch,
in_features,
depth,
num_cells,
kernel_sizes,
strides,
paddings,
activation_class,
activation_kwargs,
norm_class,
norm_kwargs,
bias_last_layer,
aggregator_class,
aggregator_kwargs,
squeeze_output,
device,
input_size,
expected_features,
seed=0,
):
torch.manual_seed(seed)
conv3dnet = Conv3dNet(
in_features=in_features,
depth=depth,
num_cells=num_cells,
kernel_sizes=kernel_sizes,
strides=strides,
paddings=paddings,
activation_class=activation_class,
activation_kwargs=activation_kwargs,
norm_class=norm_class,
norm_kwargs=norm_kwargs,
bias_last_layer=bias_last_layer,
aggregator_class=aggregator_class,
aggregator_kwargs=aggregator_kwargs,
squeeze_output=squeeze_output,
device=device,
)
if in_features is None:
in_features = 5
x = torch.randn(
*batch, in_features, input_size, input_size, input_size, device=device
)
y = conv3dnet(x)
assert y.shape == torch.Size([*batch, expected_features])
with pytest.raises(ValueError, match="must have at least 4 dimensions"):
conv3dnet(torch.randn(3, 16, 16))
def test_errors(self):
with pytest.raises(
ValueError, match="Null depth is not permitted with Conv3dNet"
):
conv3dnet = Conv3dNet(
in_features=5,
num_cells=32,
depth=0,
)
with pytest.raises(
ValueError, match="depth=None requires one of the input args"
):
conv3dnet = Conv3dNet(
in_features=5,
num_cells=32,
depth=None,
)
with pytest.raises(
ValueError, match="consider matching or specifying a constant num_cells"
):
conv3dnet = Conv3dNet(
in_features=5,
num_cells=[32],
depth=None,
kernel_sizes=[3, 3],
)
@pytest.mark.parametrize(
"layer_class",
[
NoisyLinear,
NoisyLazyLinear,
],
)
@pytest.mark.parametrize("device", get_default_devices())
def test_noisy(layer_class, device, seed=0):
torch.manual_seed(seed)
layer = layer_class(3, 4, device=device)
x = torch.randn(10, 3, device=device)
y1 = layer(x)
layer.reset_noise()
y2 = layer(x)
y3 = layer(x)
torch.testing.assert_close(y2, y3)
with pytest.raises(AssertionError):
torch.testing.assert_close(y1, y2)
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("batch_size", [3, 5])
class TestPlanner:
def test_CEM_model_free_env(self, device, batch_size, seed=1):
env = MockBatchedUnLockedEnv(device=device)
torch.manual_seed(seed)
planner = CEMPlanner(
env,
planning_horizon=10,
optim_steps=2,
num_candidates=100,
top_k=2,
)
td = env.reset(TensorDict({}, batch_size=batch_size).to(device))
td_copy = td.clone()
td = planner(td)
assert (
td.get("action").shape[-len(env.action_spec.shape) :]
== env.action_spec.shape
)
assert env.action_spec.is_in(td.get("action"))
for key in td.keys():
if key != "action":
assert torch.allclose(td[key], td_copy[key])
def test_MPPI(self, device, batch_size, seed=1):
torch.manual_seed(seed)
env = MockBatchedUnLockedEnv(device=device)
value_net = nn.LazyLinear(1, device=device)
value_net = ValueOperator(value_net, in_keys=["observation"])
advantage_module = TDLambdaEstimator(
gamma=0.99,
lmbda=0.95,
value_network=value_net,
)
value_net(env.reset())
planner = MPPIPlanner(
env,
advantage_module,
temperature=1.0,
planning_horizon=10,
optim_steps=2,
num_candidates=100,
top_k=2,
)
td = env.reset(TensorDict({}, batch_size=batch_size).to(device))
td_copy = td.clone()
td = planner(td)
assert (
td.get("action").shape[-len(env.action_spec.shape) :]
== env.action_spec.shape
)
assert env.action_spec.is_in(td.get("action"))
for key in td.keys():
if key != "action":
assert torch.allclose(td[key], td_copy[key])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("batch_size", [[], [3], [5]])
@pytest.mark.skipif(
version.parse(torch.__version__) < version.parse("1.11.0"),
reason="""Dreamer works with batches of null to 2 dimensions. Torch < 1.11
requires one-dimensional batches (for RNN and Conv nets for instance). If you'd like
to see torch < 1.11 supported for dreamer, please submit an issue.""",
)
class TestDreamerComponents:
@pytest.mark.parametrize("out_features", [3, 5])
@pytest.mark.parametrize("temporal_size", [[], [2], [4]])
def test_dreamer_actor(self, device, batch_size, temporal_size, out_features):
actor = DreamerActor(
out_features,
).to(device)
emb = torch.randn(*batch_size, *temporal_size, 15, device=device)
state = torch.randn(*batch_size, *temporal_size, 2, device=device)
loc, scale = actor(emb, state)
assert loc.shape == (*batch_size, *temporal_size, out_features)
assert scale.shape == (*batch_size, *temporal_size, out_features)
assert torch.all(scale > 0)
@pytest.mark.parametrize("depth", [32, 64])
@pytest.mark.parametrize("temporal_size", [[], [2], [4]])
def test_dreamer_encoder(self, device, temporal_size, batch_size, depth):
encoder = ObsEncoder(channels=depth).to(device)
obs = torch.randn(*batch_size, *temporal_size, 3, 64, 64, device=device)
emb = encoder(obs)
assert emb.shape == (*batch_size, *temporal_size, depth * 8 * 4)
@pytest.mark.parametrize("depth", [32, 64])
@pytest.mark.parametrize("stoch_size", [10, 20])
@pytest.mark.parametrize("deter_size", [20, 30])
@pytest.mark.parametrize("temporal_size", [[], [2], [4]])
def test_dreamer_decoder(
self, device, batch_size, temporal_size, depth, stoch_size, deter_size
):
decoder = ObsDecoder(channels=depth).to(device)
stoch_state = torch.randn(
*batch_size, *temporal_size, stoch_size, device=device
)
det_state = torch.randn(*batch_size, *temporal_size, deter_size, device=device)
obs = decoder(stoch_state, det_state)
assert obs.shape == (*batch_size, *temporal_size, 3, 64, 64)
@pytest.mark.parametrize("stoch_size", [10, 20])
@pytest.mark.parametrize("deter_size", [20, 30])
@pytest.mark.parametrize("action_size", [3, 6])
def test_rssm_prior(self, device, batch_size, stoch_size, deter_size, action_size):
action_spec = BoundedTensorSpec(
shape=(action_size,), dtype=torch.float32, low=-1, high=1
)
rssm_prior = RSSMPrior(
action_spec,
hidden_dim=stoch_size,
rnn_hidden_dim=stoch_size,
state_dim=deter_size,
).to(device)
state = torch.randn(*batch_size, deter_size, device=device)
action = torch.randn(*batch_size, action_size, device=device)
belief = torch.randn(*batch_size, stoch_size, device=device)
prior_mean, prior_std, next_state, belief = rssm_prior(state, belief, action)
assert prior_mean.shape == (*batch_size, deter_size)
assert prior_std.shape == (*batch_size, deter_size)
assert next_state.shape == (*batch_size, deter_size)
assert belief.shape == (*batch_size, stoch_size)
assert torch.all(prior_std > 0)
@pytest.mark.parametrize("stoch_size", [10, 20])
@pytest.mark.parametrize("deter_size", [20, 30])
def test_rssm_posterior(self, device, batch_size, stoch_size, deter_size):
rssm_posterior = RSSMPosterior(
hidden_dim=stoch_size,
state_dim=deter_size,
).to(device)
belief = torch.randn(*batch_size, stoch_size, device=device)
obs_emb = torch.randn(*batch_size, 1024, device=device)
# Init of lazy linears
_ = rssm_posterior(belief.clone(), obs_emb.clone())
torch.manual_seed(0)
posterior_mean, posterior_std, next_state = rssm_posterior(
belief.clone(), obs_emb.clone()
)
assert posterior_mean.shape == (*batch_size, deter_size)
assert posterior_std.shape == (*batch_size, deter_size)
assert next_state.shape == (*batch_size, deter_size)
assert torch.all(posterior_std > 0)
torch.manual_seed(0)
posterior_mean_bis, posterior_std_bis, next_state_bis = rssm_posterior(
belief.clone(), obs_emb.clone()
)
assert torch.allclose(posterior_mean, posterior_mean_bis)
assert torch.allclose(posterior_std, posterior_std_bis)
assert torch.allclose(next_state, next_state_bis)
@pytest.mark.parametrize("stoch_size", [10, 20])
@pytest.mark.parametrize("deter_size", [20, 30])
@pytest.mark.parametrize("temporal_size", [2, 4])
@pytest.mark.parametrize("action_size", [3, 6])
def test_rssm_rollout(
self, device, batch_size, temporal_size, stoch_size, deter_size, action_size
):
action_spec = BoundedTensorSpec(
shape=(action_size,), dtype=torch.float32, low=-1, high=1
)
rssm_prior = RSSMPrior(
action_spec,
hidden_dim=stoch_size,
rnn_hidden_dim=stoch_size,
state_dim=deter_size,
).to(device)
rssm_posterior = RSSMPosterior(
hidden_dim=stoch_size,
state_dim=deter_size,
).to(device)
rssm_rollout = RSSMRollout(
SafeModule(
rssm_prior,
in_keys=["state", "belief", "action"],
out_keys=[
("next", "prior_mean"),
("next", "prior_std"),
"_",
("next", "belief"),
],
),
SafeModule(
rssm_posterior,
in_keys=[("next", "belief"), ("next", "encoded_latents")],
out_keys=[
("next", "posterior_mean"),
("next", "posterior_std"),
("next", "state"),
],
),
)
state = torch.randn(*batch_size, temporal_size, deter_size, device=device)
belief = torch.randn(*batch_size, temporal_size, stoch_size, device=device)
action = torch.randn(*batch_size, temporal_size, action_size, device=device)
obs_emb = torch.randn(*batch_size, temporal_size, 1024, device=device)
tensordict = TensorDict(
{
"state": state.clone(),
"action": action.clone(),
"next": {
"encoded_latents": obs_emb.clone(),
"belief": belief.clone(),
},
},
device=device,
batch_size=torch.Size([*batch_size, temporal_size]),
)
## Init of lazy linears
_ = rssm_rollout(tensordict.clone())
torch.manual_seed(0)
rollout = rssm_rollout(tensordict)
assert rollout["next", "prior_mean"].shape == (
*batch_size,
temporal_size,
deter_size,
)
assert rollout["next", "prior_std"].shape == (
*batch_size,
temporal_size,
deter_size,
)
assert rollout["next", "state"].shape == (
*batch_size,
temporal_size,
deter_size,
)
assert rollout["next", "belief"].shape == (
*batch_size,
temporal_size,
stoch_size,
)
assert rollout["next", "posterior_mean"].shape == (
*batch_size,
temporal_size,
deter_size,
)
assert rollout["next", "posterior_std"].shape == (
*batch_size,
temporal_size,
deter_size,
)
assert torch.all(rollout["next", "prior_std"] > 0)
assert torch.all(rollout["next", "posterior_std"] > 0)
state[..., 1:, :] = 0
belief[..., 1:, :] = 0
# Only the first state is used for the prior. The rest are recomputed
tensordict_bis = TensorDict(
{
"state": state.clone(),
"action": action.clone(),
"next": {"encoded_latents": obs_emb.clone(), "belief": belief.clone()},
},
device=device,
batch_size=torch.Size([*batch_size, temporal_size]),
)
torch.manual_seed(0)
rollout_bis = rssm_rollout(tensordict_bis)
assert torch.allclose(
rollout["next", "prior_mean"], rollout_bis["next", "prior_mean"]
), (rollout["next", "prior_mean"] - rollout_bis["next", "prior_mean"]).norm()
assert torch.allclose(
rollout["next", "prior_std"], rollout_bis["next", "prior_std"]
)
assert torch.allclose(rollout["next", "state"], rollout_bis["next", "state"])
assert torch.allclose(rollout["next", "belief"], rollout_bis["next", "belief"])
assert torch.allclose(
rollout["next", "posterior_mean"], rollout_bis["next", "posterior_mean"]
)
assert torch.allclose(
rollout["next", "posterior_std"], rollout_bis["next", "posterior_std"]
)
class TestTanh:
def test_errors(self):
with pytest.raises(
ValueError, match="in_keys and out_keys should have the same length"
):
TanhModule(in_keys=["a", "b"], out_keys=["a"])
with pytest.raises(ValueError, match=r"The minimum value \(-2\) provided"):
spec = BoundedTensorSpec(-1, 1, shape=())
TanhModule(in_keys=["act"], low=-2, spec=spec)
with pytest.raises(ValueError, match=r"The maximum value \(-2\) provided to"):
spec = BoundedTensorSpec(-1, 1, shape=())
TanhModule(in_keys=["act"], high=-2, spec=spec)
with pytest.raises(ValueError, match="Got high < low"):
TanhModule(in_keys=["act"], high=-2, low=-1)
def test_minmax(self):
mod = TanhModule(
in_keys=["act"],
high=2,
)
assert isinstance(mod.act_high, torch.Tensor)
mod = TanhModule(
in_keys=["act"],
low=-2,
)
assert isinstance(mod.act_low, torch.Tensor)
mod = TanhModule(
in_keys=["act"],
high=np.ones((1,)),
)
assert isinstance(mod.act_high, torch.Tensor)
mod = TanhModule(
in_keys=["act"],
low=-np.ones((1,)),
)
assert isinstance(mod.act_low, torch.Tensor)
@pytest.mark.parametrize("clamp", [True, False])
def test_boundaries(self, clamp):
torch.manual_seed(0)
eps = torch.finfo(torch.float).resolution
for _ in range(10):
min, max = (5 * torch.randn(2)).sort()[0]
mod = TanhModule(in_keys=["act"], low=min, high=max, clamp=clamp)
assert mod.non_trivial
td = TensorDict({"act": (2 * torch.rand(100) - 1) * 10}, [])
mod(td)
# we should have a good proportion of samples close to the boundaries
assert torch.isclose(td["act"], max).any()
assert torch.isclose(td["act"], min).any()
if not clamp:
assert (td["act"] <= max + eps).all()
assert (td["act"] >= min - eps).all()
else:
assert (td["act"] < max + eps).all()
assert (td["act"] > min - eps).all()
@pytest.mark.parametrize("out_keys", [[("a", "c"), "b"], None])
@pytest.mark.parametrize("has_spec", [[True, True], [True, False], [False, False]])
def test_multi_inputs(self, out_keys, has_spec):
in_keys = [("x", "z"), "y"]
real_out_keys = out_keys if out_keys is not None else in_keys
if any(has_spec):
spec = {}
if has_spec[0]:
spec.update({real_out_keys[0]: BoundedTensorSpec(-2.0, 2.0, shape=())})
low, high = -2.0, 2.0
if has_spec[1]:
spec.update({real_out_keys[1]: BoundedTensorSpec(-3.0, 3.0, shape=())})
low, high = None, None
spec = CompositeSpec(spec)
else:
spec = None
low, high = -2.0, 2.0
mod = TanhModule(
in_keys=in_keys,
out_keys=out_keys,
low=low,
high=high,
spec=spec,
clamp=False,
)
data = TensorDict({in_key: torch.randn(100) * 100 for in_key in in_keys}, [])
mod(data)
assert all(out_key in data.keys(True, True) for out_key in real_out_keys)
eps = torch.finfo(torch.float).resolution
for out_key in real_out_keys:
key = out_key if isinstance(out_key, str) else "_".join(out_key)
low_key = f"{key}_low"
high_key = f"{key}_high"
min, max = getattr(mod, low_key), getattr(mod, high_key)
assert torch.isclose(data[out_key], max).any()
assert torch.isclose(data[out_key], min).any()
assert (data[out_key] <= max + eps).all()
assert (data[out_key] >= min - eps).all()
class TestMultiAgent:
def _get_mock_input_td(
self, n_agents, n_agents_inputs, state_shape=(64, 64, 3), T=None, batch=(2,)
):
if T is not None:
batch = batch + (T,)
obs = torch.randn(*batch, n_agents, n_agents_inputs)
state = torch.randn(*batch, *state_shape)
td = TensorDict(
{
"agents": TensorDict(
{"observation": obs},
[*batch, n_agents],
),
"state": state,
},
batch_size=batch,
)
return td
@retry(AssertionError, 5)
@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize("centralized", [True, False])
@pytest.mark.parametrize("n_agent_inputs", [6, None])
@pytest.mark.parametrize("batch", [(4,), (4, 3), ()])
def test_multiagent_mlp(
self,
n_agents,
centralized,
share_params,
batch,
n_agent_inputs,
n_agent_outputs=2,
):
torch.manual_seed(1)
mlp = MultiAgentMLP(
n_agent_inputs=n_agent_inputs,
n_agent_outputs=n_agent_outputs,
n_agents=n_agents,
centralized=centralized,
share_params=share_params,
depth=2,
)
if n_agent_inputs is None:
n_agent_inputs = 6
td = self._get_mock_input_td(n_agents, n_agent_inputs, batch=batch)
obs = td.get(("agents", "observation"))
out = mlp(obs)
assert out.shape == (*batch, n_agents, n_agent_outputs)
for i in range(n_agents):
if centralized and share_params:
assert torch.allclose(out[..., i, :], out[..., 0, :])
else:
for j in range(i + 1, n_agents):
assert not torch.allclose(out[..., i, :], out[..., j, :])
obs[..., 0, 0] += 1
out2 = mlp(obs)
for i in range(n_agents):
if centralized:
# a modification to the input of agent 0 will impact all agents
assert not torch.allclose(out[..., i, :], out2[..., i, :])
elif i > 0:
assert torch.allclose(out[..., i, :], out2[..., i, :])
obs = (
torch.randn(*batch, 1, n_agent_inputs)
.expand(*batch, n_agents, n_agent_inputs)
.clone()
)
out = mlp(obs)
for i in range(n_agents):
if share_params:
# same input same output
assert torch.allclose(out[..., i, :], out[..., 0, :])
else:
for j in range(i + 1, n_agents):
# same input different output
assert not torch.allclose(out[..., i, :], out[..., j, :])
pattern = rf"""MultiAgentMLP\(
MLP\(
\(0\): Linear\(in_features=\d+, out_features=32, bias=True\)
\(1\): Tanh\(\)
\(2\): Linear\(in_features=32, out_features=32, bias=True\)
\(3\): Tanh\(\)
\(4\): Linear\(in_features=32, out_features=2, bias=True\)
\),
n_agents={n_agents},
share_params={share_params},
centralized={centralized},
agent_dim={-2}\)"""
assert re.match(pattern, str(mlp), re.DOTALL)
def test_multiagent_mlp_lazy(self):
mlp = MultiAgentMLP(
n_agent_inputs=None,
n_agent_outputs=6,
n_agents=3,
centralized=True,
share_params=False,
depth=2,
)
optim = torch.optim.SGD(mlp.parameters(), lr=1e-3)
for p in mlp.parameters():
if isinstance(p, torch.nn.parameter.UninitializedParameter):
break
else:
raise AssertionError("No UninitializedParameter found")
for p in optim.param_groups[0]["params"]:
if isinstance(p, torch.nn.parameter.UninitializedParameter):
break
else:
raise AssertionError("No UninitializedParameter found")
for _ in range(2):
td = self._get_mock_input_td(3, 4, batch=(10,))
obs = td.get(("agents", "observation"))
out = mlp(obs)
assert (
not mlp.params[0]
.apply(lambda x, y: torch.isclose(x, y), mlp.params[1])
.any()
)
out.mean().backward()
optim.step()
for p in mlp.parameters():
if isinstance(p, torch.nn.parameter.UninitializedParameter):
raise AssertionError("UninitializedParameter found")
for p in optim.param_groups[0]["params"]:
if isinstance(p, torch.nn.parameter.UninitializedParameter):
raise AssertionError("UninitializedParameter found")
@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize("centralized", [True, False])
def test_multiagent_reset_mlp(
self,
n_agents,
centralized,
share_params,
):
actor_net = MultiAgentMLP(
n_agent_inputs=4,
n_agent_outputs=6,
num_cells=(4, 4),
n_agents=n_agents,
centralized=centralized,
share_params=share_params,
)
params_before = actor_net.params.clone()
actor_net.reset_parameters()
params_after = actor_net.params
assert not params_before.apply(
lambda x, y: torch.isclose(x, y), params_after, batch_size=[]
).any()
if params_after.numel() > 1:
assert (
not params_after[0]
.apply(lambda x, y: torch.isclose(x, y), params_after[1], batch_size=[])
.any()
)
@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize("centralized", [True, False])
@pytest.mark.parametrize("channels", [3, None])
@pytest.mark.parametrize("batch", [(4,), (4, 3), ()])
def test_multiagent_cnn(
self,
n_agents,
centralized,
share_params,
batch,
channels,
x=15,
y=15,
):
torch.manual_seed(0)
cnn = MultiAgentConvNet(
n_agents=n_agents,
centralized=centralized,
share_params=share_params,
in_features=channels,
kernel_sizes=3,
)
if channels is None:
channels = 3
td = TensorDict(
{
"agents": TensorDict(
{"observation": torch.randn(*batch, n_agents, channels, x, y)},
[*batch, n_agents],
)
},
batch_size=batch,
)
obs = td[("agents", "observation")]
out = cnn(obs)
assert out.shape[:-1] == (*batch, n_agents)
if centralized and share_params:
torch.testing.assert_close(out, out[..., :1, :].expand_as(out))
else:
for i in range(n_agents):
for j in range(i + 1, n_agents):
assert not torch.allclose(out[..., i, :], out[..., j, :])
obs[..., 0, 0, 0, 0] += 1
out2 = cnn(obs)
if centralized:
# a modification to the input of agent 0 will impact all agents
assert not torch.isclose(out, out2).all()
elif n_agents > 1:
assert not torch.isclose(out[..., 0, :], out2[..., 0, :]).all()
torch.testing.assert_close(out[..., 1:, :], out2[..., 1:, :])
obs = torch.randn(*batch, 1, channels, x, y).expand(
*batch, n_agents, channels, x, y
)
out = cnn(obs)
for i in range(n_agents):
if share_params:
# same input same output
assert torch.allclose(out[..., i, :], out[..., 0, :])
else:
for j in range(i + 1, n_agents):
# same input different output
assert not torch.allclose(out[..., i, :], out[..., j, :])
def test_multiagent_cnn_lazy(self):
n_agents = 5
n_channels = 3
cnn = MultiAgentConvNet(
n_agents=n_agents,
centralized=False,
share_params=False,
in_features=None,
kernel_sizes=3,
)
optim = torch.optim.SGD(cnn.parameters(), lr=1e-3)
for p in cnn.parameters():
if isinstance(p, torch.nn.parameter.UninitializedParameter):
break
else:
raise AssertionError("No UninitializedParameter found")
for p in optim.param_groups[0]["params"]:
if isinstance(p, torch.nn.parameter.UninitializedParameter):
break
else:
raise AssertionError("No UninitializedParameter found")
for _ in range(2):
td = TensorDict(
{
"agents": TensorDict(
{"observation": torch.randn(4, n_agents, n_channels, 15, 15)},
[4, 5],
)
},
batch_size=[4],
)
obs = td[("agents", "observation")]
out = cnn(obs)
assert (
not cnn.params[0]
.apply(lambda x, y: torch.isclose(x, y), cnn.params[1])
.any()
)
out.mean().backward()
optim.step()
for p in cnn.parameters():
if isinstance(p, torch.nn.parameter.UninitializedParameter):
raise AssertionError("UninitializedParameter found")
for p in optim.param_groups[0]["params"]:
if isinstance(p, torch.nn.parameter.UninitializedParameter):
raise AssertionError("UninitializedParameter found")
@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize("centralized", [True, False])
def test_multiagent_reset_cnn(
self,
n_agents,
centralized,
share_params,
):
actor_net = MultiAgentConvNet(
in_features=4,
num_cells=[5, 5],
n_agents=n_agents,
centralized=centralized,
share_params=share_params,
)
params_before = actor_net.params.clone()
actor_net.reset_parameters()
params_after = actor_net.params
assert not params_before.apply(
lambda x, y: torch.isclose(x, y), params_after, batch_size=[]
).any()
if params_after.numel() > 1:
assert (
not params_after[0]
.apply(lambda x, y: torch.isclose(x, y), params_after[1], batch_size=[])
.any()
)
@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize("batch", [(10,), (10, 3), ()])
def test_vdn(self, n_agents, batch):
torch.manual_seed(0)
mixer = VDNMixer(n_agents=n_agents, device="cpu")
td = self._get_mock_input_td(n_agents, batch=batch, n_agents_inputs=1)
obs = td.get(("agents", "observation"))
assert obs.shape == (*batch, n_agents, 1)
out = mixer(obs)
assert out.shape == (*batch, 1)
assert torch.equal(obs.sum(-2), out)
@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize("batch", [(10,), (10, 3), ()])
@pytest.mark.parametrize("state_shape", [(64, 64, 3), (10,)])
def test_qmix(self, n_agents, batch, state_shape):
torch.manual_seed(0)
mixer = QMixer(
n_agents=n_agents,
state_shape=state_shape,
mixing_embed_dim=32,
device="cpu",
)
td = self._get_mock_input_td(
n_agents, batch=batch, n_agents_inputs=1, state_shape=state_shape
)
obs = td.get(("agents", "observation"))
state = td.get("state")
assert obs.shape == (*batch, n_agents, 1)
assert state.shape == (*batch, *state_shape)
out = mixer(obs, state)
assert out.shape == (*batch, 1)
@pytest.mark.parametrize("mixer", ["qmix", "vdn"])
def test_mixer_malformed_input(
self, mixer, n_agents=3, batch=(32,), state_shape=(64, 64, 3)
):
td = self._get_mock_input_td(
n_agents, batch=batch, n_agents_inputs=3, state_shape=state_shape
)
if mixer == "qmix":
mixer = QMixer(
n_agents=n_agents,
state_shape=state_shape,
mixing_embed_dim=32,
device="cpu",
)
else:
mixer = VDNMixer(n_agents=n_agents, device="cpu")
obs = td.get(("agents", "observation"))
state = td.get("state")
if mixer.needs_state:
with pytest.raises(
ValueError,
match="Mixer that needs state was passed more than 2 inputs",
):
mixer(obs)
else:
with pytest.raises(
ValueError,
match="Mixer that doesn't need state was passed more than 1 input",
):
mixer(obs, state)
in_put = [obs, state] if mixer.needs_state else [obs]
with pytest.raises(
ValueError,
match="Mixer network expected chosen_action_value with last 2 dimensions",
):
mixer(*in_put)
if mixer.needs_state:
state_diff = state.unsqueeze(-1)
with pytest.raises(
ValueError,
match="Mixer network expected state with ending shape",
):
mixer(obs, state_diff)
td = self._get_mock_input_td(
n_agents, batch=batch, n_agents_inputs=1, state_shape=state_shape
)
obs = td.get(("agents", "observation"))
state = td.get("state")
obs = obs.sum(-2)
in_put = [obs, state] if mixer.needs_state else [obs]
with pytest.raises(
ValueError,
match="Mixer network expected chosen_action_value with last 2 dimensions",
):
mixer(*in_put)
obs = td.get(("agents", "observation"))
state = td.get("state")
in_put = [obs, state] if mixer.needs_state else [obs]
mixer(*in_put)
@pytest.mark.skipif(torch.__version__ < "2.0", reason="torch 2.0 is required")
@pytest.mark.parametrize("use_vmap", [False, True])
@pytest.mark.parametrize("scale", range(10))
def test_tanh_atanh(use_vmap, scale):
if use_vmap:
try:
from torch import vmap
except ImportError:
try:
from functorch import vmap
except ImportError:
raise pytest.skip("functorch not found")
torch.manual_seed(0)
x = (torch.randn(10, dtype=torch.double) * scale).requires_grad_(True)
if not use_vmap:
y = safetanh(x, 1e-6)
else:
y = vmap(safetanh, (0, None))(x, 1e-6)
if not use_vmap:
xp = safeatanh(y, 1e-6)
else:
xp = vmap(safeatanh, (0, None))(y, 1e-6)
xp.sum().backward()
torch.testing.assert_close(x.grad, torch.ones_like(x))
@pytest.mark.skipif(
not _has_transformers, reason="transformers needed for TestDecisionTransformer"
)
class TestDecisionTransformer:
def test_init(self):
DecisionTransformer(
3,
4,
)
with pytest.raises(TypeError):
DecisionTransformer(3, 4, config="some_str")
DecisionTransformer(
3,
4,
config=DecisionTransformer.DTConfig(
n_layer=2, n_embd=16, n_positions=16, n_inner=16, n_head=2
),
)
@pytest.mark.parametrize("batch_dims", [[], [3], [3, 4]])
def test_exec(self, batch_dims, T=5):
observations = torch.randn(*batch_dims, T, 3)
actions = torch.randn(*batch_dims, T, 4)
r2go = torch.randn(*batch_dims, T, 1)
model = DecisionTransformer(
3,
4,
config=DecisionTransformer.DTConfig(
n_layer=2, n_embd=16, n_positions=16, n_inner=16, n_head=2
),
)
out = model(observations, actions, r2go)
assert out.shape == torch.Size([*batch_dims, T, 16])
@pytest.mark.parametrize("batch_dims", [[], [3], [3, 4]])
def test_dtactor(self, batch_dims, T=5):
dtactor = DTActor(
3,
4,
transformer_config=DecisionTransformer.DTConfig(
n_layer=2, n_embd=16, n_positions=16, n_inner=16, n_head=2
),
)
observations = torch.randn(*batch_dims, T, 3)
actions = torch.randn(*batch_dims, T, 4)
r2go = torch.randn(*batch_dims, T, 1)
out = dtactor(observations, actions, r2go)
assert out.shape == torch.Size([*batch_dims, T, 4])
@pytest.mark.parametrize("batch_dims", [[], [3], [3, 4]])
def test_onlinedtactor(self, batch_dims, T=5):
dtactor = OnlineDTActor(
3,
4,
transformer_config=DecisionTransformer.DTConfig(
n_layer=2, n_embd=16, n_positions=16, n_inner=16, n_head=2
),
)
observations = torch.randn(*batch_dims, T, 3)
actions = torch.randn(*batch_dims, T, 4)
r2go = torch.randn(*batch_dims, T, 1)
mu, sig = dtactor(observations, actions, r2go)
assert mu.shape == torch.Size([*batch_dims, T, 4])
assert sig.shape == torch.Size([*batch_dims, T, 4])
assert (dtactor.log_std_min < sig.log()).all()
assert (dtactor.log_std_max > sig.log()).all()
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("bias", [True, False])
def test_python_lstm_cell(device, bias):
lstm_cell1 = LSTMCell(10, 20, device=device, bias=bias)
lstm_cell2 = nn.LSTMCell(10, 20, device=device, bias=bias)
lstm_cell1.load_state_dict(lstm_cell2.state_dict())
# Make sure parameters match
for (k1, v1), (k2, v2) in zip(
lstm_cell1.named_parameters(), lstm_cell2.named_parameters()
):
assert k1 == k2, f"Parameter names do not match: {k1} != {k2}"
assert (
v1.shape == v2.shape
), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}"
# Run loop
input = torch.randn(2, 3, 10, device=device)
h0 = torch.randn(3, 20, device=device)
c0 = torch.randn(3, 20, device=device)
with torch.no_grad():
for i in range(input.size()[0]):
h1, c1 = lstm_cell1(input[i], (h0, c0))
h2, c2 = lstm_cell2(input[i], (h0, c0))
# Make sure the final hidden states have the same shape
assert h1.shape == h2.shape
assert c1.shape == c2.shape
torch.testing.assert_close(h1, h2)
torch.testing.assert_close(c1, c2)
h0 = h1
c0 = c1
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("bias", [True, False])
def test_python_gru_cell(device, bias):
gru_cell1 = GRUCell(10, 20, device=device, bias=bias)
gru_cell2 = nn.GRUCell(10, 20, device=device, bias=bias)
gru_cell2.load_state_dict(gru_cell1.state_dict())
# Make sure parameters match
for (k1, v1), (k2, v2) in zip(
gru_cell1.named_parameters(), gru_cell2.named_parameters()
):
assert k1 == k2, f"Parameter names do not match: {k1} != {k2}"
assert (v1 == v2).all()
assert (
v1.shape == v2.shape
), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}"
# Run loop
input = torch.randn(2, 3, 10, device=device)
h0 = torch.zeros(3, 20, device=device)
with torch.no_grad():
for i in range(input.size()[0]):
h1 = gru_cell1(input[i], h0)
h2 = gru_cell2(input[i], h0)
# Make sure the final hidden states have the same shape
assert h1.shape == h2.shape
torch.testing.assert_close(h1, h2)
h0 = h1
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("batch_first", [True, False])
@pytest.mark.parametrize("dropout", [0.0, 0.5])
@pytest.mark.parametrize("num_layers", [1, 2])
def test_python_lstm(device, bias, dropout, batch_first, num_layers):
B = 5
T = 3
lstm1 = LSTM(
input_size=10,
hidden_size=20,
num_layers=num_layers,
device=device,
bias=bias,
batch_first=batch_first,
)
lstm2 = nn.LSTM(
input_size=10,
hidden_size=20,
num_layers=num_layers,
device=device,
bias=bias,
batch_first=batch_first,
)
lstm2.load_state_dict(lstm1.state_dict())
# Make sure parameters match
for (k1, v1), (k2, v2) in zip(lstm1.named_parameters(), lstm2.named_parameters()):
assert k1 == k2, f"Parameter names do not match: {k1} != {k2}"
assert (
v1.shape == v2.shape
), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}"
if batch_first:
input = torch.randn(B, T, 10, device=device)
else:
input = torch.randn(T, B, 10, device=device)
h0 = torch.randn(num_layers, 5, 20, device=device)
c0 = torch.randn(num_layers, 5, 20, device=device)
# Test without hidden states
with torch.no_grad():
output1, (h1, c1) = lstm1(input)
output2, (h2, c2) = lstm2(input)
assert h1.shape == h2.shape
assert c1.shape == c2.shape
assert output1.shape == output2.shape
if dropout == 0.0:
torch.testing.assert_close(output1, output2)
torch.testing.assert_close(h1, h2)
torch.testing.assert_close(c1, c2)
# Test with hidden states
with torch.no_grad():
output1, (h1, c1) = lstm1(input, (h0, c0))
output2, (h2, c2) = lstm1(input, (h0, c0))
assert h1.shape == h2.shape
assert c1.shape == c2.shape
assert output1.shape == output2.shape
if dropout == 0.0:
torch.testing.assert_close(output1, output2)
torch.testing.assert_close(h1, h2)
torch.testing.assert_close(c1, c2)
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("batch_first", [True, False])
@pytest.mark.parametrize("dropout", [0.0, 0.5])
@pytest.mark.parametrize("num_layers", [1, 2])
def test_python_gru(device, bias, dropout, batch_first, num_layers):
B = 5
T = 3
gru1 = GRU(
input_size=10,
hidden_size=20,
num_layers=num_layers,
device=device,
bias=bias,
batch_first=batch_first,
)
gru2 = nn.GRU(
input_size=10,
hidden_size=20,
num_layers=num_layers,
device=device,
bias=bias,
batch_first=batch_first,
)
gru2.load_state_dict(gru1.state_dict())
# Make sure parameters match
for (k1, v1), (k2, v2) in zip(gru1.named_parameters(), gru2.named_parameters()):
assert k1 == k2, f"Parameter names do not match: {k1} != {k2}"
torch.testing.assert_close(v1, v2)
assert (
v1.shape == v2.shape
), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}"
if batch_first:
input = torch.randn(B, T, 10, device=device)
else:
input = torch.randn(T, B, 10, device=device)
h0 = torch.randn(num_layers, 5, 20, device=device)
# Test without hidden states
with torch.no_grad():
output1, h1 = gru1(input)
output2, h2 = gru2(input)
assert h1.shape == h2.shape
assert output1.shape == output2.shape
if dropout == 0.0:
torch.testing.assert_close(output1, output2)
torch.testing.assert_close(h1, h2)
# Test with hidden states
with torch.no_grad():
output1, h1 = gru1(input, h0)
output2, h2 = gru2(input, h0)
assert h1.shape == h2.shape
assert output1.shape == output2.shape
if dropout == 0.0:
torch.testing.assert_close(output1, output2)
torch.testing.assert_close(h1, h2)
class TestBatchRenorm:
@pytest.mark.parametrize("num_steps", [0, 5])
@pytest.mark.parametrize("smooth", [False, True])
def test_batchrenorm(self, num_steps, smooth):
torch.manual_seed(0)
bn = torch.nn.BatchNorm1d(5, momentum=0.1, eps=1e-5)
brn = BatchRenorm1d(
5,
momentum=0.1,
eps=1e-5,
warmup_steps=num_steps,
max_d=10000,
max_r=10000,
smooth=smooth,
)
bn.train()
brn.train()
data_train = torch.randn(100, 5).split(25)
data_test = torch.randn(100, 5)
for i, d in enumerate(data_train):
b = bn(d)
a = brn(d)
if num_steps > 0 and (
(i < num_steps and not smooth) or (i == 0 and smooth)
):
torch.testing.assert_close(a, b)
else:
assert not torch.isclose(a, b).all(), i
bn.eval()
brn.eval()
torch.testing.assert_close(bn(data_test), brn(data_test))
if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)