Skip to content

Commit

Permalink
[Minor] merge main
Browse files Browse the repository at this point in the history
  • Loading branch information
LTluttmann committed May 29, 2024
2 parents c2c0b8f + 6074225 commit 7fdf21c
Show file tree
Hide file tree
Showing 16 changed files with 772 additions and 36 deletions.
256 changes: 234 additions & 22 deletions examples/other/1-mtvrp.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions rl4co/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"tsp": TSPEnv,
"smtwtp": SMTWTPEnv,
"mdcpdp": MDCPDPEnv,
"mtvrp": MTVRPEnv,
}


Expand Down
2 changes: 1 addition & 1 deletion rl4co/envs/routing/mtvrp/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def _check_c1(feature="demand_linehaul"):
_check_c1("demand_linehaul")
_check_c1("demand_backhaul")

def load_data(fpath, batch_size=[], scale=False):
def load_data(self, fpath, batch_size=[], scale=False):
"""Dataset loading from file
Normalize demand by capacity to be in [0, 1]
"""
Expand Down
1 change: 1 addition & 0 deletions rl4co/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from rl4co.models.zoo.l2d import L2DAttnPolicy, L2DModel, L2DPolicy, L2DPPOModel
from rl4co.models.zoo.matnet import MatNet, MatNetPolicy
from rl4co.models.zoo.mdam import MDAM, MDAMPolicy
from rl4co.models.zoo.mvmoe import MVMoE_AM, MVMoE_POMO
from rl4co.models.zoo.n2s import N2S, N2SPolicy
from rl4co.models.zoo.nargnn import NARGNNPolicy
from rl4co.models.zoo.pomo import POMO
Expand Down
72 changes: 71 additions & 1 deletion rl4co/models/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from einops import rearrange

from rl4co.models.nn.moe import MoE
from rl4co.utils import get_pylogger

log = get_pylogger(__name__)
Expand Down Expand Up @@ -247,6 +248,7 @@ def __init__(
out_bias: bool = False,
check_nan: bool = True,
sdpa_fn: Optional[Callable] = None,
**kwargs,
):
super(PointerAttention, self).__init__()
self.num_heads = num_heads
Expand All @@ -270,7 +272,7 @@ def forward(self, query, key, value, logit_key, attn_mask=None):
"""
# Compute inner multi-head attention with no projections.
heads = self._inner_mha(query, key, value, attn_mask)
glimpse = self.project_out(heads)
glimpse = self._project_out(heads)

# Batch matrix multiplication to compute logits (batch_size, num_steps, graph_size)
# bmm is slightly faster than einsum and matmul
Expand Down Expand Up @@ -302,6 +304,74 @@ def _inner_mha(self, query, key, value, attn_mask):
def _make_heads(self, v):
return rearrange(v, "... g (h s) -> ... h g s", h=self.num_heads)

def _project_out(self, out):
return self.project_out(out)


class PointerAttnMoE(PointerAttention):
"""Calculate logits given query, key and value and logit key.
This follows the pointer mechanism of Vinyals et al. (2015) (https://arxiv.org/abs/1506.03134),
and the MoE gating mechanism of Zhou et al. (2024) <https://arxiv.org/abs/2405.01029>.
Note:
With Flash Attention, masking is not supported
Performs the following:
1. Apply cross attention to get the heads
2. Project heads to get glimpse
3. Compute attention score between glimpse and logit key
Args:
embed_dim: total dimension of the model
num_heads: number of heads
mask_inner: whether to mask inner attention
linear_bias: whether to use bias in linear projection
check_nan: whether to check for NaNs in logits
sdpa_fn: scaled dot product attention function (SDPA) implementation
moe_kwargs: Keyword arguments for MoE
"""

def __init__(
self,
embed_dim: int,
num_heads: int,
mask_inner: bool = True,
out_bias: bool = False,
check_nan: bool = True,
sdpa_fn: Optional[Callable] = None,
moe_kwargs: Optional[dict] = None,
):
super(PointerAttnMoE, self).__init__(

Check warning on line 344 in rl4co/models/nn/attention.py

View check run for this annotation

Codecov / codecov/patch

rl4co/models/nn/attention.py#L344

Added line #L344 was not covered by tests
embed_dim, num_heads, mask_inner, out_bias, check_nan, sdpa_fn
)
self.moe_kwargs = moe_kwargs

self.project_out = None
self.project_out_moe = MoE(

Check warning on line 350 in rl4co/models/nn/attention.py

View check run for this annotation

Codecov / codecov/patch

rl4co/models/nn/attention.py#L350

Added line #L350 was not covered by tests
embed_dim, embed_dim, num_neurons=[], out_bias=out_bias, **moe_kwargs
)
if self.moe_kwargs["light_version"]:
self.dense_or_moe = nn.Linear(embed_dim, 2, bias=False)
self.project_out = nn.Linear(embed_dim, embed_dim, bias=out_bias)

def _project_out(self, out):
"""Implementation of Hierarchical Gating based on Zhou et al. (2024) <https://arxiv.org/abs/2405.01029>."""
if self.moe_kwargs["light_version"]:
probs = F.softmax(

Check warning on line 360 in rl4co/models/nn/attention.py

View check run for this annotation

Codecov / codecov/patch

rl4co/models/nn/attention.py#L360

Added line #L360 was not covered by tests
self.dense_or_moe(out.view(-1, out.size(-1)).mean(dim=0, keepdim=True)),
dim=-1,
)
selected = probs.multinomial(1).squeeze(0)
out = (

Check warning on line 365 in rl4co/models/nn/attention.py

View check run for this annotation

Codecov / codecov/patch

rl4co/models/nn/attention.py#L365

Added line #L365 was not covered by tests
self.project_out_moe(out)
if selected.item() == 1
else self.project_out(out)
)
glimpse = out * probs.squeeze(0)[selected]
else:
glimpse = self.project_out_moe(out)
return glimpse


# Deprecated
class LogitAttention(PointerAttention):
Expand Down
39 changes: 39 additions & 0 deletions rl4co/models/nn/env_embeddings/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def env_context_embedding(env_name: str, config: dict) -> nn.Module:
"mtsp": MTSPContext,
"smtwtp": SMTWTPContext,
"mdcpdp": MDCPDPContext,
"mtvrp": MTVRPContext,
}

