Skip to content

Commit

Permalink
Hydra integration (#202)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 24, 2022
1 parent b162180 commit 336b981
Show file tree
Hide file tree
Showing 18 changed files with 1,359 additions and 1,726 deletions.
1 change: 1 addition & 0 deletions .circleci/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ dependencies:
- scipy
- dm_control
- mujoco_py
- hydra-core
- pyrender
1 change: 1 addition & 0 deletions .circleci/unittest/linux_optdeps/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ dependencies:
- expecttest
- pyyaml
- scipy
- hydra-core
1 change: 1 addition & 0 deletions .circleci/unittest/linux_stable/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ dependencies:
- scipy
- dm_control
- mujoco_py
- hydra-core
- pyrender
131 changes: 59 additions & 72 deletions examples/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,70 +3,57 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import dataclasses
import uuid
from datetime import datetime

from torchrl.envs import ParallelEnv, EnvCreator
from torchrl.envs.utils import set_exploration_mode
from torchrl.record import VideoRecorder

try:
import configargparse as argparse

_configargparse = True
except ImportError:
import argparse

_configargparse = False

import hydra
import torch.cuda
from hydra.core.config_store import ConfigStore
from torchrl.envs import ParallelEnv, EnvCreator
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.envs.utils import set_exploration_mode
from torchrl.modules import OrnsteinUhlenbeckProcessWrapper
from torchrl.record import VideoRecorder
from torchrl.trainers.helpers.collectors import (
make_collector_offpolicy,
parser_collector_args_offpolicy,
OffPolicyCollectorConfig,
)
from torchrl.trainers.helpers.envs import (
correct_for_frame_skip,
get_stats_random_rollout,
parallel_env_constructor,
parser_env_args,
transformed_env_constructor,
EnvConfig,
)
from torchrl.trainers.helpers.losses import make_ddpg_loss, parser_loss_args
from torchrl.trainers.helpers.losses import make_ddpg_loss, LossConfig
from torchrl.trainers.helpers.models import (
make_ddpg_actor,
parser_model_args_continuous,
DDPGModelConfig,
)
from torchrl.trainers.helpers.recorder import parser_recorder_args
from torchrl.trainers.helpers.recorder import RecorderConfig
from torchrl.trainers.helpers.replay_buffer import (
make_replay_buffer,
parser_replay_args,
ReplayArgsConfig,
)
from torchrl.trainers.helpers.trainers import make_trainer, parser_trainer_args


def make_args():
parser = argparse.ArgumentParser()
if _configargparse:
parser.add_argument(
"-c",
"--config",
required=True,
is_config_file=True,
help="config file path",
)
parser_trainer_args(parser)
parser_collector_args_offpolicy(parser)
parser_env_args(parser)
parser_loss_args(parser, algorithm="DDPG")
parser_model_args_continuous(parser, "DDPG")
parser_recorder_args(parser)
parser_replay_args(parser)
return parser


parser = make_args()
from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig

config_fields = [
(config_field.name, config_field.type, config_field)
for config_cls in (
TrainerConfig,
OffPolicyCollectorConfig,
EnvConfig,
LossConfig,
DDPGModelConfig,
RecorderConfig,
ReplayArgsConfig,
)
for config_field in dataclasses.fields(config_cls)
]
Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields)
cs = ConfigStore.instance()
cs.store(name="config", node=Config)

DEFAULT_REWARD_SCALING = {
"Hopper-v1": 5,
Expand All @@ -79,13 +66,14 @@ def make_args():
}


def main(args):
@hydra.main(version_base=None, config_path=None, config_name="config")
def main(cfg: "DictConfig"):
from torch.utils.tensorboard import SummaryWriter

args = correct_for_frame_skip(args)
cfg = correct_for_frame_skip(cfg)

if not isinstance(args.reward_scaling, float):
args.reward_scaling = DEFAULT_REWARD_SCALING.get(args.env_name, 5.0)
if not isinstance(cfg.reward_scaling, float):
cfg.reward_scaling = DEFAULT_REWARD_SCALING.get(cfg.env_name, 5.0)

