Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature add activation to BlockRNN (#2492) #2504

Merged
59 changes: 54 additions & 5 deletions darts/models/forecasting/block_rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
import torch.nn as nn

from darts.logging import get_logger, raise_log
from darts.logging import get_logger, raise_if, raise_log
from darts.models.forecasting.pl_forecasting_module import (
PLPastCovariatesModule,
io_processor,
Expand All @@ -30,6 +30,7 @@ def __init__(
nr_params: int,
num_layers_out_fc: Optional[List] = None,
dropout: float = 0.0,
activation: Optional[str] = None,
SzymonCogiel marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
):
"""This class allows to create custom block RNN modules that can later be used with Darts'
Expand Down Expand Up @@ -63,6 +64,8 @@ def __init__(
This network connects the last hidden layer of the PyTorch RNN module to the output.
dropout
The fraction of neurons that are dropped in all-but-last RNN layers.
activation
The name of the activation function to be applied between the layers of the fully connected network.
**kwargs
all parameters required for :class:`darts.models.forecasting.pl_forecasting_module.PLForecastingModule`
base class.
Expand All @@ -77,6 +80,7 @@ def __init__(
self.nr_params = nr_params
self.num_layers_out_fc = [] if num_layers_out_fc is None else num_layers_out_fc
self.dropout = dropout
self.activation = activation
self.out_len = self.output_chunk_length

@io_processor
Expand Down Expand Up @@ -105,6 +109,7 @@ class _BlockRNNModule(CustomBlockRNNModule):
def __init__(
self,
name: str,
activation: Optional[str] = None,
**kwargs,
):
"""PyTorch module implementing a block RNN to be used in `BlockRNNModel`.
Expand All @@ -116,6 +121,7 @@ def __init__(

This module uses an RNN to encode the input sequence, and subsequently uses a fully connected
network as the decoder which takes as input the last hidden state of the encoder RNN.
Optionally, a non-linear activation function can be applied between the layers of the fully connected network.
The final output of the decoder is a sequence of length `output_chunk_length`. In this sense,
the `_BlockRNNModule` produces 'blocks' of forecasts at a time (which is different
from `_RNNModule` used by the `RNNModel`).
Expand All @@ -124,6 +130,9 @@ def __init__(
----------
name
The name of the specific PyTorch RNN module ("RNN", "GRU" or "LSTM").
activation
The name of the activation function to be applied between the layers of the fully connected network.
Options include "ReLU", "Sigmoid", "Tanh", or None for no activation. Default: None.
**kwargs
all parameters required for the :class:`darts.models.forecasting.CustomBlockRNNModule` base class.

Expand All @@ -135,7 +144,8 @@ def __init__(
Outputs
-------
y of shape `(batch_size, output_chunk_length, target_size, nr_params)`
Tensor containing the prediction at the last time step of the sequence.
Tensor containing the prediction at the last time step of the sequence, where optional activation
functions may have been applied between the layers of the fully connected network.
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
"""

super().__init__(**kwargs)
Expand All @@ -155,10 +165,15 @@ def __init__(
# to the output of desired length
last = self.hidden_dim
feats = []
for feature in self.num_layers_out_fc + [
self.out_len * self.target_size * self.nr_params
]:
for index, feature in enumerate(
self.num_layers_out_fc + [self.out_len * self.target_size * self.nr_params]
):
feats.append(nn.Linear(last, feature))

# Add activation only between layers, but not on the final layer
if activation and index < len(self.num_layers_out_fc):
activation_function = getattr(nn, activation)()
feats.append(activation_function)
last = feature
self.fc = nn.Sequential(*feats)

Expand Down Expand Up @@ -195,6 +210,7 @@ def __init__(
n_rnn_layers: int = 1,
hidden_fc_sizes: Optional[List] = None,
dropout: float = 0.0,
activation: Optional[str] = None,
SzymonCogiel marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
):
"""Block Recurrent Neural Network Model (RNNs).
Expand Down Expand Up @@ -243,6 +259,9 @@ def __init__(
Sizes of hidden layers connecting the last hidden layer of the RNN module to the output, if any.
dropout
Fraction of neurons afected by Dropout.
activation
The name of a torch.nn activation function to be applied between the layers of the fully connected network.
Default: None.
**kwargs
Optional arguments to initialize the pytorch_lightning.Module, pytorch_lightning.Trainer, and
Darts' :class:`TorchForecastingModel`.
Expand Down Expand Up @@ -430,11 +449,40 @@ def encode_year(idx):
logger=logger,
)

raise_if(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the comment from above, we could remove this check

activation is None
and hidden_fc_sizes is not None
and len(hidden_fc_sizes) > 0,
"The model contains hidden fully connected layers, but the activation function is not set; "
"please specify a valid activation function or remove the hidden layers.",
logger=logger,
)

# check we got right activation function specified:
if activation is not None:
try:
getattr(nn, activation)
except AttributeError:
raise_log(
ValueError(
f"Invalid activation function: {activation}. "
"Please use a valid torch.nn activation function name."
),
logger=logger,
)
raise_if(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and we could remove this as well

activation and (hidden_fc_sizes is None or len(hidden_fc_sizes) == 0),
"The activation function has been set, but the model does not contain hidden fully connected layers; "
"either set `activation=None` or increase `hidden_fc_sizes`.",
logger=logger,
)

madtoinou marked this conversation as resolved.
Show resolved Hide resolved
self.rnn_type_or_module = model
self.hidden_fc_sizes = hidden_fc_sizes
self.hidden_dim = hidden_dim
self.n_rnn_layers = n_rnn_layers
self.dropout = dropout
self.activation = activation

@property
def supports_multivariate(self) -> bool:
Expand Down Expand Up @@ -464,6 +512,7 @@ def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module:
num_layers=self.n_rnn_layers,
num_layers_out_fc=hidden_fc_sizes,
dropout=self.dropout,
activation=self.activation,
SzymonCogiel marked this conversation as resolved.
Show resolved Hide resolved
**self.pl_module_params,
**kwargs,
)
84 changes: 77 additions & 7 deletions darts/tests/models/forecasting/test_block_RNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,31 +85,86 @@ def test_creation(self):
model1.fit(self.series)
preds1 = model1.predict(n=3)

# can create from a custom class itself
# can create from valid module name with ReLU activation
model2 = BlockRNNModel(
input_chunk_length=1,
output_chunk_length=1,
model=ModuleValid1,
model="RNN",
activation="ReLU",
hidden_fc_sizes=[10],
n_epochs=1,
random_state=42,
**tfm_kwargs,
)
model2.fit(self.series)
preds2 = model2.predict(n=3)
np.testing.assert_array_equal(preds1.all_values(), preds2.all_values())
assert preds1.values().shape == preds2.values().shape

# can create from a custom class itself
model3 = BlockRNNModel(
input_chunk_length=1,
output_chunk_length=1,
model=ModuleValid2,
model=ModuleValid1,
n_epochs=1,
random_state=42,
**tfm_kwargs,
)
model3.fit(self.series)
preds3 = model2.predict(n=3)
assert preds3.all_values().shape == preds2.all_values().shape
assert preds3.time_index.equals(preds2.time_index)
preds3 = model3.predict(n=3)
np.testing.assert_array_equal(preds1.all_values(), preds3.all_values())

model4 = BlockRNNModel(
input_chunk_length=1,
output_chunk_length=1,
model=ModuleValid2,
n_epochs=1,
random_state=42,
**tfm_kwargs,
)
model4.fit(self.series)
preds4 = model4.predict(n=3)
assert preds4.all_values().shape == preds3.all_values().shape
assert preds4.time_index.equals(preds3.time_index)

def test_invalid_activation(self):
with pytest.raises(
ValueError, match="Invalid activation function: InvalidActivation"
):
BlockRNNModel(
input_chunk_length=1,
output_chunk_length=1,
model="RNN",
activation="InvalidActivation",
hidden_fc_sizes=[10],
n_epochs=1,
random_state=42,
**tfm_kwargs,
)

def test_raise_if_activation_with_single_linear_layer(self):
with pytest.raises(ValueError):
BlockRNNModel(
input_chunk_length=1,
output_chunk_length=1,
model="RNN",
activation="ReLU",
n_epochs=1,
random_state=42,
**tfm_kwargs,
)

def test_raise_if_no_activation_with_hidden_fc_layers(self):
with pytest.raises(ValueError):
BlockRNNModel(
input_chunk_length=1,
output_chunk_length=1,
model="RNN",
activation=None,
hidden_fc_sizes=[10],
n_epochs=1,
random_state=42,
**tfm_kwargs,
)

def test_fit(self, tmpdir_module):
# Test basic fit()
Expand Down Expand Up @@ -180,3 +235,18 @@ def helper_test_pred_length(self, pytorch_model, series):

def test_pred_length(self):
self.helper_test_pred_length(BlockRNNModel, self.series)

def test_varied_chunk_lengths(self):
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
model = BlockRNNModel(
input_chunk_length=5,
output_chunk_length=3,
n_epochs=2,
activation="ReLU",
hidden_fc_sizes=[10],
random_state=42,
**tfm_kwargs,
)
model.fit(self.series[:50])
pred = model.predict(3)
assert len(pred) == 3
assert pred.values().shape == (3, 1)