-
Notifications
You must be signed in to change notification settings - Fork 326
/
gym_conversion_examples.py
125 lines (93 loc) · 3.3 KB
/
gym_conversion_examples.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
This script gives some examples of gym environment conversion with Dict, Tuple and Sequence spaces.
"""
import gymnasium as gym
from gymnasium import spaces
from torchrl.envs import GymWrapper
action_space = spaces.Discrete(2)
class BaseEnv(gym.Env):
def step(self, action):
return self.observation_space.sample(), 1, False, False, {}
def reset(self, **kwargs):
return self.observation_space.sample(), {}
class SimpleEnv(BaseEnv):
def __init__(self):
self.observation_space = spaces.Box(-1, 1, (2,))
self.action_space = action_space
gym.register("SimpleEnv-v0", entry_point=SimpleEnv)
class SimpleEnvWithDict(BaseEnv):
def __init__(self):
self.observation_space = spaces.Dict(
obs0=spaces.Box(-1, 1, (2,)), obs1=spaces.Box(-1, 1, (3,))
)
self.action_space = action_space
gym.register("SimpleEnvWithDict-v0", entry_point=SimpleEnvWithDict)
class SimpleEnvWithTuple(BaseEnv):
def __init__(self):
self.observation_space = spaces.Tuple(
(spaces.Box(-1, 1, (2,)), spaces.Box(-1, 1, (3,)))
)
self.action_space = action_space
gym.register("SimpleEnvWithTuple-v0", entry_point=SimpleEnvWithTuple)
class SimpleEnvWithSequence(BaseEnv):
def __init__(self):
self.observation_space = spaces.Sequence(
spaces.Box(-1, 1, (2,)),
# Only stack=True is currently allowed
stack=True,
)
self.action_space = action_space
gym.register("SimpleEnvWithSequence-v0", entry_point=SimpleEnvWithSequence)
class SimpleEnvWithSequenceOfTuple(BaseEnv):
def __init__(self):
self.observation_space = spaces.Sequence(
spaces.Tuple(
(
spaces.Box(-1, 1, (2,)),
spaces.Box(-1, 1, (3,)),
)
),
# Only stack=True is currently allowed
stack=True,
)
self.action_space = action_space
gym.register(
"SimpleEnvWithSequenceOfTuple-v0", entry_point=SimpleEnvWithSequenceOfTuple
)
class SimpleEnvWithTupleOfSequences(BaseEnv):
def __init__(self):
self.observation_space = spaces.Tuple(
(
spaces.Sequence(
spaces.Box(-1, 1, (2,)),
# Only stack=True is currently allowed
stack=True,
),
spaces.Sequence(
spaces.Box(-1, 1, (3,)),
# Only stack=True is currently allowed
stack=True,
),
)
)
self.action_space = action_space
gym.register(
"SimpleEnvWithTupleOfSequences-v0", entry_point=SimpleEnvWithTupleOfSequences
)
if __name__ == "__main__":
for envname in [
"SimpleEnv",
"SimpleEnvWithDict",
"SimpleEnvWithTuple",
"SimpleEnvWithSequence",
"SimpleEnvWithSequenceOfTuple",
"SimpleEnvWithTupleOfSequences",
]:
print("\n\nEnv =", envname)
env = gym.make(envname + "-v0")
env_torchrl = GymWrapper(env)
print(env_torchrl.rollout(10, return_contiguous=False))