Skip to content

Commit

Permalink
feat(connector): single agent (instadeepai#119)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel <57721552+dluo96@users.noreply.github.com>
Co-authored-by: Clément Bonnet <56230714+clement-bonnet@users.noreply.github.com>
  • Loading branch information
3 people authored May 12, 2023
1 parent 328fba9 commit e2feb60
Show file tree
Hide file tree
Showing 13 changed files with 91 additions and 274 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ problems.
| 🏭 JobShop (Job Shop Scheduling Problem) | Packing | `JobShop-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/job_shop/) | [doc](https://instadeepai.github.io/jumanji/environments/job_shop/) |
| 🎒 Knapsack | Packing | `Knapsack-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/knapsack/) | [doc](https://instadeepai.github.io/jumanji/environments/knapsack/) |
| 🧹 Cleaner | Routing | `Cleaner-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/cleaner/) | [doc](https://instadeepai.github.io/jumanji/environments/cleaner/) |
| :link: Connector | Routing | `Connector-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/connector/) | [doc](https://instadeepai.github.io/jumanji/environments/connector/) |
| :link: Connector | Routing | `Connector-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/connector/) | [doc](https://instadeepai.github.io/jumanji/environments/connector/) |
| 🚚 CVRP (Capacitated Vehicle Routing Problem) | Routing | `CVRP-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/cvrp/) | [doc](https://instadeepai.github.io/jumanji/environments/cvrp/) |
| :mag: Maze | Routing | `Maze-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/maze/) | [doc](https://instadeepai.github.io/jumanji/environments/maze/) |
| 🐍 Snake | Routing | `Snake-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/snake/) | [doc](https://instadeepai.github.io/jumanji/environments/snake/) |
Expand Down
2 changes: 0 additions & 2 deletions docs/api/environments/connector.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
- init
- observation_spec
- action_spec
- reward_spec
- discount_spec
- reset
- step
- render
46 changes: 10 additions & 36 deletions docs/environments/connector.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,14 @@ to allow each other to connect to their own targets without overlapping.
An episode ends when all agents have connected to their targets or no agents can make any further
moves due to being blocked.

> ⚠️ Warning
>
> This environment is multi-agent, i.e. the observation, action and action mask are batched on the
> agent dimension.
>
> - If used in a multi-agent RL setting, one can direclty vmap the agents' inference functions on
> the observation they receive or unpack the observation and give it to each agent manually, e.g.
> `agents_obs = [jax.tree_util.tree_map(lambda x: x[i] if x.ndim>0 else x, obs) for i in range(len(obs.grid))]`.
>
> - If used in a single-agent RL setting, one can use `jumanji.wrappers.MultiToSingleWrapper` to
> make it a single-agent environment.

## Observation
At each step observation contains 3 items: a grid for each agent, an action mask for each agent and
At each step observation contains 3 items: a grid, an action mask for each agent and
the episode step count.

- `grid`: jax array (int32) of shape `(num_agents, grid_size, grid_size)`, a 2D matrix for each
agent that represents pairs of points that need to be connected from the perspective of each
agent. The **position** of an agent has to connect to its **target**, leaving a **path** behind
- `grid`: jax array (int32) of shape `(grid_size, grid_size)`, a 2D matrix that represents pairs
of points that need to be connected. Each agent has three types of points: **position**,
**target** and **path** which are represented by different numbers on the grid. The
**position** of an agent has to connect to its **target**, leaving a **path** behind
it as it moves across the grid forming its route. Each agent connects to only 1 target.

- `action_mask`: jax array (bool) of shape `(num_agents, 5)`, indicates which actions each agent
Expand All @@ -43,15 +31,14 @@ the episode step count.


### Encoding
Each agent has 3 components represented in the observation space: position, target, and path. Each
Each agent has 3 components represented in the observation space: **position**, **target**, and **path**. Each
agent in the environment will have an integer representing their components.

- Positions are encoded starting from 2 in multiples of 3: 2, 5, 8, …

- Targets are encoded starting from 3 in multiples of 3: 3, 6, 9, …

- Paths appear in the location of the head once it moves, starting from 1 in
multiples of 3: 1, 4, 7, …
- Paths appear in the location of the head once it moves, starting from 1 in multiples of 3: 1, 4, 7, …

Every group of 3 corresponds to 1 agent: (1,2,3), (4,5,6), …

Expand All @@ -62,7 +49,7 @@ Agent2[path=4, position=5, target=6]
Agent3[path=7, position=8, target=9]
```

For example, on a 6x6 grid, the starting observation is shown below.
For example, on a 6x6 grid, a possible observation is shown below.

```
[[ 2 0 3 0 0 0]
Expand All @@ -73,31 +60,18 @@ For example, on a 6x6 grid, the starting observation is shown below.
[ 0 0 6 7 7 7]]
```

### Current Agent (multi-agent)

Given that this is a multi-agent environment, each agent gets its own observation thus we must
have a way to represent the current agent, so that the actor/learner knows which agent its actions
will apply to. The current agent is always encoded as `(1,2,3)` in the observations. However, this
notion of current agent only exists in the observations, in the state agent 0 is always encoded
as `(1,2,3)`.

The implementation shifts all other agents' values to make the `(1,2,3)` values represent the
current agent, so in each agent’s observation it will be represented by `(1,2,3)`.
This means that the agent with the values `(4,5,6)` will always be the next agent to act.


## Action
The action space is a `MultiDiscreteArray` of shape `(num_agents,)` of integer values in the range
of `[0, 4]`. Each value corresponds to an agent moving in 1 of 4 cardinal directions or taking the
no-op action. That is, [0, 1, 2, 3, 4] -> [No Op, Up, Right, Down, Left].


## Reward
The reward is **dense**: +1.0 for each agent that connects at that step and -0.03 for each agent that has not
The reward is **dense**: +1.0 per agent that connects at that step and -0.03 per agent that has not
connected yet.

Rewards are provided in the shape `(num_agents,)` so that each agent can have a reward.


## Registered Versions 📖
- `Connector-v0`, grid size of 10 and 5 agents.
- `Connector-v1`, grid size of 10 and 5 agents.
2 changes: 1 addition & 1 deletion jumanji/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
register(id="Cleaner-v0", entry_point="jumanji.environments:Cleaner")

# Connector with grid size of 10 and 5 agents.
register(id="Connector-v0", entry_point="jumanji.environments:Connector")
register(id="Connector-v1", entry_point="jumanji.environments:Connector")

# CVRP with 20 randomly generated nodes, a maximum capacity of 30,
# a maximum demand for each node of 10, and a dense reward function.
Expand Down
81 changes: 21 additions & 60 deletions jumanji/environments/routing/connector/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,23 @@
is_valid_position,
move_agent,
move_position,
switch_perspective,
)
from jumanji.environments.routing.connector.viewer import ConnectorViewer
from jumanji.types import TimeStep, restart, termination, transition
from jumanji.viewer import Viewer


class Connector(Environment[State]):
"""The `Connector` environment is a multi-agent gridworld problem where each agent must connect a
start to a target. However, when moving through this gridworld the agent leaves an impassable
trail behind it. Therefore, agents must connect to their targets without overlapping the routes
taken by any other agent.
"""The `Connector` environment is a gridworld problem where multiple pairs of points (sets)
must be connected without overlapping the paths taken by any other set. This is achieved
by allowing certain points to move to an adjacent cell at each step. However, each time a
point moves it leaves an impassable trail behind it. The goal is to connect all sets.
- observation - `Observation`
- action mask: jax array (bool) of shape (num_agents, 5).
- step_count: jax array (int32) of shape ()
the current episode step.
- grid: jax array (int32) of shape (num_agents, size, size)
- each 2d array (size, size) along axis 0 is the agent's local observation.
- agents have ids from 0 to (num_agents - 1)
- grid: jax array (int32) of shape (grid_size, grid_size)
- with 2 agents you might have a grid like this:
4 0 1
5 0 1
Expand All @@ -68,24 +65,21 @@ class Connector(Environment[State]):
the bottom right corner and is aiming to get to the middle bottom cell. Agent 2
started in the top left and moved down once towards its target in the bottom left.
This would just be agent 0's view, the numbers would be flipped for agent 1's view.
So the full observation would be of shape (2, 3, 3).
- action: jax array (int32) of shape (num_agents,):
- can take the values [0,1,2,3,4] which correspond to [No Op, Up, Right, Down, Left].
- each value in the array corresponds to an agent's action.
- reward: jax array (float) of shape ():
- dense: each agent is given 1.0 if it connects on that step, otherwise 0.0. Additionally,
each agent that has not connected receives a penalty reward of -0.03.
- dense: reward is 1 for each successful connection on that step. Additionally,
each pair of points that have not connected receives a penalty reward of -0.03.
- episode termination: if an agent can't move, or the time limit is reached, or the agent
connects to its target, it is considered done. Once all agents are done, the episode
terminates. The timestep discounts are of shape (num_agents,).
- episode termination:
- all agents either can't move (no available actions) or have connected to their target.
- the time limit is reached.
- state: State:
- key: jax PRNG key used to randomly spawn agents and targets.
- grid: jax array (int32) of shape (size, size) which corresponds to agent 0's observation.
- grid: jax array (int32) of shape (grid_size, grid_size) giving the observation.
- step_count: jax array (int32) of shape () number of steps elapsed in the current episode.
```python
Expand Down Expand Up @@ -147,14 +141,12 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
state.agents, state.grid
)
observation = Observation(
grid=self._obs_from_grid(state.grid),
grid=state.grid,
action_mask=action_mask,
step_count=state.step_count,
)
extras = self._get_extras(state)
timestep = restart(
observation=observation, extras=extras, shape=(self.num_agents,)
)
timestep = restart(observation=observation, extras=extras)
return state, timestep

