Skip to content

Commit

Permalink
[BugFix] Discrete SAC rewrite (#1461)
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <matbet@meta.com>
  • Loading branch information
matteobettini authored Aug 30, 2023
1 parent bd316e9 commit 78b2bb2
Show file tree
Hide file tree
Showing 11 changed files with 717 additions and 353 deletions.
6 changes: 6 additions & 0 deletions .circleci/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,12 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/multiagent/q
train.num_epochs=3 \
train.minibatch_size=100 \
logger.backend=
python .circleci/unittest/helpers/coverage_run_parallel.py examples/multiagent/sac.py \
collector.n_iters=2 \
collector.frames_per_batch=200 \
train.num_epochs=3 \
train.minibatch_size=100 \
logger.backend=


python .circleci/unittest/helpers/coverage_run_parallel.py examples/bandits/dqn.py --n_steps=100
Expand Down
8 changes: 5 additions & 3 deletions examples/discrete_sac/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import torch.cuda
import tqdm
from tensordict.nn import InteractionType
from tensordict.nn import InteractionType, TensorDictModule

from torch import nn, optim
from torchrl.collectors import SyncDataCollector
Expand All @@ -27,7 +27,7 @@
from torchrl.modules import MLP, SafeModule
from torchrl.modules.distributions import OneHotCategorical

from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
from torchrl.modules.tensordict_module.actors import ProbabilisticActor

from torchrl.objectives import DiscreteSACLoss, SoftUpdate
from torchrl.record.loggers import generate_exp_name, get_logger
Expand Down Expand Up @@ -150,8 +150,9 @@ def env_factory(num_workers):
**qvalue_net_kwargs,
)

