Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature add activation to BlockRNN (#2492) #2504

Merged
Prev Previous commit
Next Next commit
Feature add activation to BlockRNN (#2492)
* Add a check that raise an error when activation is None and hidden_fc_sizes is greater than 0
  • Loading branch information
SzymonCogiel committed Aug 29, 2024
commit 56258964ff14c3233439a7152c70dad5dece05ed
12 changes: 11 additions & 1 deletion darts/models/forecasting/block_rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ def __init__(
dropout
Fraction of neurons afected by Dropout.
activation
The name of a torch.nn activation function to be applied between the layers of the fully connected network. Default: None.
The name of a torch.nn activation function to be applied between the layers of the fully connected network.
Default: None.
**kwargs
Optional arguments to initialize the pytorch_lightning.Module, pytorch_lightning.Trainer, and
Darts' :class:`TorchForecastingModel`.
Expand Down Expand Up @@ -448,6 +449,15 @@ def encode_year(idx):
logger=logger,
)

raise_if(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the comment from above, we could remove this check

activation is None
and hidden_fc_sizes is not None
and len(hidden_fc_sizes) > 0,
"The model contains hidden fully connected layers, but the activation function is not set; "
"please specify a valid activation function or remove the hidden layers.",
logger=logger,
)

# check we got right activation function specified:
if activation is not None:
try:
Expand Down
13 changes: 13 additions & 0 deletions darts/tests/models/forecasting/test_block_RNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,19 @@ def test_raise_if_activation_with_single_linear_layer(self):
**tfm_kwargs,
)

def test_raise_if_no_activation_with_hidden_fc_layers(self):
with pytest.raises(ValueError):
BlockRNNModel(
input_chunk_length=1,
output_chunk_length=1,
model="RNN",
activation=None,
hidden_fc_sizes=[10],
n_epochs=1,
random_state=42,
**tfm_kwargs,
)

def test_fit(self, tmpdir_module):
# Test basic fit()
model = BlockRNNModel(
Expand Down