def step(
Expand All @@ -180,31 +172,26 @@ def step(
grid=grid, step_count=state.step_count + 1, agents=agents, key=state.key
)

# Construct timestep: get observations, rewards, discounts
grids = self._obs_from_grid(grid)
# Construct timestep: get reward, legal actions and done
reward = self._reward_fn(state, action, new_state)
action_mask = jax.vmap(self._get_action_mask, (0, None))(agents, grid)
observation = Observation(
grid=grids, action_mask=action_mask, step_count=new_state.step_count
grid=grid, action_mask=action_mask, step_count=new_state.step_count
)

dones = jax.vmap(connected_or_blocked)(agents, action_mask)
discount = jnp.asarray(jnp.logical_not(dones), dtype=float)
done = jnp.all(jax.vmap(connected_or_blocked)(agents, action_mask))
extras = self._get_extras(new_state)
timestep = jax.lax.cond(
dones.all() | (new_state.step_count >= self.time_limit),
done | (new_state.step_count >= self.time_limit),
lambda: termination(
reward=reward,
observation=observation,
extras=extras,
shape=self.num_agents,
),
lambda: transition(
reward=reward,
observation=observation,
discount=discount,
extras=extras,
shape=self.num_agents,
),
)

Expand Down Expand Up @@ -271,12 +258,6 @@ def _step_agent(

return new_agent, new_grid

def _obs_from_grid(self, grid: chex.Array) -> chex.Array:
"""Gets the observation vector for all agents."""
return jax.vmap(switch_perspective, (None, 0, None))(
grid, self._agent_ids, self.num_agents
)

def _get_action_mask(self, agent: Agent, grid: chex.Array) -> chex.Array:
"""Gets an agent's action mask."""
# Don't check action 0 because no-op is always valid
Expand Down Expand Up @@ -344,12 +325,12 @@ def observation_spec(self) -> specs.Spec[Observation]:
Returns:
Spec for the `Observation` whose fields are:
- grid: BoundedArray (int32) of shape (num_agents, grid_size, grid_size).
- grid: BoundedArray (int32) of shape (grid_size, grid_size).
- action_mask: BoundedArray (bool) of shape (num_agents, 5).
- step_count: BoundedArray (int32) of shape ().
"""
grid = specs.BoundedArray(
shape=(self.num_agents, self.grid_size, self.grid_size),
shape=(self.grid_size, self.grid_size),
dtype=jnp.int32,
name="grid",
minimum=0,
Expand Down Expand Up @@ -380,8 +361,8 @@ def observation_spec(self) -> specs.Spec[Observation]:
def action_spec(self) -> specs.MultiDiscreteArray:
"""Returns the action spec for the Connector environment.
5 actions: [0,1,2,3,4] -> [No Op, Up, Right, Down, Left]. Since this is a multi-agent
environment, the environment expects an array of actions of shape (num_agents,).
5 actions: [0,1,2,3,4] -> [No Op, Up, Right, Down, Left]. Since this is an environment with
a multi-dimensional action space, it expects an array of actions of shape (num_agents,).
Returns:
observation_spec: `MultiDiscreteArray` of shape (num_agents,).
Expand All @@ -391,23 +372,3 @@ def action_spec(self) -> specs.MultiDiscreteArray:
dtype=jnp.int32,
name="action",
)

def reward_spec(self) -> specs.Array:
"""
Returns:
reward_spec: a `specs.Array` spec of shape (num_agents,). One for each agent.
"""
return specs.Array(shape=(self.num_agents,), dtype=float, name="reward")

def discount_spec(self) -> specs.BoundedArray:
"""
Returns:
discount_spec: a `specs.Array` spec of shape (num_agents,). One for each agent
"""
return specs.BoundedArray(
shape=(self.num_agents,),
dtype=float,
minimum=0.0,
maximum=1.0,
name="discount",
)
53 changes: 6 additions & 47 deletions jumanji/environments/routing/connector/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def test_connector__reset(connector: Connector, key: jax.random.KeyArray) -> Non
assert all(is_head_on_grid(state.agents, state.grid))
assert all(is_target_on_grid(state.agents, state.grid))

assert jnp.array_equal(timestep.discount, jnp.ones(connector.num_agents))
assert jnp.array_equal(timestep.reward, jnp.zeros(connector.num_agents))
assert timestep.discount == 1.0
assert timestep.reward == 0.0
assert timestep.step_type == StepType.FIRST


Expand Down Expand Up @@ -91,7 +91,7 @@ def test_connector__step_connected(
chex.assert_trees_all_equal(real_state2, state2)

assert timestep.step_type == StepType.LAST
assert jnp.array_equal(timestep.discount, jnp.zeros(connector.num_agents))
assert jnp.array_equal(timestep.discount, jnp.asarray(0))
reward = connector._reward_fn(real_state1, action2, real_state2)
assert jnp.array_equal(timestep.reward, reward)

Expand Down Expand Up @@ -143,7 +143,7 @@ def test_connector__step_blocked(

assert jnp.array_equal(state.grid, expected_grid)
assert timestep.step_type == StepType.LAST
assert jnp.array_equal(timestep.discount, jnp.zeros(connector.num_agents))
assert jnp.array_equal(timestep.discount, jnp.asarray(0))

assert all(is_head_on_grid(state.agents, state.grid))
assert all(is_target_on_grid(state.agents, state.grid))
Expand All @@ -162,12 +162,12 @@ def test_connector__step_horizon(connector: Connector, state: State) -> None:
state, timestep = step_fn(state, actions)

assert timestep.step_type != StepType.LAST
assert jnp.array_equal(timestep.discount, jnp.ones(connector.num_agents))
assert jnp.array_equal(timestep.discount, jnp.asarray(1))

# step 5
state, timestep = step_fn(state, actions)
assert timestep.step_type == StepType.LAST
assert jnp.array_equal(timestep.discount, jnp.zeros(connector.num_agents))
assert jnp.array_equal(timestep.discount, jnp.asarray(0))


def test_connector__step_agents_collision(
Expand Down Expand Up @@ -230,47 +230,6 @@ def test_connector__does_not_smoke(connector: Connector) -> None:
check_env_does_not_smoke(connector)


def test_connector__obs_from_grid(
connector: Connector,
grid: chex.Array,
path0: int,
path1: int,
path2: int,
targ0: int,
targ1: int,
targ2: int,
posi0: int,
posi1: int,
posi2: int,
) -> None:
"""Tests that observations are correctly generated given the grid."""
observations = connector._obs_from_grid(grid)

expected_agent_1 = jnp.array(
[
[EMPTY, EMPTY, targ2, EMPTY, EMPTY, EMPTY],
[EMPTY, EMPTY, posi2, path2, path2, EMPTY],
[EMPTY, EMPTY, EMPTY, targ1, posi1, EMPTY],
[targ0, EMPTY, posi0, EMPTY, path1, EMPTY],
[EMPTY, EMPTY, path0, EMPTY, path1, EMPTY],
[EMPTY, EMPTY, path0, EMPTY, EMPTY, EMPTY],
]
)
expected_agent_2 = jnp.array(
[
[EMPTY, EMPTY, targ1, EMPTY, EMPTY, EMPTY],
[EMPTY, EMPTY, posi1, path1, path1, EMPTY],
[EMPTY, EMPTY, EMPTY, targ0, posi0, EMPTY],
[targ2, EMPTY, posi2, EMPTY, path0, EMPTY],
[EMPTY, EMPTY, path2, EMPTY, path0, EMPTY],
[EMPTY, EMPTY, path2, EMPTY, EMPTY, EMPTY],
]
)

expected_obs = jnp.stack([grid, expected_agent_1, expected_agent_2])
assert jnp.array_equal(expected_obs, observations)


def test_connector__get_action_mask(state: State, connector: Connector) -> None:
"""Validates the action masking."""
action_masks = jax.vmap(connector._get_action_mask, (0, None))(
Expand Down
Loading

0 comments on commit e2feb60

Please sign in to comment.