Skip to content

Commit

Permalink
fix general padding for global warehouse image
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasSchaefer committed Apr 7, 2022
1 parent 341fd73 commit d5e594b
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions rware/warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,7 @@ def get_global_image(
:param pad_to_shape: if given than pad environment global image shape into this
shape (if doesn't fit throw exception)
"""
if recompute or not self.global_image:
if recompute or self.global_image is None:
layers = []
for layer_type in image_layers:
if layer_type == ImageLayer.SHELVES:
Expand Down Expand Up @@ -979,17 +979,18 @@ def get_global_image(
for ag in self.agents:
layer[ag.y, ag.x] = 0.0
layers.append(layer)
all_layers = np.stack(layers)
if pad_to_shape:
assert len(pad_to_shape) == len(all_layers.shape) == 3
padding_dims = [pad_dim - global_dim for pad_dim, global_dim in zip(pad_to_shape, all_layers.shape)]
start_dims = [pad_dim // 2 for pad_dim in padding_dims[1:]]
end_dims = [pad_dim // 2 if pad_dim % 2 == 0 else pad_dim // 2 + 1 for pad_dim in padding_dims[1:]]
self.global_image = np.stack(layers)
if pad_to_shape is not None:
padding_dims = [pad_dim - global_dim for pad_dim, global_dim in zip(pad_to_shape, self.global_image.shape)]
assert all([dim >= 0 for dim in padding_dims])
self.global_image = np.zeros(pad_to_shape)
self.global_image[:, start_dims[0]:-end_dims[0], start_dims[1]:-end_dims[1]] = all_layers
else:
self.global_image = all_layers
pad_before = [pad_dim // 2 for pad_dim in padding_dims]
pad_after = [pad_dim // 2 if pad_dim % 2 == 0 else pad_dim // 2 + 1 for pad_dim in padding_dims]
self.global_image = np.pad(
self.global_image,
pad_width=tuple(zip(pad_before, pad_after)),
mode='constant',
constant_values=0
)
return self.global_image


Expand Down

0 comments on commit d5e594b

Please sign in to comment.