-
Notifications
You must be signed in to change notification settings - Fork 326
/
utils_atari.py
128 lines (108 loc) · 3.77 KB
/
utils_atari.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import torch.nn
import torch.optim
from torchrl.data import Composite
from torchrl.envs import (
CatFrames,
DoubleToFloat,
EndOfLifeTransform,
GrayScale,
GymEnv,
NoopResetEnv,
Resize,
RewardSum,
SignTransform,
StepCounter,
ToTensorImage,
TransformedEnv,
VecNorm,
)
from torchrl.modules import ConvNet, MLP, QValueActor
from torchrl.record import VideoRecorder
# ====================================================================
# Environment utils
# --------------------------------------------------------------------
def make_env(env_name, frame_skip, device, is_test=False):
env = GymEnv(
env_name,
frame_skip=frame_skip,
from_pixels=True,
pixels_only=False,
device=device,
categorical_action_encoding=True,
)
env = TransformedEnv(env)
env.append_transform(NoopResetEnv(noops=30, random=True))
if not is_test:
env.append_transform(EndOfLifeTransform())
env.append_transform(SignTransform(in_keys=["reward"]))
env.append_transform(ToTensorImage())
env.append_transform(GrayScale())
env.append_transform(Resize(84, 84))
env.append_transform(CatFrames(N=4, dim=-3))
env.append_transform(RewardSum())
env.append_transform(StepCounter(max_steps=4500))
env.append_transform(DoubleToFloat())
env.append_transform(VecNorm(in_keys=["pixels"]))
return env
# ====================================================================
# Model utils
# --------------------------------------------------------------------
def make_dqn_modules_pixels(proof_environment, device):
# Define input shape
input_shape = proof_environment.observation_spec["pixels"].shape
env_specs = proof_environment.specs
num_actions = env_specs["input_spec", "full_action_spec", "action"].space.n
action_spec = env_specs["input_spec", "full_action_spec", "action"]
# Define Q-Value Module
cnn = ConvNet(
activation_class=torch.nn.ReLU,
num_cells=[32, 64, 64],
kernel_sizes=[8, 4, 3],
strides=[4, 2, 1],
device=device,
)
cnn_output = cnn(torch.ones(input_shape, device=device))
mlp = MLP(
in_features=cnn_output.shape[-1],
activation_class=torch.nn.ReLU,
out_features=num_actions,
num_cells=[512],
device=device,
)
qvalue_module = QValueActor(
module=torch.nn.Sequential(cnn, mlp),
spec=Composite(action=action_spec).to(device),
in_keys=["pixels"],
)
return qvalue_module
def make_dqn_model(env_name, frame_skip, device):
proof_environment = make_env(env_name, frame_skip, device=device)
qvalue_module = make_dqn_modules_pixels(proof_environment, device=device)
del proof_environment
return qvalue_module
# ====================================================================
# Evaluation utils
# --------------------------------------------------------------------
def eval_model(actor, test_env, num_episodes=3):
test_rewards = torch.zeros(num_episodes, dtype=torch.float32)
for i in range(num_episodes):
td_test = test_env.rollout(
policy=actor,
auto_reset=True,
auto_cast_to_device=True,
break_when_any_done=True,
max_steps=10_000_000,
)
test_env.apply(dump_video)
reward = td_test["next", "episode_reward"][td_test["next", "done"]]
test_rewards[i] = reward.sum()
del td_test
return test_rewards.mean()
def dump_video(module):
if isinstance(module, VideoRecorder):
module.dump()