Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature add activation to BlockRNN (#2492) #2504

Merged

Conversation

SzymonCogiel
Copy link
Contributor

@SzymonCogiel SzymonCogiel commented Aug 17, 2024

Checklist before merging this PR:

  • Mentioned all issues that this PR fixes or addresses.
  • Summarized the updates of this PR under Summary.
  • Added an entry under Unreleased in the Changelog.

Fixes #2492 .

Summary

  • Added support for specifying PyTorch activation functions (ReLU, Sigmoid, Tanh, or None) in the BlockRNNModel.
  • Ensured that activation functions are applied between fully connected layers, but not as the final layer.
  • Implemented a check to raise an error if an activation function is set but the model only contains one linear layer.
  • Updated documentation to reflect the new activation parameter and usage examples.
  • Added test cases to verify the correct application of activation functions and to handle edge cases.

Other Information

Consider using a logging warning instead of raising an error if an activation function is set but the model only contains one linear layer.

* Added support for specifying PyTorch activation functions (`ReLU`, `Sigmoid`, `Tanh`, or `None`) in the `BlockRNNModel`.
* Ensured that activation functions are applied between fully connected layers, but not as the final layer.
* Implemented a check to raise an error if an activation function is set but the model only contains one linear layer.
* Updated documentation to reflect the new activation parameter and usage examples.
* Added test cases to verify the correct application of activation functions and to handle edge cases.
Copy link

codecov bot commented Aug 19, 2024

Codecov Report

Attention: Patch coverage is 58.33333% with 5 lines in your changes missing coverage. Please review.

Project coverage is 93.75%. Comparing base (26c5f39) to head (e94b3c7).
Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
darts/models/forecasting/block_rnn_model.py 58.33% 5 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2504      +/-   ##
==========================================
- Coverage   93.79%   93.75%   -0.05%     
==========================================
  Files         139      139              
  Lines       14741    14738       -3     
==========================================
- Hits        13827    13818       -9     
- Misses        914      920       +6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@eschibli
Copy link
Contributor

Should a warning be raised if activation is None and hidden_fc_sizes isn't? That would always be suboptimal and users who don't read the changelog might not be aware the default activation is None.

Copy link
Collaborator

@madtoinou madtoinou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR, some minor comments about the sanity checks.

darts/models/forecasting/block_rnn_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/block_rnn_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/block_rnn_model.py Outdated Show resolved Hide resolved
darts/tests/models/forecasting/test_block_RNN.py Outdated Show resolved Hide resolved
SzymonCogiel and others added 4 commits August 28, 2024 17:09
Co-authored-by: madtoinou <32447896+madtoinou@users.noreply.github.com>
Co-authored-by: madtoinou <32447896+madtoinou@users.noreply.github.com>
* Add a check that raise an error when activation is None and hidden_fc_sizes is greater than 0
Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @SzymonCogiel and thanks for this great PR 🚀 We're close, I just had a couple more suggestions. Mainly that I would set a default activation, to make it easier for the user in case he uses multiple fc layers.

darts/tests/models/forecasting/test_block_RNN.py Outdated Show resolved Hide resolved
darts/models/forecasting/block_rnn_model.py Show resolved Hide resolved
@@ -430,11 +449,40 @@ def encode_year(idx):
logger=logger,
)

raise_if(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the comment from above, we could remove this check

),
logger=logger,
)
raise_if(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and we could remove this as well

darts/models/forecasting/block_rnn_model.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates @SzymonCogiel. Last little adaptions and then we're ready to merge :)

darts/models/forecasting/block_rnn_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/block_rnn_model.py Outdated Show resolved Hide resolved
SzymonCogiel and others added 2 commits September 5, 2024 16:47
Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
* Revert docstring _BlockRNNModule
Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good now, thanks @SzymonCogiel 🚀

@dennisbader dennisbader merged commit 38c066b into unit8co:master Sep 13, 2024
9 checks passed
eschibli pushed a commit to eschibli/darts that referenced this pull request Oct 8, 2024
* Feature add activation to BlockRNN (unit8co#2492)

* Added support for specifying PyTorch activation functions (`ReLU`, `Sigmoid`, `Tanh`, or `None`) in the `BlockRNNModel`.
* Ensured that activation functions are applied between fully connected layers, but not as the final layer.
* Implemented a check to raise an error if an activation function is set but the model only contains one linear layer.
* Updated documentation to reflect the new activation parameter and usage examples.
* Added test cases to verify the correct application of activation functions and to handle edge cases.

* Update darts/models/forecasting/block_rnn_model.py

Co-authored-by: madtoinou <32447896+madtoinou@users.noreply.github.com>

* Update darts/models/forecasting/block_rnn_model.py

Co-authored-by: madtoinou <32447896+madtoinou@users.noreply.github.com>

* Feature add activation to BlockRNN (unit8co#2492)

* Add a check that raise an error when activation is None and hidden_fc_sizes is greater than 0

* Update darts/models/forecasting/block_rnn_model.py

Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>

* Feature add activation to BlockRNN (unit8co#2492)

* _check_ckpt_parameters
* Remove redundant raise_if

* Update darts/models/forecasting/block_rnn_model.py

Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>

* Feature add activation to BlockRNN (unit8co#2492)

* Revert docstring _BlockRNNModule

---------

Co-authored-by: madtoinou <32447896+madtoinou@users.noreply.github.com>
Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] No activations in BlockRNNModel output MLP
4 participants