Skip to content

Commit

Permalink
Feat/local and global models in EnsembleModel (#1745)
Browse files Browse the repository at this point in the history
* feat: remove restriction in EnsembleModel, models can be a mix of Local and Global models

* feat: EnsembleModel accepts a mixture of local and global models for single ts training/inference only. If provided, the covariates are passed only to the models supporting them (individually)

* feat: updated unittests

* doc: fix typo in docstring, SeasonalityMode must be imported from darts.utils.utils

* doc: updated changelog

* feat: logger info when all the models in the ensemble do not support the same covariates, to make the behavior more transparent

* fix: typo, using parenthesis to call proterty method

* Apply suggestions from code review

Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>

* fix: made the covariates handling in ensemble model more transparent, added tests for covariates support in ensemble models

* Update CHANGELOG.md

* Update CHANGELOG.md

* Update CHANGELOG.md

* Apply suggestions from code review

Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>

* fix: addressed reviewer comments, added show_warning arg to ensemble_model, support_*_covariates for RegressionModels rely on the lags attribute

* fix typo

* fix: improve warning synthax

---------

Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
  • Loading branch information
madtoinou and dennisbader authored May 26, 2023
1 parent 52f4004 commit ee53a83
Show file tree
Hide file tree
Showing 8 changed files with 324 additions and 94 deletions.
11 changes: 7 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,22 @@ We do our best to avoid the introduction of breaking changes,
but cannot always guarantee backwards compatibility. Changes that may **break code which uses a previous release of Darts** are marked with a "🔴".

## [Unreleased](https://github.com/unit8co/darts/tree/master)
**Fixed**
- Fixed a bug when loading the weights of a model trained with encoders or a Likelihood. [#1744](https://github.com/unit8co/darts/pull/1744) by [Antoine Madrona](https://github.com/madtoinou).

[Full Changelog](https://github.com/unit8co/darts/compare/0.24.0...master)

### For users of the library:

**Improved**
- Added support for `PathLike` to the `save()` and `load()` functions of `ForecastingModel`. [#1754](https://github.com/unit8co/darts/pull/1754) by [Simon Sudrich](https://github.com/sudrich).
- Fixed an issue with `TorchForecastingModel.load_from_checkpoint()` not properly loading the loss function and metrics. [#1749](https://github.com/unit8co/darts/pull/1749) by [Antoine Madrona](https://github.com/madtoinou).
- General model improvements:
- Added support for `PathLike` to the `save()` and `load()` functions of all non-deep learning based models. [#1754](https://github.com/unit8co/darts/pull/1754) by [Simon Sudrich](https://github.com/sudrich).
- Improvements to `EnsembleModel`:
- Model creation parameter `forecasting_models` now supports a mix of `LocalForecastingModel` and `GlobalForecastingModel` (single `TimeSeries` training/inference only, due to the local models). [#1745](https://github.com/unit8co/darts/pull/1745) by [Antoine Madrona](https://github.com/madtoinou).
- Future and past covariates can now be used even if `forecasting_models` have different covariates support. The covariates passed to `fit()`/`predict()` are used only by models that support it. [#1745](https://github.com/unit8co/darts/pull/1745) by [Antoine Madrona](https://github.com/madtoinou).

**Fixed**
- Fixed an issue not considering original component names for `TimeSeries.plot()` when providing a label prefix. [#1783](https://github.com/unit8co/darts/pull/1783) by [Simon Sudrich](https://github.com/sudrich).
- Fixed an issue with `TorchForecastingModel.load_from_checkpoint()` not properly loading the loss function and metrics. [#1749](https://github.com/unit8co/darts/pull/1749) by [Antoine Madrona](https://github.com/madtoinou).
- Fixed a bug when loading the weights of a `TorchForecastingModel` trained with encoders or a Likelihood. [#1744](https://github.com/unit8co/darts/pull/1744) by [Antoine Madrona](https://github.com/madtoinou).

## [0.24.0](https://github.com/unit8co/darts/tree/0.24.0) (2023-04-12)
### For users of the library:
Expand Down
31 changes: 20 additions & 11 deletions darts/models/forecasting/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,26 @@ def predict(self, n: int, num_samples: int = 1, verbose: bool = False):

class NaiveEnsembleModel(EnsembleModel):
def __init__(
self, models: Union[List[LocalForecastingModel], List[GlobalForecastingModel]]
self,
models: Union[List[LocalForecastingModel], List[GlobalForecastingModel]],
show_warnings: bool = True,
):
"""Naive combination model
Naive implementation of `EnsembleModel`
Returns the average of all predictions of the constituent models
If `future_covariates` or `past_covariates` are provided at training or inference time,
they will be passed only to the models supporting them.
Parameters
----------
models
List of forecasting models whose predictions to ensemble
show_warnings
Whether to show warnings related to models covariates support.
"""
super().__init__(models)
super().__init__(models=models, show_warnings=show_warnings)

def fit(
self,
Expand All @@ -184,15 +196,12 @@ def fit(
future_covariates=future_covariates,
)
for model in self.models:
if self.is_global_ensemble:
kwargs = dict(series=series)
if model.supports_past_covariates:
kwargs["past_covariates"] = past_covariates
if model.supports_future_covariates:
kwargs["future_covariates"] = future_covariates
model.fit(**kwargs)
else:
model.fit(series=series)
kwargs = dict(series=series)
if model.supports_past_covariates:
kwargs["past_covariates"] = past_covariates
if model.supports_future_covariates:
kwargs["future_covariates"] = future_covariates
model.fit(**kwargs)

return self

Expand Down
129 changes: 102 additions & 27 deletions darts/models/forecasting/ensemble_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from darts.logging import get_logger, raise_if, raise_if_not
from darts.models.forecasting.forecasting_model import (
ForecastingModel,
GlobalForecastingModel,
LocalForecastingModel,
)
Expand All @@ -22,31 +23,43 @@ class EnsembleModel(GlobalForecastingModel):
Ensemble models take in a list of forecasting models and ensemble their predictions
to make a single one according to the rule defined by their `ensemble()` method.
If `future_covariates` or `past_covariates` are provided at training or inference time,
they will be passed only to the models supporting them.
Parameters
----------
models
List of forecasting models whose predictions to ensemble
show_warnings
Whether to show warnings related to models covariates support.
"""

def __init__(
self, models: Union[List[LocalForecastingModel], List[GlobalForecastingModel]]
):
def __init__(self, models: List[ForecastingModel], show_warnings: bool = True):
raise_if_not(
isinstance(models, list) and models,
"Cannot instantiate EnsembleModel with an empty list of models",
logger,
)

is_local_ensemble = all(
isinstance(model, LocalForecastingModel) for model in models
)
self.is_global_ensemble = all(
is_local_model = [isinstance(model, LocalForecastingModel) for model in models]
is_global_model = [
isinstance(model, GlobalForecastingModel) for model in models
)
]

self.is_local_ensemble = all(is_local_model)
self.is_global_ensemble = all(is_global_model)

raise_if_not(
is_local_ensemble or self.is_global_ensemble,
"All models must be of the same type: either GlobalForecastingModel, or LocalForecastingModel.",
all(
[
local_model or global_model
for local_model, global_model in zip(
is_local_model, is_global_model
)
]
),
"All models must be of type `GlobalForecastingModel`, or `LocalForecastingModel`. "
"Also, make sure that all models in `forecasting_model/models` are instantiated.",
logger,
)

Expand All @@ -60,6 +73,27 @@ def __init__(
super().__init__()
self.models = models

if show_warnings:
if (
self.supports_past_covariates
and not self._full_past_covariates_support()
):
logger.warning(
"Some models in the ensemble do not support past covariates, the past covariates will be "
"provided only to the models supporting them when calling fit()` or `predict()`. "
"To hide these warnings, set `show_warnings=False`."
)

if (
self.supports_future_covariates
and not self._full_future_covariates_support()
):
logger.warning(
"Some models in the ensemble do not support future covariates, the future covariates will be "
"provided only to the models supporting them when calling `fit()` or `predict()`. "
"To hide these warnings, set `show_warnings=False`."
)

def fit(
self,
series: Union[TimeSeries, Sequence[TimeSeries]],
Expand All @@ -71,34 +105,37 @@ def fit(
Note that `EnsembleModel.fit()` does NOT call `fit()` on each of its constituent forecasting models.
It is left to classes inheriting from EnsembleModel to do so appropriately when overriding `fit()`
"""

is_single_series = isinstance(series, TimeSeries)

# local models OR mix of local and global models
raise_if(
not self.is_global_ensemble and not isinstance(series, TimeSeries),
"The models are of type LocalForecastingModel, which does not support training on multiple series.",
logger,
)
raise_if(
not self.is_global_ensemble and past_covariates is not None,
"The models are of type LocalForecastingModel, which does not support past covariates.",
not self.is_global_ensemble and not is_single_series,
"The models contain at least one LocalForecastingModel, which does not support training on multiple "
"series.",
logger,
)

is_single_series = isinstance(series, TimeSeries)

# check that if timeseries is single series, than covariates are as well and vice versa
error = False
# check that if timeseries is single series, that covariates are as well and vice versa
error_past_cov = False
error_future_cov = False

if past_covariates is not None:
error = is_single_series != isinstance(past_covariates, TimeSeries)
error_past_cov = is_single_series != isinstance(past_covariates, TimeSeries)

if future_covariates is not None:
error = is_single_series != isinstance(future_covariates, TimeSeries)
error_future_cov = is_single_series != isinstance(
future_covariates, TimeSeries
)

raise_if(
error,
"Both series and covariates have to be either univariate or multivariate.",
error_past_cov or error_future_cov,
"Both series and covariates have to be either single TimeSeries or sequences of TimeSeries.",
logger,
)

self._verify_past_future_covariates(past_covariates, future_covariates)

super().fit(series, past_covariates, future_covariates)

return self
Expand All @@ -125,12 +162,17 @@ def _make_multiple_predictions(
num_samples: int = 1,
):
is_single_series = isinstance(series, TimeSeries) or series is None
# maximize covariate usage
predictions = [
model._predict_wrapper(
n=n,
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
past_covariates=past_covariates
if model.supports_past_covariates
else None,
future_covariates=future_covariates
if model.supports_future_covariates
else None,
num_samples=num_samples,
)
for model in self.models
Expand Down Expand Up @@ -160,6 +202,8 @@ def predict(
verbose=verbose,
)

self._verify_past_future_covariates(past_covariates, future_covariates)

predictions = self._make_multiple_predictions(
n=n,
series=series,
Expand Down Expand Up @@ -229,3 +273,34 @@ def find_max_lag_or_none(lag_id, aggregator) -> Optional[int]:

def _is_probabilistic(self) -> bool:
return all([model._is_probabilistic() for model in self.models])

@property
def supports_past_covariates(self) -> bool:
return any([model.supports_past_covariates for model in self.models])

@property
def supports_future_covariates(self) -> bool:
return any([model.supports_future_covariates for model in self.models])

def _full_past_covariates_support(self) -> bool:
return all([model.supports_past_covariates for model in self.models])

def _full_future_covariates_support(self) -> bool:
return all([model.supports_future_covariates for model in self.models])

def _verify_past_future_covariates(self, past_covariates, future_covariates):
"""
Verify that any non-None covariates comply with the model type.
"""
raise_if(
past_covariates is not None and not self.supports_past_covariates,
"`past_covariates` were provided to an `EnsembleModel` but none of its "
"base models support such covariates.",
logger,
)
raise_if(
future_covariates is not None and not self.supports_future_covariates,
"`future_covariates` were provided to an `EnsembleModel` but none of its "
"base models support such covariates.",
logger,
)
28 changes: 18 additions & 10 deletions darts/models/forecasting/regression_ensemble_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@

from darts.logging import get_logger, raise_if, raise_if_not
from darts.models.forecasting.ensemble_model import EnsembleModel
from darts.models.forecasting.forecasting_model import (
GlobalForecastingModel,
LocalForecastingModel,
)
from darts.models.forecasting.forecasting_model import ForecastingModel
from darts.models.forecasting.linear_regression_model import LinearRegressionModel
from darts.models.forecasting.regression_model import RegressionModel
from darts.timeseries import TimeSeries
Expand All @@ -23,11 +20,10 @@
class RegressionEnsembleModel(EnsembleModel):
def __init__(
self,
forecasting_models: Union[
List[LocalForecastingModel], List[GlobalForecastingModel]
],
forecasting_models: List[ForecastingModel],
regression_train_n_points: int,
regression_model=None,
show_warnings: bool = True,
):
"""
Use a regression model for ensembling individual models' predictions.
Expand All @@ -38,6 +34,11 @@ def __init__(
as in :class:`RegressionModel`, where the regression model is used to produce forecasts based on the
lagged series.
If `future_covariates` or `past_covariates` are provided at training or inference time,
they will be passed only to the forecasting models supporting them.
The regression model does not leverage the covariates passed to ``fit()`` and ``predict()``.
Parameters
----------
forecasting_models
Expand All @@ -47,8 +48,10 @@ def __init__(
regression_model
Any regression model with ``predict()`` and ``fit()`` methods (e.g. from scikit-learn)
Default: ``darts.model.LinearRegressionModel(fit_intercept=False)``
show_warnings
Whether to show warnings related to forecasting_models covariates support.
"""
super().__init__(forecasting_models)
super().__init__(models=forecasting_models, show_warnings=show_warnings)
if regression_model is None:
regression_model = LinearRegressionModel(
lags=None, lags_future_covariates=[0], fit_intercept=False
Expand Down Expand Up @@ -115,10 +118,15 @@ def fit(
)

for model in self.models:
# maximize covariate usage
model._fit_wrapper(
series=forecast_training,
past_covariates=past_covariates,
future_covariates=future_covariates,
past_covariates=past_covariates
if model.supports_past_covariates
else None,
future_covariates=future_covariates
if model.supports_future_covariates
else None,
)

predictions = self._make_multiple_predictions(
Expand Down
8 changes: 8 additions & 0 deletions darts/models/forecasting/regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,14 @@ def lagged_feature_names(self) -> Optional[List[str]]:
def __str__(self):
return self.model.__str__()

@property
def supports_past_covariates(self) -> bool:
return len(self.lags.get("past", [])) > 0

@property
def supports_future_covariates(self) -> bool:
return len(self.lags.get("future", [])) > 0

@property
def supports_static_covariates(self) -> bool:
return True
Expand Down
2 changes: 1 addition & 1 deletion darts/models/forecasting/theta.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
`season_mode` must be a ``SeasonalityMode`` Enum member.
You can access the Enum with ``from darts import SeasonalityMode``.
You can access the Enum with ``from darts.utils.utils import SeasonalityMode``.
Parameters
----------
Expand Down
Loading

0 comments on commit ee53a83

Please sign in to comment.