Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/local and global models in EnsembleModel #1745

Merged
merged 24 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c5786cb
feat: remove restriction in EnsembleModel, models can be a mix of Loc…
madtoinou May 4, 2023
9dc9001
feat: EnsembleModel accepts a mixture of local and global models for …
madtoinou May 5, 2023
7467a4c
feat: updated unittests
madtoinou May 5, 2023
0ca18ca
doc: fix typo in docstring, SeasonalityMode must be imported from dar…
madtoinou May 5, 2023
838b40c
doc: updated changelog
madtoinou May 5, 2023
bad67e0
Merge branch 'master' into feat/local-and-global-ensemble
madtoinou May 8, 2023
f3d060e
Merge branch 'master' into feat/local-and-global-ensemble
dennisbader May 15, 2023
608fb03
Merge branch 'master' into feat/local-and-global-ensemble
madtoinou May 16, 2023
5035cf4
feat: logger info when all the models in the ensemble do not support …
madtoinou May 17, 2023
e7fa8e4
fix: typo, using parenthesis to call proterty method
madtoinou May 17, 2023
e48e754
Apply suggestions from code review
madtoinou May 22, 2023
ce6606b
fix: made the covariates handling in ensemble model more transparent,…
madtoinou May 22, 2023
62397ef
Merge branch 'feat/local-and-global-ensemble' of https://github.com/u…
madtoinou May 22, 2023
fe80f65
Merge branch 'master' into feat/local-and-global-ensemble
madtoinou May 22, 2023
9011104
Merge branch 'master' into feat/local-and-global-ensemble
dennisbader May 23, 2023
a0cf04e
Update CHANGELOG.md
dennisbader May 23, 2023
0423418
Update CHANGELOG.md
dennisbader May 23, 2023
1efbcc1
Update CHANGELOG.md
dennisbader May 23, 2023
29d39be
Apply suggestions from code review
madtoinou May 24, 2023
93a9cd7
fix: addressed reviewer comments, added show_warning arg to ensemble_…
madtoinou May 24, 2023
d4a5370
fix typo
madtoinou May 24, 2023
79ee5ff
fix: improve warning synthax
madtoinou May 24, 2023
33fd7c5
Merge branch 'master' into feat/local-and-global-ensemble
madtoinou May 24, 2023
6a0970d
Merge branch 'master' into feat/local-and-global-ensemble
dennisbader May 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: made the covariates handling in ensemble model more transparent,…
… added tests for covariates support in ensemble models
  • Loading branch information