if env_name not in embedding_registry:
Expand Down Expand Up @@ -326,3 +327,41 @@ def forward(self, h, td):
busy_proj = self.proj_busy(busy_for.unsqueeze(-1))

Check warning on line 327 in rl4co/models/nn/env_embeddings/context.py

View check run for this annotation

Codecov / codecov/patch

rl4co/models/nn/env_embeddings/context.py#L326-L327

Added lines #L326 - L327 were not covered by tests
# (b m e)
return h + busy_proj

Check warning on line 329 in rl4co/models/nn/env_embeddings/context.py

View check run for this annotation

Codecov / codecov/patch

rl4co/models/nn/env_embeddings/context.py#L329

Added line #L329 was not covered by tests


class MTVRPContext(VRPContext):
"""Context embedding for Multi-Task VRPEnv.
Project the following to the embedding space:
- current node embedding
- remaining_linehaul_capacity (vehicle_capacity - used_capacity_linehaul)
- remaining_backhaul_capacity (vehicle_capacity - used_capacity_backhaul)
- current time
- current_route_length
- open route indicator
"""

def __init__(self, embed_dim):
super(VRPContext, self).__init__(
embed_dim=embed_dim, step_context_dim=embed_dim + 5
)

def _state_embedding(self, embeddings, td):
remaining_linehaul_capacity = (

Check warning on line 349 in rl4co/models/nn/env_embeddings/context.py

View check run for this annotation

Codecov / codecov/patch

rl4co/models/nn/env_embeddings/context.py#L349

Added line #L349 was not covered by tests
td["vehicle_capacity"] - td["used_capacity_linehaul"]
)
remaining_backhaul_capacity = (

Check warning on line 352 in rl4co/models/nn/env_embeddings/context.py

View check run for this annotation

Codecov / codecov/patch

rl4co/models/nn/env_embeddings/context.py#L352

Added line #L352 was not covered by tests
td["vehicle_capacity"] - td["used_capacity_backhaul"]
)
current_time = td["current_time"]
current_route_length = td["current_route_length"]
open_route = td["open_route"]
return torch.cat(

Check warning on line 358 in rl4co/models/nn/env_embeddings/context.py

View check run for this annotation

Codecov / codecov/patch

rl4co/models/nn/env_embeddings/context.py#L358

Added line #L358 was not covered by tests
[
remaining_linehaul_capacity,
remaining_backhaul_capacity,
current_time,
current_route_length,
open_route,
],
-1,
)
1 change: 1 addition & 0 deletions rl4co/models/nn/env_embeddings/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def env_dynamic_embedding(env_name: str, config: dict) -> nn.Module:
"smtwtp": StaticEmbedding,
"jssp": JSSPDynamicEmbedding,
"fjsp": JSSPDynamicEmbedding,
"mtvrp": StaticEmbedding,
}

if env_name not in embedding_registry:
Expand Down
33 changes: 33 additions & 0 deletions rl4co/models/nn/env_embeddings/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def env_init_embedding(env_name: str, config: dict) -> nn.Module:
"mdcpdp": MDCPDPInitEmbedding,
"fjsp": FJSPInitEmbedding,
"jssp": FJSPInitEmbedding,
"mtvrp": MTVRPInitEmbedding,
}

if env_name not in embedding_registry:
Expand Down Expand Up @@ -489,3 +490,35 @@ def forward(self, td: TensorDict):
# edgeweights for matnet
matnet_edge_weights = proc_times.transpose(1, 2) / self.scaling_factor
return ops_emb, ma_emb, matnet_edge_weights

