Skip to content

Commit

Permalink
[Example] A2C simplified example (pytorch#1076)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 authored Apr 25, 2023
1 parent df233fb commit 6c89a65
Show file tree
Hide file tree
Showing 8 changed files with 663 additions and 207 deletions.
34 changes: 14 additions & 20 deletions .circleci/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,13 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py
record_frames=4 \
buffer_size=120
python .circleci/unittest/helpers/coverage_run_parallel.py examples/a2c/a2c.py \
total_frames=48 \
batch_size=10 \
frames_per_batch=16 \
num_workers=4 \
env_per_collector=2 \
collector_devices=cuda:0 \
optim_steps_per_batch=1 \
record_video=True \
record_frames=4 \
logger=csv
env.num_envs=1 \
collector.total_frames=48 \
collector.frames_per_batch=16 \
collector.collector_device=cuda:0 \
logger.backend= \
logger.log_interval=4 \
optim.lr_scheduler=False
python .circleci/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn.py \
total_frames=48 \
init_random_frames=10 \
Expand Down Expand Up @@ -142,16 +139,13 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py
record_frames=4 \
buffer_size=120
python .circleci/unittest/helpers/coverage_run_parallel.py examples/a2c/a2c.py \
total_frames=48 \
batch_size=10 \
frames_per_batch=16 \
num_workers=2 \
env_per_collector=1 \
collector_devices=cuda:0 \
optim_steps_per_batch=1 \
record_video=True \
record_frames=4 \
logger=csv
env.num_envs=1 \
collector.total_frames=48 \
collector.frames_per_batch=16 \
collector.collector_device=cuda:0 \
logger.backend= \
logger.log_interval=4 \
optim.lr_scheduler=False
python .circleci/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn.py \
total_frames=48 \
init_random_frames=10 \
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ A series of [examples](examples/) are provided with an illustrative purpose:
- [DDPG](examples/ddpg/ddpg.py)
- [IQL](examples/iql/iql.py)
- [TD3](examples/td3/td3.py)
- [A2C](examples/a2c/a2c.py)
- [A2C](examples/a2c_old/a2c.py)
- [PPO](examples/ppo/ppo.py)
- [SAC](examples/sac/sac.py)
- [REDQ](examples/redq/redq.py)
Expand Down
270 changes: 125 additions & 145 deletions examples/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,164 +2,144 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""A2C Example.
import dataclasses
This is a self-contained example of a A2C training script.
import hydra
import torch.cuda
from hydra.core.config_store import ConfigStore
from torchrl.envs.transforms import RewardScaling
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives.value import TD0Estimator
from torchrl.record.loggers import generate_exp_name, get_logger
from torchrl.trainers.helpers.collectors import (
make_collector_onpolicy,
OnPolicyCollectorConfig,
)
from torchrl.trainers.helpers.envs import (
correct_for_frame_skip,
EnvConfig,
initialize_observation_norm_transforms,
parallel_env_constructor,
retrieve_observation_norms_state_dict,
transformed_env_constructor,
)
from torchrl.trainers.helpers.logger import LoggerConfig
from torchrl.trainers.helpers.losses import A2CLossConfig, make_a2c_loss
from torchrl.trainers.helpers.models import A2CModelConfig, make_a2c_model
from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig

config_fields = [
(config_field.name, config_field.type, config_field)
for config_cls in (
TrainerConfig,
OnPolicyCollectorConfig,
EnvConfig,
A2CLossConfig,
A2CModelConfig,
LoggerConfig,
)
for config_field in dataclasses.fields(config_cls)
]
Both state and pixel-based environments are supported.
Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields)
cs = ConfigStore.instance()
cs.store(name="config", node=Config)
The helper functions are coded in the utils.py associated with this script.
"""
import hydra


@hydra.main(version_base=None, config_path="", config_name="config")
@hydra.main(config_path=".", config_name="config")
def main(cfg: "DictConfig"): # noqa: F821

cfg = correct_for_frame_skip(cfg)

if not isinstance(cfg.reward_scaling, float):
cfg.reward_scaling = 1.0

device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda:0")
import torch
import tqdm
from utils import (
make_a2c_models,
make_collector,
make_logger,
make_loss,
make_optim,
make_test_env,
)

exp_name = generate_exp_name("A2C", cfg.exp_name)
logger = get_logger(
logger_type=cfg.logger, logger_name="a2c_logging", experiment_name=exp_name
)
video_tag = exp_name if cfg.record_video else ""

key, init_env_steps, stats = None, None, None
if not cfg.vecnorm and cfg.norm_stats:
if not hasattr(cfg, "init_env_steps"):
raise AttributeError("init_env_steps missing from arguments.")
key = "pixels" if cfg.from_pixels else "observation_vector"
init_env_steps = cfg.init_env_steps
stats = {"loc": None, "scale": None}
elif cfg.from_pixels:
stats = {"loc": 0.5, "scale": 0.5}
proof_env = transformed_env_constructor(
cfg=cfg,
use_env_creator=False,
stats=stats,
)()
initialize_observation_norm_transforms(
proof_environment=proof_env, num_iter=init_env_steps, key=key
# Correct for frame_skip
cfg.collector.total_frames = cfg.collector.total_frames // cfg.env.frame_skip
cfg.collector.frames_per_batch = (
cfg.collector.frames_per_batch // cfg.env.frame_skip
)
_, obs_norm_state_dict = retrieve_observation_norms_state_dict(proof_env)[0]

model = make_a2c_model(
proof_env,
cfg=cfg,
device=device,
)
actor_model = model.get_policy_operator()

loss_module = make_a2c_loss(model, cfg)
if cfg.gSDE:
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
# get dimensions to build the parallel env
proof_td = model(proof_env.reset().to(device))
action_dim_gsde, state_dim_gsde = proof_td.get("_eps_gSDE").shape[-2:]
del proof_td
else:
action_dim_gsde, state_dim_gsde = None, None

proof_env.close()
create_env_fn = parallel_env_constructor(
cfg=cfg,
obs_norm_state_dict=obs_norm_state_dict,
action_dim_gsde=action_dim_gsde,
state_dim_gsde=state_dim_gsde,
)
model_device = cfg.optim.device
actor, critic = make_a2c_models(cfg)
actor = actor.to(model_device)
critic = critic.to(model_device)

collector = make_collector_onpolicy(
make_env=create_env_fn,
actor_model_explore=actor_model,
cfg=cfg,
collector = make_collector(cfg, policy=actor)
loss_module, adv_module = make_loss(
cfg.loss, actor_network=actor, value_network=critic
)

recorder = transformed_env_constructor(
cfg,
video_tag=video_tag,
norm_obs_only=True,
obs_norm_state_dict=obs_norm_state_dict,
logger=logger,
use_env_creator=False,
)()

# reset reward scaling
for t in recorder.transform:
if isinstance(t, RewardScaling):
t.scale.fill_(1.0)
t.loc.fill_(0.0)

trainer = make_trainer(
collector=collector,
loss_module=loss_module,
recorder=recorder,
target_net_updater=None,
policy_exploration=actor_model,
replay_buffer=None,
logger=logger,
cfg=cfg,
)

critic_model = model.get_value_operator()
advantage = TD0Estimator(
gamma=cfg.gamma,
value_network=critic_model,
average_rewards=True,
differentiable=True,
)
trainer.register_op(
"process_optim_batch",
torch.no_grad()(advantage),
)

final_seed = collector.set_seed(cfg.seed)
print(f"init seed: {cfg.seed}, final seed: {final_seed}")

trainer.train()
return (logger.log_dir, trainer._log_dict)
optim = make_optim(cfg.optim, actor_network=actor, value_network=critic)

batch_size = cfg.collector.total_frames * cfg.env.num_envs
total_network_updates = cfg.collector.total_frames // batch_size

scheduler = None
if cfg.optim.lr_scheduler:
scheduler = torch.optim.lr_scheduler.LinearLR(
optim, total_iters=total_network_updates, start_factor=1.0, end_factor=0.1
)

logger = None
if cfg.logger.backend:
logger = make_logger(cfg.logger)
test_env = make_test_env(cfg.env)
record_interval = cfg.logger.log_interval
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
collected_frames = 0

# Main loop
r0 = None
l0 = None
for data in collector:

frames_in_batch = data.numel()
collected_frames += frames_in_batch * cfg.env.frame_skip
pbar.update(data.numel())
data_view = data.reshape(-1)

# Compute GAE
with torch.no_grad():
batch = adv_module(data_view)

# Normalize advantage
adv = batch.get("advantage")
loc = adv.mean().item()
scale = adv.std().clamp_min(1e-6).item()
adv = (adv - loc) / scale
batch.set("advantage", adv)

# Forward pass A2C loss
batch = batch.to(model_device)
loss = loss_module(batch)
loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]

# Backward pass + learning step
loss_sum.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
list(actor.parameters()) + list(critic.parameters()), max_norm=0.5
)
optim.step()
if scheduler is not None:
scheduler.step()
optim.zero_grad()

# Logging
if r0 is None:
r0 = data["next", "reward"].mean().item()
if l0 is None:
l0 = loss_sum.item()
pbar.set_description(
f"loss: {loss_sum.item(): 4.4f} (init: {l0: 4.4f}), reward: {data['next', 'reward'].mean(): 4.4f} (init={r0: 4.4f})"
)
if logger is not None:
for key, value in loss.items():
logger.log_scalar(key, value.item(), collected_frames)
logger.log_scalar("grad_norm", grad_norm.item(), collected_frames)
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(episode_rewards) > 0:
logger.log_scalar(
"reward_training", episode_rewards.mean().item(), collected_frames
)
collector.update_policy_weights_()

# Test current policy
if (
logger is not None
and (collected_frames - frames_in_batch) // record_interval
< collected_frames // record_interval
):

with torch.no_grad():
test_env.eval()
actor.eval()
# Generate a complete episode
td_test = test_env.rollout(
policy=actor,
max_steps=10_000_000,
auto_reset=True,
auto_cast_to_device=True,
break_when_any_done=True,
).clone()
logger.log_scalar(
"reward_testing",
td_test["next", "reward"].sum().item(),
collected_frames,
)
actor.train()


if __name__ == "__main__":
Expand Down
Binary file added examples/a2c/a2c_mujoco_halfcheetah.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 6c89a65

Please sign in to comment.