-
Notifications
You must be signed in to change notification settings - Fork 326
/
catframes-in-buffer.py
99 lines (87 loc) · 2.52 KB
/
catframes-in-buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torchrl.data import LazyTensorStorage, ReplayBuffer
from torchrl.envs import (
CatFrames,
Compose,
DMControlEnv,
StepCounter,
ToTensorImage,
TransformedEnv,
UnsqueezeTransform,
)
# Number of frames to stack together
frame_stack = 4
# Dimension along which the stack should occur
stack_dim = -4
# Max size of the buffer
max_size = 100_000
# Batch size of the replay buffer
training_batch_size = 32
seed = 123
def main():
catframes = CatFrames(
N=frame_stack,
dim=stack_dim,
in_keys=["pixels_trsf"],
out_keys=["pixels_trsf"],
)
env = TransformedEnv(
DMControlEnv(
env_name="cartpole",
task_name="balance",
device="cpu",
from_pixels=True,
pixels_only=True,
),
Compose(
ToTensorImage(
from_int=True,
dtype=torch.float32,
in_keys=["pixels"],
out_keys=["pixels_trsf"],
shape_tolerant=True,
),
UnsqueezeTransform(
dim=stack_dim, in_keys=["pixels_trsf"], out_keys=["pixels_trsf"]
),
catframes,
StepCounter(),
),
)
env.set_seed(seed)
transform, sampler = catframes.make_rb_transform_and_sampler(
batch_size=training_batch_size,
traj_key=("collector", "traj_ids"),
strict_length=True,
)
rb_transforms = Compose(
ToTensorImage(
from_int=True,
dtype=torch.float32,
in_keys=["pixels", ("next", "pixels")],
out_keys=["pixels_trsf", ("next", "pixels_trsf")],
shape_tolerant=True,
), # C W' H' -> C W' H' (unchanged due to shape_tolerant)
UnsqueezeTransform(
dim=stack_dim,
in_keys=["pixels_trsf", ("next", "pixels_trsf")],
out_keys=["pixels_trsf", ("next", "pixels_trsf")],
), # 1 C W' H'
transform,
)
rb = ReplayBuffer(
storage=LazyTensorStorage(max_size=max_size, device="cpu"),
sampler=sampler,
batch_size=training_batch_size,
transform=rb_transforms,
)
data = env.rollout(1000, break_when_any_done=False)
rb.extend(data)
training_batch = rb.sample()
print(training_batch)
if __name__ == "__main__":
main()