-
Notifications
You must be signed in to change notification settings - Fork 38
/
main_dmc.py
126 lines (119 loc) · 3.78 KB
/
main_dmc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import datetime
import os
import argparse
import torch
from rlpyt.samplers.collections import TrajInfo
from rlpyt.runners.minibatch_rl import MinibatchRlEval, MinibatchRl
from rlpyt.samplers.serial.sampler import SerialSampler
from rlpyt.utils.logging.context import logger_context
from dreamer.agents.dmc_dreamer_agent import DMCDreamerAgent
from dreamer.algos.dreamer_algo import Dreamer
from dreamer.envs.dmc import DeepMindControl
from dreamer.envs.time_limit import TimeLimit
from dreamer.envs.action_repeat import ActionRepeat
from dreamer.envs.normalize_actions import NormalizeActions
from dreamer.envs.wrapper import make_wapper
def build_and_train(
log_dir,
game="cartpole_balance",
run_ID=0,
cuda_idx=None,
eval=False,
save_model="last",
load_model_path=None,
):
params = torch.load(load_model_path) if load_model_path else {}
agent_state_dict = params.get("agent_state_dict")
optimizer_state_dict = params.get("optimizer_state_dict")
action_repeat = 2
factory_method = make_wapper(
DeepMindControl,
[ActionRepeat, NormalizeActions, TimeLimit],
[dict(amount=action_repeat), dict(), dict(duration=1000 / action_repeat)],
)
sampler = SerialSampler(
EnvCls=factory_method,
TrajInfoCls=TrajInfo,
env_kwargs=dict(name=game),
eval_env_kwargs=dict(name=game),
batch_T=1,
batch_B=1,
max_decorrelation_steps=0,
eval_n_envs=10,
eval_max_steps=int(10e3),
eval_max_trajectories=5,
)
algo = Dreamer(initial_optim_state_dict=optimizer_state_dict) # Run with defaults.
agent = DMCDreamerAgent(
train_noise=0.3,
eval_noise=0,
expl_type="additive_gaussian",
expl_min=None,
expl_decay=None,
initial_model_state_dict=agent_state_dict,
)
runner_cls = MinibatchRlEval if eval else MinibatchRl
runner = runner_cls(
algo=algo,
agent=agent,
sampler=sampler,
n_steps=5e6,
log_interval_steps=1e3,
affinity=dict(cuda_idx=cuda_idx),
)
config = dict(game=game)
name = "dreamer_" + game
with logger_context(
log_dir,
run_ID,
name,
config,
snapshot_mode=save_model,
override_prefix=True,
use_summary_writer=True,
):
runner.train()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("--game", help="DMC game", default="cartpole_balance")
parser.add_argument(
"--run-ID", help="run identifier (logging)", type=int, default=0
)
parser.add_argument("--cuda-idx", help="gpu to use ", type=int, default=None)
parser.add_argument("--eval", action="store_true")
parser.add_argument(
"--save-model",
help="save model",
type=str,
default="last",
choices=["all", "none", "gap", "last"],
)
parser.add_argument(
"--load-model-path", help="load model from path", type=str
) # path to params.pkl
default_log_dir = os.path.join(
os.path.dirname(__file__),
"data",
"local",
datetime.datetime.now().strftime("%Y%m%d"),
)
parser.add_argument("--log-dir", type=str, default=default_log_dir)
args = parser.parse_args()
log_dir = os.path.abspath(args.log_dir)
i = args.run_ID
while os.path.exists(os.path.join(log_dir, "run_" + str(i))):
print(f"run {i} already exists. ")
i += 1
print(f"Using run id = {i}")
args.run_ID = i
build_and_train(
log_dir,
game=args.game,
run_ID=args.run_ID,
cuda_idx=args.cuda_idx,
eval=args.eval,
save_model=args.save_model,
load_model_path=args.load_model_path,
)