Check warning on line 492 in rl4co/models/nn/env_embeddings/init.py

View check run for this annotation

Codecov / codecov/patch

rl4co/models/nn/env_embeddings/init.py#L491-L492

Added lines #L491 - L492 were not covered by tests


class MTVRPInitEmbedding(VRPInitEmbedding):
def __init__(self, embed_dim, linear_bias=True, node_dim: int = 7):
# node_dim = 7: x, y, demand_linehaul, demand_backhaul, tw start, tw end, service time
super(MTVRPInitEmbedding, self).__init__(embed_dim, linear_bias, node_dim)

def forward(self, td):
depot, cities = td["locs"][:, :1, :], td["locs"][:, 1:, :]
demand_linehaul, demand_backhaul = (

Check warning on line 502 in rl4co/models/nn/env_embeddings/init.py

View check run for this annotation

Codecov / codecov/patch

rl4co/models/nn/env_embeddings/init.py#L502

Added line #L502 was not covered by tests
td["demand_linehaul"][..., 1:],
td["demand_backhaul"][..., 1:],
)
service_time = td["service_time"][..., 1:]
time_windows = td["time_windows"][..., 1:, :]
# [!] convert [0, inf] -> [0, 0] if a problem does not include the time window constraint, do not modify in-place
time_windows = torch.nan_to_num(time_windows, posinf=0.0)

Check warning on line 509 in rl4co/models/nn/env_embeddings/init.py

View check run for this annotation

Codecov / codecov/patch

rl4co/models/nn/env_embeddings/init.py#L509

Added line #L509 was not covered by tests
# embeddings
depot_embedding = self.init_embed_depot(depot)
node_embeddings = self.init_embed(
torch.cat(
(
cities,
demand_linehaul[..., None],
demand_backhaul[..., None],
time_windows,
service_time[..., None],
),
-1,
)
)
return torch.cat((depot_embedding, node_embeddings), -2)
23 changes: 14 additions & 9 deletions rl4co/models/nn/graph/attnnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from torch import Tensor

from rl4co.models.nn.mlp import MLP
from rl4co.models.nn.moe import MoE
from rl4co.models.nn.attention import MultiHeadAttention
from rl4co.models.nn.ops import Normalization, SkipConnection
from rl4co.utils.pylogger import get_pylogger
Expand All @@ -20,6 +22,7 @@ class MultiHeadAttentionLayer(nn.Sequential):
feedforward_hidden: dimension of the hidden layer in the feed-forward layer
normalization: type of normalization to use (batch, layer, none)
sdpa_fn: scaled dot product attention function (SDPA)
moe_kwargs: Keyword arguments for MoE
"""

def __init__(
Expand All @@ -30,21 +33,20 @@ def __init__(
normalization: Optional[str] = "batch",
bias: bool = True,
sdpa_fn: Optional[Callable] = None,
moe_kwargs: Optional[dict] = None,
):
num_neurons = [feedforward_hidden] if feedforward_hidden > 0 else []
if moe_kwargs is not None:
ffn = MoE(embed_dim, embed_dim, num_neurons=num_neurons, **moe_kwargs)
else:
ffn = MLP(input_dim=embed_dim, output_dim=embed_dim, num_neurons=num_neurons, hidden_act="ReLU")

super(MultiHeadAttentionLayer, self).__init__(
SkipConnection(
MultiHeadAttention(embed_dim, num_heads, bias=bias, sdpa_fn=sdpa_fn)
),
Normalization(embed_dim, normalization),
SkipConnection(
nn.Sequential(
nn.Linear(embed_dim, feedforward_hidden),
nn.ReLU(),
nn.Linear(feedforward_hidden, embed_dim),
)
if feedforward_hidden > 0
else nn.Linear(embed_dim, embed_dim)
),
SkipConnection(ffn),
Normalization(embed_dim, normalization),
)

Expand All @@ -60,6 +62,7 @@ class GraphAttentionNetwork(nn.Module):
normalization: type of normalization to use (batch, layer, none)
feedforward_hidden: dimension of the hidden layer in the feed-forward layer
sdpa_fn: scaled dot product attention function (SDPA)
moe_kwargs: Keyword arguments for MoE
"""

def __init__(
Expand All @@ -70,6 +73,7 @@ def __init__(
normalization: str = "batch",
feedforward_hidden: int = 512,
sdpa_fn: Optional[Callable] = None,
moe_kwargs: Optional[dict] = None,
):
super(GraphAttentionNetwork, self).__init__()

Expand All @@ -81,6 +85,7 @@ def __init__(
feedforward_hidden=feedforward_hidden,
normalization=normalization,
sdpa_fn=sdpa_fn,
moe_kwargs=moe_kwargs,
)
for _ in range(num_layers)
)
Expand Down
Loading

0 comments on commit 7fdf21c

Please sign in to comment.