madtoinou committed May 22, 2023
commit ce6606b594b0187926f3bd3c7c676c42cab9658d
3 changes: 3 additions & 0 deletions darts/models/forecasting/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ def __init__(

Naive implementation of `EnsembleModel`
Returns the average of all predictions of the constituent models

If `future_covariates` or `past_covariates` are provided at training or inference time,
they will be passed only to the models supporting them.
"""
super().__init__(models)

Expand Down
61 changes: 41 additions & 20 deletions darts/models/forecasting/ensemble_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,21 @@ def __init__(self, models: List[ForecastingModel]):
super().__init__()
self.models = models

if self.supports_past_covariates and not self._full_past_covariates_support():
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
logger.info(
"Some models in the ensemble do not support past covariates, the past covariates will be "
"provided only to the models supporting them when calling fit/predict."
)

if (
self.supports_future_covariates
and not self._full_future_covariates_support()
):
logger.info(
"Some models in the ensemble do not support future covariates, the future covariates will be "
"provided only to the models supporting them when calling fit/predict."
)

def fit(
self,
series: Union[TimeSeries, Sequence[TimeSeries]],
Expand Down Expand Up @@ -116,8 +131,7 @@ def fit(
logger,
)

# inform user that covariates will be passed only to the models supporting them
self._logger_info_covariates_handling(past_covariates, future_covariates)
self._verify_past_future_covariates(past_covariates, future_covariates)

super().fit(series, past_covariates, future_covariates)

Expand Down Expand Up @@ -185,8 +199,7 @@ def predict(
verbose=verbose,
)

# inform user that covariates will be passed only to the models supporting them
self._logger_info_covariates_handling(past_covariates, future_covariates)
self._verify_past_future_covariates(past_covariates, future_covariates)

predictions = self._make_multiple_predictions(
n=n,
Expand Down Expand Up @@ -258,25 +271,33 @@ def find_max_lag_or_none(lag_id, aggregator) -> Optional[int]:
def _is_probabilistic(self) -> bool:
return all([model._is_probabilistic() for model in self.models])

@property
def supports_past_covariates(self) -> bool:
return any([model.supports_past_covariates for model in self.models])

@property
def supports_future_covariates(self) -> bool:
return any([model.supports_future_covariates for model in self.models])

def _full_past_covariates_support(self) -> bool:
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
return all([model.supports_past_covariates for model in self.models])

def _full_future_covariates_support(self) -> bool:
return all([model.supports_future_covariates for model in self.models])

def _logger_info_covariates_handling(
self,
past_covariates: Union[TimeSeries, Sequence[TimeSeries]],
future_covariates: Union[TimeSeries, Sequence[TimeSeries]],
):
if past_covariates is not None and not self._full_past_covariates_support():
logger.info(
"Some models in the ensemble do not support past covariates, the past covariates will be "
"provided only to the models supporting them."
)

if future_covariates is not None and not self._full_future_covariates_support():
logger.info(
"Some models in the ensemble do not support future covariates, the future covariates will be "
"provided only to the models supporting them."
)
def _verify_past_future_covariates(self, past_covariates, future_covariates):
"""
Verify that any non-None covariates comply with the model type.
"""
raise_if(
past_covariates is not None and not self.supports_past_covariates,
"Some past_covariates have been provided to a EnsembleModel containing no models "
"supporting such covariates.",
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
logger,
)
raise_if(
future_covariates is not None and not self.supports_future_covariates,
"Some future_covariates have been provided to a Ensemble model containing no models "
"supporting such covariates.",
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
logger,
)
14 changes: 7 additions & 7 deletions darts/models/forecasting/regression_ensemble_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@

from darts.logging import get_logger, raise_if, raise_if_not
from darts.models.forecasting.ensemble_model import EnsembleModel
from darts.models.forecasting.forecasting_model import (
GlobalForecastingModel,
LocalForecastingModel,
)
from darts.models.forecasting.forecasting_model import ForecastingModel
from darts.models.forecasting.linear_regression_model import LinearRegressionModel
from darts.models.forecasting.regression_model import RegressionModel
from darts.timeseries import TimeSeries
Expand All @@ -23,9 +20,7 @@
class RegressionEnsembleModel(EnsembleModel):
def __init__(
self,
forecasting_models: Union[
List[LocalForecastingModel], List[GlobalForecastingModel]
],
forecasting_models: List[ForecastingModel],
regression_train_n_points: int,
regression_model=None,
):
Expand All @@ -38,6 +33,11 @@ def __init__(
as in :class:`RegressionModel`, where the regression model is used to produce forecasts based on the
lagged series.

If `future_covariates` or `past_covariates` are provided at training or inference time,
they will be passed only to the forecasting models supporting them.

The regression model does not leverage the covariates passed to ``fit()`` and ``predict()``.

Parameters
----------
forecasting_models
Expand Down
24 changes: 23 additions & 1 deletion darts/tests/models/forecasting/test_ensemble_models.py
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
NaiveEnsembleModel,
NaiveSeasonal,
RegressionEnsembleModel,
StatsForecastAutoARIMA,
Theta,
)
from darts.tests.base_test_class import DartsBaseTestClass
Expand Down Expand Up @@ -188,7 +189,8 @@ def test_call_predict_global_models_multivariate_input_with_covariates(self):

@unittest.skipUnless(TORCH_AVAILABLE, "requires torch")
def test_input_models_mixed(self):
naive_ensemble = NaiveEnsembleModel([NaiveDrift(), RNNModel(12)])
# NaiveDrift is local, RNNModel is global
naive_ensemble = NaiveEnsembleModel([NaiveDrift(), RNNModel(12, n_epochs=1)])
# ensemble is neither local, nor global
self.assertFalse(naive_ensemble.is_local_ensemble)
self.assertFalse(naive_ensemble.is_global_ensemble)
Expand All @@ -197,6 +199,26 @@ def test_input_models_mixed(self):
with self.assertRaises(ValueError):
naive_ensemble.fit([self.series1, self.series2])

@unittest.skipUnless(TORCH_AVAILABLE, "requires torch")
def test_mixed_models_with_covariates(self):
dennisbader marked this conversation as resolved.
Show resolved Hide resolved
naive_ensemble_one_covs = NaiveEnsembleModel(
[NaiveDrift(), RNNModel(12, n_epochs=1)]
)
# none of the models support past covariates
with self.assertRaises(ValueError):
naive_ensemble_one_covs.fit(self.series1, past_covariates=self.series2)
# only RNN supports future covariates
naive_ensemble_one_covs.fit(self.series1, future_covariates=self.series2)

naive_ensemble_future_covs = NaiveEnsembleModel(
[StatsForecastAutoARIMA(), RNNModel(12, n_epochs=1)]
)
# none of the models support past covariates
with self.assertRaises(ValueError):
naive_ensemble_future_covs.fit(self.series1, past_covariates=self.series2)
# both models supports future covariates
naive_ensemble_future_covs.fit(self.series1, future_covariates=self.series2)

def test_fit_multivar_ts_with_local_models(self):
naive = NaiveEnsembleModel(
[NaiveDrift(), NaiveSeasonal(), Theta(), ExponentialSmoothing()]
Expand Down
151 changes: 111 additions & 40 deletions examples/00-quickstart.ipynb

Large diffs are not rendered by default.