-
Notifications
You must be signed in to change notification settings - Fork 85
/
env.py
510 lines (414 loc) · 17.7 KB
/
env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cached_property
from typing import Optional, Sequence, Tuple
import chex
import jax
import jax.numpy as jnp
import matplotlib
from numpy.typing import NDArray
from jumanji import specs
from jumanji.env import Environment
from jumanji.environments.packing.flat_pack.generator import (
InstanceGenerator,
RandomFlatPackGenerator,
)
from jumanji.environments.packing.flat_pack.reward import CellDenseReward, RewardFn
from jumanji.environments.packing.flat_pack.types import Observation, State
from jumanji.environments.packing.flat_pack.utils import compute_grid_dim, rotate_block
from jumanji.environments.packing.flat_pack.viewer import FlatPackViewer
from jumanji.types import TimeStep, restart, termination, transition
from jumanji.viewer import Viewer
class FlatPack(Environment[State, specs.MultiDiscreteArray, Observation]):
"""The FlatPack environment with a configurable number of row and column blocks.
Here the goal of an agent is to completely fill an empty grid by placing all
available blocks. It can be thought of as a discrete 2D version of the `BinPack`
environment.
- observation: `Observation`
- grid: jax array (int) of shape (num_rows, num_cols) with the
current state of the grid.
- blocks: jax array (int) of shape (num_blocks, 3, 3) with the blocks to
be placed on the grid. Here each block is a 2D array with shape (3, 3).
- action_mask: jax array (bool) showing where which blocks can be placed on the grid.
this mask includes all possible rotations and possible placement locations
for each block on the grid.
- action: jax array (int32) of shape (4,)
multi discrete array containing the move to perform
(block to place, number of rotations, row coordinate, column coordinate).
- reward: jax array (float) of shape (), could be either:
- cell dense: the number of non-zero cells in a placed block normalised by the
total number of cells in a grid. this will be a value in the range [0, 1].
that is to say that the agent will optimise for the maximum area to fill on
the grid.
- block dense: each placed block will receive a reward of 1./num_blocks. this will
be a value in the range [0, 1]. that is to say that the agent will optimise
for the maximum number of blocks placed on the grid.
- sparse: 1 if the grid is completely filled, otherwise 0 at each timestep.
- episode termination:
- if all blocks have been placed on the board.
- if the agent has taken `num_blocks` steps in the environment.
- state: `State`
- num_blocks: jax array (int32) of shape () with the
number of blocks in the environment.
- blocks: jax array (int32) of shape (num_blocks, 3, 3) with the blocks to
be placed on the grid. Here each block is a 2D array with shape (3, 3).
- action_mask: jax array (bool) showing where which blocks can be placed on the grid.
this mask includes all possible rotations and possible placement locations
for each block on the grid.
- placed_blocks: jax array (bool) of shape (num_blocks,) showing which blocks
have been placed on the grid.
- grid: jax array (int32) of shape (num_rows, num_cols) with the
current state of the grid.
- step_count: jax array (int32) of shape () with the number of steps taken
in the environment.
- key: jax array of shape (2,) with the random key used for board
generation.
```python
from jumanji.environments import FlatPack
env = FlatPack()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)
```
"""
def __init__(
self,
generator: Optional[InstanceGenerator] = None,
reward_fn: Optional[RewardFn] = None,
viewer: Optional[Viewer[State]] = None,
):
"""Initializes the FlatPack environment.
Args:
generator: Instance generator for the environment, default to `RandomFlatPackGenerator`
with a grid of 5 blocks per row and column.
reward_fn: Reward function for the environment, default to `CellDenseReward`.
viewer: Viewer for rendering the environment.
"""
default_generator = RandomFlatPackGenerator(
num_row_blocks=5,
num_col_blocks=5,
)
self.generator = generator or default_generator
self.num_row_blocks = self.generator.num_row_blocks
self.num_col_blocks = self.generator.num_col_blocks
self.num_blocks = self.num_row_blocks * self.num_col_blocks
self.num_rows, self.num_cols = (
compute_grid_dim(self.num_row_blocks),
compute_grid_dim(self.num_col_blocks),
)
self.reward_fn = reward_fn or CellDenseReward()
self.viewer = viewer or FlatPackViewer("FlatPack", self.num_blocks, render_mode="human")
super().__init__()
def __repr__(self) -> str:
return (
f"FlatPack environment with a grid size of ({self.num_rows}x{self.num_cols}) "
f"with {self.num_row_blocks} row blocks, {self.num_col_blocks} column "
f"blocks. Each block has dimension (3x3)."
)
def reset(
self,
key: chex.PRNGKey,
) -> Tuple[State, TimeStep[Observation]]:
"""Resets the environment.
Args:
key: PRNG key for generating a new instance.
Returns:
a tuple of the initial environment state and a time step.
"""
grid_state = self.generator(key)
obs = self._observation_from_state(grid_state)
timestep = restart(observation=obs)
return grid_state, timestep
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
"""Steps the environment.
Args:
state: current state of the environment.
action: action to take.
Returns:
a tuple of the next environment state and a time step.
"""
# Unpack and use actions
block_idx, rotation, row_idx, col_idx = action
chosen_block = state.blocks[block_idx]
# Rotate chosen block
chosen_block = rotate_block(chosen_block, rotation)
grid_block = self._expand_block_to_grid(chosen_block, row_idx, col_idx)
action_is_legal = state.action_mask[block_idx, rotation, row_idx, col_idx]
# If the action is legal create a new grid and update the placed blocks array
new_grid = jax.lax.cond(
action_is_legal,
lambda: state.grid + grid_block,
lambda: state.grid,
)
placed_blocks = jax.lax.cond(
action_is_legal,
lambda: state.placed_blocks.at[block_idx].set(True),
lambda: state.placed_blocks,
)
new_action_mask = self._make_action_mask(new_grid, state.blocks, placed_blocks)
next_state = State(
grid=new_grid,
blocks=state.blocks,
action_mask=new_action_mask,
num_blocks=state.num_blocks,
key=state.key,
step_count=state.step_count + 1,
placed_blocks=placed_blocks,
)
done = self._is_done(next_state)
next_obs = self._observation_from_state(next_state)
reward = self.reward_fn(state, grid_block, next_state, action_is_legal, done)
timestep = jax.lax.cond(
done,
termination,
transition,
reward,
next_obs,
)
return next_state, timestep
def render(self, state: State) -> Optional[NDArray]:
"""Render a given state of the environment.
Args:
state: `State` object containing the current environment state.
"""
return self.viewer.render(state)
def animate(
self,
states: Sequence[State],
interval: int = 200,
save_path: Optional[str] = None,
) -> matplotlib.animation.FuncAnimation:
"""Create an animation from a sequence of states.
Args:
states: sequence of `State` corresponding to subsequent timesteps.
interval: delay between frames in milliseconds, default to 200.
save_path: the path where the animation file should be saved. If it is None, the plot
will not be saved.
Returns:
animation that can export to gif, mp4, or render with HTML.
"""
return self.viewer.animate(states, interval, save_path)
def close(self) -> None:
"""Perform any necessary cleanup.
Environments will automatically `close()` themselves when
garbage collected or when the program exits.
"""
self.viewer.close()
@cached_property
def observation_spec(self) -> specs.Spec[Observation]:
"""Returns the observation spec of the environment.
Returns:
Spec for each filed in the observation:
- grid: BoundedArray (int) of shape (num_rows, num_cols).
- blocks: BoundedArray (int) of shape (num_blocks, 3, 3).
- action_mask: BoundedArray (bool) of shape
(num_blocks, 4, num_rows-2, num_cols-2).
"""
grid = specs.BoundedArray(
shape=(self.num_rows, self.num_cols),
minimum=0,
maximum=self.num_blocks,
dtype=jnp.int32,
name="grid",
)
blocks = specs.BoundedArray(
shape=(self.num_blocks, 3, 3),
minimum=0,
maximum=self.num_blocks,
dtype=jnp.int32,
name="blocks",
)
action_mask = specs.BoundedArray(
shape=(
self.num_blocks,
4,
self.num_rows - 2,
self.num_cols - 2,
),
minimum=False,
maximum=True,
dtype=bool,
name="action_mask",
)
return specs.Spec(
Observation,
"ObservationSpec",
grid=grid,
blocks=blocks,
action_mask=action_mask,
)
@cached_property
def action_spec(self) -> specs.MultiDiscreteArray:
"""Specifications of the action expected by the `FlatPack` environment.
Returns:
MultiDiscreteArray (int32) of shape (num_blocks, num_rotations,
num_rows-2, num_cols-2).
- num_blocks: int between 0 and num_blocks - 1 (inclusive).
- num_rotations: int between 0 and 3 (inclusive).
- max_row_position: int between 0 and num_rows - 3 (inclusive).
- max_col_position: int between 0 and num_cols - 3 (inclusive).
"""
max_row_position = self.num_rows - 2
max_col_position = self.num_cols - 2
return specs.MultiDiscreteArray(
num_values=jnp.array([self.num_blocks, 4, max_row_position, max_col_position]),
name="action",
)
def _is_done(self, state: State) -> bool:
"""Checks if the environment is done by checking whether the number of
steps is equal to the number of blocks.
Args:
state: current state of the environment.
Returns:
True if the environment is done, False otherwise.
"""
done: bool = state.step_count >= state.num_blocks
return done
def _is_legal_action(
self,
action: chex.Numeric,
grid: chex.Array,
placed_blocks: chex.Array,
grid_mask_block: chex.Array,
) -> bool:
"""Checks if the action is legal by considering the action mask and the
current grid. An action is legal if the action mask is True for that action
and the there is no overlap with blocks already placed.
Args:
action: action taken.
grid: current state of the grid.
placed_blocks: array indicating which blocks have been placed.
grid_mask_block: grid with ones where current block should be placed.
Returns:
True if the action is legal, False otherwise.
"""
block_idx, _, _, _ = action
placed_mask = (grid > 0.0) + grid_mask_block
legal: bool = (~placed_blocks[block_idx]) & (jnp.max(placed_mask) <= 1)
return legal
def _get_ones_like_expanded_block(self, grid_block: chex.Array) -> chex.Array:
"""Makes a grid of zeroes with ones where the block is placed."""
return (grid_block != 0).astype(jnp.int32)
def _expand_block_to_grid(
self,
block: chex.Array,
row_coord: chex.Numeric,
col_coord: chex.Numeric,
) -> chex.Array:
"""Places a block on a grid of zeroes with the same size as the grid.
Args:
block: block to place on the grid.
row_coord: row coordinate on the grid where the top left corner
of the block will be placed.
col_coord: column coordinate on the grid where the top left corner
of the block will be placed.
Returns:
Grid of zeroes with values where the block is placed.
"""
# Make an empty grid for placing the block on.
grid_with_block = jnp.zeros((self.num_rows, self.num_cols), dtype=jnp.int32)
place_location = (row_coord, col_coord)
grid_with_block = jax.lax.dynamic_update_slice(grid_with_block, block, place_location)
return grid_with_block
def _observation_from_state(self, state: State) -> Observation:
"""Creates an observation from a state.
Args:
state: State to create an observation from.
Returns:
An observation.
"""
return Observation(
grid=state.grid,
action_mask=state.action_mask,
blocks=state.blocks,
)
def _expand_all_blocks_to_grids(
self,
blocks: chex.Array,
block_idxs: chex.Array,
rotations: chex.Array,
row_coords: chex.Array,
col_coords: chex.Array,
) -> chex.Array:
"""Takes multiple blocks and their corresponding rotations and positions,
and generates a grid for each block.
Args:
blocks: array of possible blocks.
block_idxs: array of indices of the blocks to place.
rotations: array of all possible rotations for each block.
row_coords: array of row coordinates.
col_coords: array of column coordinates.
"""
batch_expand_block_to_board = jax.vmap(self._expand_block_to_grid)
all_possible_blocks = blocks[block_idxs]
rotated_blocks = jax.vmap(rotate_block)(all_possible_blocks, rotations)
grids = batch_expand_block_to_board(rotated_blocks, row_coords, col_coords)
batch_get_ones_like_expanded_block = jax.vmap(
self._get_ones_like_expanded_block, in_axes=(0)
)
grids = batch_get_ones_like_expanded_block(grids)
return grids
def _make_action_mask(
self, grid: chex.Array, blocks: chex.Array, placed_blocks: chex.Array
) -> chex.Array:
"""Create a mask of possible actions based on the current state of the grid.
Args:
grid: current state of the grid.
blocks: array of all blocks.
placed_blocks: array of blocks that have already been placed.
"""
num_blocks, num_rotations, num_placement_rows, num_placement_cols = (
self.num_blocks,
4,
self.num_rows - 2,
self.num_cols - 2,
)
blocks_grid, rotations_grid, rows_grid, cols_grid = jnp.meshgrid(
jnp.arange(num_blocks),
jnp.arange(num_rotations),
jnp.arange(num_placement_rows),
jnp.arange(num_placement_cols),
indexing="ij",
)
grid_mask_pieces = self._expand_all_blocks_to_grids(
blocks,
blocks_grid.flatten(),
rotations_grid.flatten(),
rows_grid.flatten(),
cols_grid.flatten(),
)
batch_is_legal_action = jax.vmap(self._is_legal_action, in_axes=(0, None, None, 0))
all_actions = jnp.stack(
(blocks_grid, rotations_grid, rows_grid, cols_grid), axis=-1
).reshape(-1, 4)
legal_actions = batch_is_legal_action(
all_actions,
grid,
placed_blocks,
grid_mask_pieces,
)
legal_actions = legal_actions.reshape(
num_blocks, num_rotations, num_placement_rows, num_placement_cols
)
# Now set all current placed blocks to false in the mask.
placed_blocks_array = placed_blocks.reshape((self.num_blocks, 1, 1, 1))
placed_blocks_mask = jnp.tile(
placed_blocks_array,
(1, num_rotations, num_placement_rows, num_placement_cols),
)
legal_actions = jnp.where(placed_blocks_mask, False, legal_actions)
return legal_actions