qvalue = ValueOperator(
qvalue = TensorDictModule(
in_keys=in_keys,
out_keys=["action_value"],
module=qvalue_net,
).to(device)

Expand All @@ -171,6 +172,7 @@ def env_factory(num_workers):
# Create SAC loss
loss_module = DiscreteSACLoss(
actor_network=model[0],
action_space=test_env.action_spec,
qvalue_network=model[1],
num_actions=num_actions,
num_qvalue_nets=2,
Expand Down
321 changes: 321 additions & 0 deletions examples/multiagent/sac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,321 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import time

import hydra
import torch

from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn
from torch.distributions import Categorical
from torchrl.collectors import SyncDataCollector
from torchrl.data import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import RewardSum, TransformedEnv
from torchrl.envs.libs.vmas import VmasEnv
from torchrl.envs.utils import ExplorationType, set_exploration_type

from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.modules.models.multiagent import MultiAgentMLP
from torchrl.objectives import DiscreteSACLoss, SACLoss, SoftUpdate, ValueEstimators
from utils.logging import init_logging, log_evaluation, log_training


def rendering_callback(env, td):
env.frames.append(env.render(mode="rgb_array", agent_index_focus=None))


@hydra.main(version_base="1.1", config_path=".", config_name="sac")
def train(cfg: "DictConfig"): # noqa: F821
# Device
cfg.train.device = "cpu" if not torch.has_cuda else "cuda:0"
cfg.env.device = cfg.train.device

# Seeding
torch.manual_seed(cfg.seed)

# Sampling
cfg.env.vmas_envs = cfg.collector.frames_per_batch // cfg.env.max_steps
cfg.collector.total_frames = cfg.collector.frames_per_batch * cfg.collector.n_iters
cfg.buffer.memory_size = cfg.collector.frames_per_batch

# Create env and env_test
env = VmasEnv(
scenario=cfg.env.scenario_name,
num_envs=cfg.env.vmas_envs,
continuous_actions=cfg.env.continuous_actions,
max_steps=cfg.env.max_steps,
device=cfg.env.device,
seed=cfg.seed,
# Scenario kwargs
**cfg.env.scenario,
)
env = TransformedEnv(
env,
RewardSum(in_keys=[env.reward_key], out_keys=[("agents", "episode_reward")]),
)

env_test = VmasEnv(
scenario=cfg.env.scenario_name,
num_envs=cfg.eval.evaluation_episodes,
continuous_actions=cfg.env.continuous_actions,
max_steps=cfg.env.max_steps,
device=cfg.env.device,
seed=cfg.seed,
# Scenario kwargs
**cfg.env.scenario,
)

# Policy
if cfg.env.continuous_actions:
actor_net = nn.Sequential(
MultiAgentMLP(
n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
n_agent_outputs=2 * env.action_spec.shape[-1],
n_agents=env.n_agents,
centralised=False,
share_params=cfg.model.shared_parameters,
device=cfg.train.device,
depth=2,
num_cells=256,
activation_class=nn.Tanh,
),
NormalParamExtractor(),
)
policy_module = TensorDictModule(
actor_net,
in_keys=[("agents", "observation")],
out_keys=[("agents", "loc"), ("agents", "scale")],
)

policy = ProbabilisticActor(
module=policy_module,
spec=env.unbatched_action_spec,
in_keys=[("agents", "loc"), ("agents", "scale")],
out_keys=[env.action_key],
distribution_class=TanhNormal,
distribution_kwargs={
"min": env.unbatched_action_spec[("agents", "action")].space.minimum,
"max": env.unbatched_action_spec[("agents", "action")].space.maximum,
},
return_log_prob=True,
)

# Critic
module = MultiAgentMLP(
n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1]
+ env.action_spec.shape[-1], # Q critic takes action and value
n_agent_outputs=1,
n_agents=env.n_agents,
centralised=cfg.model.centralised_critic,
share_params=cfg.model.shared_parameters,
device=cfg.train.device,
depth=2,
num_cells=256,
activation_class=nn.Tanh,
)
value_module = ValueOperator(
module=module,
in_keys=[("agents", "observation"), env.action_key],
out_keys=[("agents", "state_action_value")],
)
else:
actor_net = nn.Sequential(
MultiAgentMLP(
n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
n_agent_outputs=env.action_spec.space.n,
n_agents=env.n_agents,
centralised=False,
share_params=cfg.model.shared_parameters,
device=cfg.train.device,
depth=2,
num_cells=256,
activation_class=nn.Tanh,
),
)
policy_module = TensorDictModule(
actor_net,
in_keys=[("agents", "observation")],
out_keys=[("agents", "logits")],
)
policy = ProbabilisticActor(
module=policy_module,
spec=env.unbatched_action_spec,
in_keys=[("agents", "logits")],
out_keys=[env.action_key],
distribution_class=Categorical,
return_log_prob=True,
)

# Critic
module = MultiAgentMLP(
n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
n_agent_outputs=env.action_spec.space.n,
n_agents=env.n_agents,
centralised=cfg.model.centralised_critic,
share_params=cfg.model.shared_parameters,
device=cfg.train.device,
depth=2,
num_cells=256,
activation_class=nn.Tanh,
)
value_module = ValueOperator(
module=module,
in_keys=[("agents", "observation")],
out_keys=[("agents", "action_value")],
)

collector = SyncDataCollector(
env,
policy,
device=cfg.env.device,
storing_device=cfg.train.device,
frames_per_batch=cfg.collector.frames_per_batch,
total_frames=cfg.collector.total_frames,
)

replay_buffer = TensorDictReplayBuffer(
storage=LazyTensorStorage(cfg.buffer.memory_size, device=cfg.train.device),
sampler=SamplerWithoutReplacement(),
batch_size=cfg.train.minibatch_size,
)

if cfg.env.continuous_actions:
loss_module = SACLoss(
actor_network=policy, qvalue_network=value_module, delay_qvalue=True
)
loss_module.set_keys(
state_action_value=("agents", "state_action_value"),
action=env.action_key,
reward=env.reward_key,
)
else:
loss_module = DiscreteSACLoss(
actor_network=policy,
qvalue_network=value_module,
delay_qvalue=True,
num_actions=env.action_spec.space.n,
action_space=env.unbatched_action_spec,
)
loss_module.set_keys(
action_value=("agents", "action_value"),
action=env.action_key,
reward=env.reward_key,
)

loss_module.make_value_estimator(ValueEstimators.TD0, gamma=cfg.loss.gamma)
target_net_updater = SoftUpdate(loss_module, eps=1 - cfg.loss.tau)

optim = torch.optim.Adam(loss_module.parameters(), cfg.train.lr)

# Logging
if cfg.logger.backend:
model_name = (
("Het" if not cfg.model.shared_parameters else "")
+ ("MA" if cfg.model.centralised_critic else "I")
+ "SAC"
)
logger = init_logging(cfg, model_name)

total_time = 0
total_frames = 0
sampling_start = time.time()
for i, tensordict_data in enumerate(collector):
print(f"\nIteration {i}")

sampling_time = time.time() - sampling_start

tensordict_data.set(
("next", "done"),
tensordict_data.get(("next", "done"))
.unsqueeze(-1)
.expand(tensordict_data.get(("next", env.reward_key)).shape),
) # We need to expand the done to match the reward shape

current_frames = tensordict_data.numel()
total_frames += current_frames
data_view = tensordict_data.reshape(-1)
replay_buffer.extend(data_view)

training_tds = []
training_start = time.time()
for _ in range(cfg.train.num_epochs):
for _ in range(cfg.collector.frames_per_batch // cfg.train.minibatch_size):
subdata = replay_buffer.sample()
loss_vals = loss_module(subdata)
training_tds.append(loss_vals.detach())

loss_value = (
loss_vals["loss_actor"]
+ loss_vals["loss_alpha"]
+ loss_vals["loss_qvalue"]
+ loss_vals["loss_alpha"]
)

loss_value.backward()

total_norm = torch.nn.utils.clip_grad_norm_(
loss_module.parameters(), cfg.train.max_grad_norm
)
training_tds[-1].set("grad_norm", total_norm.mean())

optim.step()
optim.zero_grad()
target_net_updater.step()

collector.update_policy_weights_()

training_time = time.time() - training_start

iteration_time = sampling_time + training_time
total_time += iteration_time
training_tds = torch.stack(training_tds)

# More logs
if cfg.logger.backend:
log_training(
logger,
training_tds,
tensordict_data,
sampling_time,
training_time,
total_time,
i,
current_frames,
total_frames,
step=i,
)

if (
cfg.eval.evaluation_episodes > 0
and i % cfg.eval.evaluation_interval == 0
and cfg.logger.backend
):
evaluation_start = time.time()
with torch.no_grad() and set_exploration_type(ExplorationType.MODE):
env_test.frames = []
rollouts = env_test.rollout(
max_steps=cfg.env.max_steps,
policy=policy,
callback=rendering_callback,
auto_cast_to_device=True,
break_when_any_done=False,
# We are running vectorized evaluation we do not want it to stop when just one env is done
)

evaluation_time = time.time() - evaluation_start

log_evaluation(logger, rollouts, env_test, evaluation_time, step=i)

if cfg.logger.backend == "wandb":
logger.experiment.log({}, commit=True)
sampling_start = time.time()


if __name__ == "__main__":
train()
Loading

0 comments on commit 78b2bb2

Please sign in to comment.