device = (
torch.device("cpu")
Expand All @@ -96,50 +84,50 @@ def main(args):
exp_name = "_".join(
[
"DDPG",
args.exp_name,
cfg.exp_name,
str(uuid.uuid4())[:8],
datetime.now().strftime("%y_%m_%d-%H_%M_%S"),
]
)
writer = SummaryWriter(f"ddpg_logging/{exp_name}")
video_tag = exp_name if args.record_video else ""
video_tag = exp_name if cfg.record_video else ""

stats = None
if not args.vecnorm and args.norm_stats:
proof_env = transformed_env_constructor(args=args, use_env_creator=False)()
if not cfg.vecnorm and cfg.norm_stats:
proof_env = transformed_env_constructor(cfg=cfg, use_env_creator=False)()
stats = get_stats_random_rollout(
args, proof_env, key="next_pixels" if args.from_pixels else None
cfg, proof_env, key="next_pixels" if cfg.from_pixels else None
)
# make sure proof_env is closed
proof_env.close()
elif args.from_pixels:
elif cfg.from_pixels:
stats = {"loc": 0.5, "scale": 0.5}
proof_env = transformed_env_constructor(
args=args, use_env_creator=False, stats=stats
cfg=cfg, use_env_creator=False, stats=stats
)()

model = make_ddpg_actor(
proof_env,
args=args,
cfg=cfg,
device=device,
)
loss_module, target_net_updater = make_ddpg_loss(model, args)
loss_module, target_net_updater = make_ddpg_loss(model, cfg)

actor_model_explore = model[0]
if args.ou_exploration:
if args.gSDE:
if cfg.ou_exploration:
if cfg.gSDE:
raise RuntimeError("gSDE and ou_exploration are incompatible")
actor_model_explore = OrnsteinUhlenbeckProcessWrapper(
actor_model_explore,
annealing_num_steps=args.annealing_frames,
sigma=args.ou_sigma,
theta=args.ou_theta,
annealing_num_steps=cfg.annealing_frames,
sigma=cfg.ou_sigma,
theta=cfg.ou_theta,
).to(device)
if device == torch.device("cpu"):
# mostly for debugging
actor_model_explore.share_memory()

if args.gSDE:
if cfg.gSDE:
with torch.no_grad(), set_exploration_mode("random"):
# get dimensions to build the parallel env
proof_td = actor_model_explore(proof_env.reset().to(device))
Expand All @@ -150,7 +138,7 @@ def main(args):

proof_env.close()
create_env_fn = parallel_env_constructor(
args=args,
cfg=cfg,
stats=stats,
action_dim_gsde=action_dim_gsde,
state_dim_gsde=state_dim_gsde,
Expand All @@ -159,17 +147,17 @@ def main(args):
collector = make_collector_offpolicy(
make_env=create_env_fn,
actor_model_explore=actor_model_explore,
args=args,
cfg=cfg,
# make_env_kwargs=[
# {"device": device} if device >= 0 else {}
# for device in args.env_rendering_devices
# ],
)

replay_buffer = make_replay_buffer(device, args)
replay_buffer = make_replay_buffer(device, cfg)

recorder = transformed_env_constructor(
args,
cfg,
video_tag=video_tag,
norm_obs_only=True,
stats=stats,
Expand All @@ -178,7 +166,7 @@ def main(args):
)()

# remove video recorder from recorder to have matching state_dict keys
if args.record_video:
if cfg.record_video:
recorder_rm = TransformedEnv(recorder.base_env)
for transform in recorder.transform:
if not isinstance(transform, VideoRecorder):
Expand Down Expand Up @@ -208,7 +196,7 @@ def main(args):
actor_model_explore,
replay_buffer,
writer,
args,
cfg,
)

def select_keys(batch):
Expand All @@ -225,13 +213,12 @@ def select_keys(batch):

trainer.register_op("batch_process", select_keys)

final_seed = collector.set_seed(args.seed)
print(f"init seed: {args.seed}, final seed: {final_seed}")
final_seed = collector.set_seed(cfg.seed)
print(f"init seed: {cfg.seed}, final seed: {final_seed}")

trainer.train()
return (writer.log_dir, trainer._log_dict, trainer.state_dict())


if __name__ == "__main__":
args = parser.parse_args()
main(args)
main()
Loading

0 comments on commit 336b981

Please sign in to comment.