Skip to content

Commit

Permalink
[Minor] Code quality improvements (pytorch#2140)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 30, 2024
1 parent c25ec59 commit 68101b0
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 5 deletions.
3 changes: 2 additions & 1 deletion examples/envs/gym-async-info-reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def step(self, action):
]

# Create an info reader: this object will read the info and write its content to the tensordict
reader = lambda info, tensordict: tensordict.set("field1", np.stack(info["field1"]))
def reader(info, tensordict):
return tensordict.set("field1", np.stack(info["field1"]))
env.set_info_dict_reader(info_dict_reader=reader)

# Print the info readers (there should be 2: one to read the terminal states and another to read the 'field1')
Expand Down
4 changes: 0 additions & 4 deletions torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,6 @@ def __call__(cls, *args, **kwargs):
)
add_info_dict = False
if add_info_dict:
print("adding info dict reader", instance.observation_spec)
instance.set_info_dict_reader(
terminal_obs_reader(instance.observation_spec, backend=backend)
)
Expand Down Expand Up @@ -1539,9 +1538,6 @@ def __call__(self, info_dict, tensordict):
for i, obs in enumerate(terminal_obs):
# writes final_obs inplace with terminal_obs content
self._read_obs(obs, key[-1], final_obs_buffer, index=i)
print("self.batch_size", tensordict.shape)
print("final_obs_buffer", final_obs_buffer)
print("spec", item)
tensordict.set(key, final_obs_buffer)
return tensordict

Expand Down

0 comments on commit 68101b0

Please sign in to comment.