Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature add activation to BlockRNN (#2492) #2504

Merged
Merged
34 changes: 31 additions & 3 deletions darts/models/forecasting/block_rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
nr_params: int,
num_layers_out_fc: Optional[List] = None,
dropout: float = 0.0,
activation: str = "ReLU",
**kwargs,
):
"""This class allows to create custom block RNN modules that can later be used with Darts'
Expand Down Expand Up @@ -63,6 +64,8 @@ def __init__(
This network connects the last hidden layer of the PyTorch RNN module to the output.
dropout
The fraction of neurons that are dropped in all-but-last RNN layers.
activation
The name of the activation function to be applied between the layers of the fully connected network.
**kwargs
all parameters required for :class:`darts.models.forecasting.pl_forecasting_module.PLForecastingModule`
base class.
Expand All @@ -77,6 +80,7 @@ def __init__(
self.nr_params = nr_params
self.num_layers_out_fc = [] if num_layers_out_fc is None else num_layers_out_fc
self.dropout = dropout
self.activation = activation
self.out_len = self.output_chunk_length

@io_processor
Expand Down Expand Up @@ -105,6 +109,7 @@ class _BlockRNNModule(CustomBlockRNNModule):
def __init__(
self,
name: str,
activation: Optional[str] = None,
**kwargs,
):
"""PyTorch module implementing a block RNN to be used in `BlockRNNModel`.
Expand All @@ -116,6 +121,7 @@ def __init__(

This module uses an RNN to encode the input sequence, and subsequently uses a fully connected
network as the decoder which takes as input the last hidden state of the encoder RNN.
Optionally, a non-linear activation function can be applied between the layers of the fully connected network.
The final output of the decoder is a sequence of length `output_chunk_length`. In this sense,
the `_BlockRNNModule` produces 'blocks' of forecasts at a time (which is different
from `_RNNModule` used by the `RNNModel`).
Expand All @@ -124,6 +130,9 @@ def __init__(
----------
name
The name of the specific PyTorch RNN module ("RNN", "GRU" or "LSTM").
activation
The name of the activation function to be applied between the layers of the fully connected network.
Options include "ReLU", "Sigmoid", "Tanh", or None for no activation. Default: None.
**kwargs
all parameters required for the :class:`darts.models.forecasting.CustomBlockRNNModule` base class.

Expand Down Expand Up @@ -155,10 +164,15 @@ def __init__(
# to the output of desired length
last = self.hidden_dim
feats = []
for feature in self.num_layers_out_fc + [
self.out_len * self.target_size * self.nr_params
]:
for index, feature in enumerate(
self.num_layers_out_fc + [self.out_len * self.target_size * self.nr_params]
):
feats.append(nn.Linear(last, feature))

# Add activation only between layers, but not on the final layer
if activation and index < len(self.num_layers_out_fc):
activation_function = getattr(nn, activation)()
feats.append(activation_function)
last = feature
self.fc = nn.Sequential(*feats)

Expand Down Expand Up @@ -195,6 +209,7 @@ def __init__(
n_rnn_layers: int = 1,
hidden_fc_sizes: Optional[List] = None,
dropout: float = 0.0,
activation: str = "ReLU",
**kwargs,
):
"""Block Recurrent Neural Network Model (RNNs).
Expand Down Expand Up @@ -243,6 +258,9 @@ def __init__(
Sizes of hidden layers connecting the last hidden layer of the RNN module to the output, if any.
dropout
Fraction of neurons afected by Dropout.
activation
The name of a torch.nn activation function to be applied between the layers of the fully connected network.
Default: "ReLU".
**kwargs
Optional arguments to initialize the pytorch_lightning.Module, pytorch_lightning.Trainer, and
Darts' :class:`TorchForecastingModel`.
Expand Down Expand Up @@ -435,6 +453,7 @@ def encode_year(idx):
self.hidden_dim = hidden_dim
self.n_rnn_layers = n_rnn_layers
self.dropout = dropout
self.activation = activation

@property
def supports_multivariate(self) -> bool:
Expand Down Expand Up @@ -464,6 +483,15 @@ def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module:
num_layers=self.n_rnn_layers,
num_layers_out_fc=hidden_fc_sizes,
dropout=self.dropout,
activation=self.activation,
SzymonCogiel marked this conversation as resolved.
Show resolved Hide resolved
**self.pl_module_params,
**kwargs,
)

def _check_ckpt_parameters(self, tfm_save):
# new parameters were added that will break loading weights
new_params = ["activation"]
for param in new_params:
if param not in tfm_save.model_params:
tfm_save.model_params[param] = "ReLU"
super()._check_ckpt_parameters(tfm_save)
29 changes: 22 additions & 7 deletions darts/tests/models/forecasting/test_block_RNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,31 +85,46 @@ def test_creation(self):
model1.fit(self.series)
preds1 = model1.predict(n=3)

# can create from a custom class itself
# can create from valid module name with ReLU activation
model2 = BlockRNNModel(
input_chunk_length=1,
output_chunk_length=1,
model=ModuleValid1,
model="RNN",
activation="ReLU",
hidden_fc_sizes=[10],
n_epochs=1,
random_state=42,
**tfm_kwargs,
)
model2.fit(self.series)
preds2 = model2.predict(n=3)
np.testing.assert_array_equal(preds1.all_values(), preds2.all_values())
assert preds1.values().shape == preds2.values().shape

# can create from a custom class itself
model3 = BlockRNNModel(
input_chunk_length=1,
output_chunk_length=1,
model=ModuleValid2,
model=ModuleValid1,
n_epochs=1,
random_state=42,
**tfm_kwargs,
)
model3.fit(self.series)
preds3 = model2.predict(n=3)
assert preds3.all_values().shape == preds2.all_values().shape
assert preds3.time_index.equals(preds2.time_index)
preds3 = model3.predict(n=3)
np.testing.assert_array_equal(preds1.all_values(), preds3.all_values())

model4 = BlockRNNModel(
input_chunk_length=1,
output_chunk_length=1,
model=ModuleValid2,
n_epochs=1,
random_state=42,
**tfm_kwargs,
)
model4.fit(self.series)
preds4 = model4.predict(n=3)
assert preds4.all_values().shape == preds3.all_values().shape
assert preds4.time_index.equals(preds3.time_index)

def test_fit(self, tmpdir_module):
# Test basic fit()
Expand Down
Loading