Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/local and global models in EnsembleModel #1745

Merged
merged 24 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c5786cb
feat: remove restriction in EnsembleModel, models can be a mix of Loc…
madtoinou May 4, 2023
9dc9001
feat: EnsembleModel accepts a mixture of local and global models for …
madtoinou May 5, 2023
7467a4c
feat: updated unittests
madtoinou May 5, 2023
0ca18ca
doc: fix typo in docstring, SeasonalityMode must be imported from dar…
madtoinou May 5, 2023
838b40c
doc: updated changelog
madtoinou May 5, 2023
bad67e0
Merge branch 'master' into feat/local-and-global-ensemble
madtoinou May 8, 2023
f3d060e
Merge branch 'master' into feat/local-and-global-ensemble
dennisbader May 15, 2023
608fb03
Merge branch 'master' into feat/local-and-global-ensemble
madtoinou May 16, 2023
5035cf4
feat: logger info when all the models in the ensemble do not support …
madtoinou May 17, 2023
e7fa8e4
fix: typo, using parenthesis to call proterty method
madtoinou May 17, 2023
e48e754
Apply suggestions from code review
madtoinou May 22, 2023
ce6606b
fix: made the covariates handling in ensemble model more transparent,…
madtoinou May 22, 2023
62397ef
Merge branch 'feat/local-and-global-ensemble' of https://github.com/u…
madtoinou May 22, 2023
fe80f65
Merge branch 'master' into feat/local-and-global-ensemble
madtoinou May 22, 2023
9011104
Merge branch 'master' into feat/local-and-global-ensemble
dennisbader May 23, 2023
a0cf04e
Update CHANGELOG.md
dennisbader May 23, 2023
0423418
Update CHANGELOG.md
dennisbader May 23, 2023
1efbcc1
Update CHANGELOG.md
dennisbader May 23, 2023
29d39be
Apply suggestions from code review
madtoinou May 24, 2023
93a9cd7
fix: addressed reviewer comments, added show_warning arg to ensemble_…
madtoinou May 24, 2023
d4a5370
fix typo
madtoinou May 24, 2023
79ee5ff
fix: improve warning synthax
madtoinou May 24, 2023
33fd7c5
Merge branch 'master' into feat/local-and-global-ensemble
madtoinou May 24, 2023
6a0970d
Merge branch 'master' into feat/local-and-global-ensemble
dennisbader May 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,22 @@ We do our best to avoid the introduction of breaking changes,
but cannot always guarantee backwards compatibility. Changes that may **break code which uses a previous release of Darts** are marked with a "🔴".

## [Unreleased](https://github.com/unit8co/darts/tree/master)
**Fixed**
- Fixed a bug when loading the weights of a model trained with encoders or a Likelihood. [#1744](https://github.com/unit8co/darts/pull/1744) by [Antoine Madrona](https://github.com/madtoinou).

[Full Changelog](https://github.com/unit8co/darts/compare/0.24.0...master)

### For users of the library:

**Improved**
- Added support for `PathLike` to the `save()` and `load()` functions of `ForecastingModel`. [#1754](https://github.com/unit8co/darts/pull/1754) by [Simon Sudrich](https://github.com/sudrich).
- Fixed an issue with `TorchForecastingModel.load_from_checkpoint()` not properly loading the loss function and metrics. [#1749](https://github.com/unit8co/darts/pull/1749) by [Antoine Madrona](https://github.com/madtoinou).
- General model improvements:
- Added support for `PathLike` to the `save()` and `load()` functions of all non-deep learning based models. [#1754](https://github.com/unit8co/darts/pull/1754) by [Simon Sudrich](https://github.com/sudrich).
- Improvements to `EnsembleModel`:
- Model creation parameter `forecasting_models` now supports a mix of `LocalForecastingModel` and `GlobalForecastingModel` (single `TimeSeries` training/inference only, due to the local models). [#1745](https://github.com/unit8co/darts/pull/1745) by [Antoine Madrona](https://github.com/madtoinou).
- Future and past covariates can now be used even if `forecasting_models` have different covariates support. The covariates passed to `fit()`/`predict()` are used only by models that support it. [#1745](https://github.com/unit8co/darts/pull/1745) by [Antoine Madrona](https://github.com/madtoinou).

**Fixed**
- Fixed an issue not considering original component names for `TimeSeries.plot()` when providing a label prefix. [#1783](https://github.com/unit8co/darts/pull/1783) by [Simon Sudrich](https://github.com/sudrich).
- Fixed an issue with `TorchForecastingModel.load_from_checkpoint()` not properly loading the loss function and metrics. [#1749](https://github.com/unit8co/darts/pull/1749) by [Antoine Madrona](https://github.com/madtoinou).
- Fixed a bug when loading the weights of a `TorchForecastingModel` trained with encoders or a Likelihood. [#1744](https://github.com/unit8co/darts/pull/1744) by [Antoine Madrona](https://github.com/madtoinou).

## [0.24.0](https://github.com/unit8co/darts/tree/0.24.0) (2023-04-12)
### For users of the library:
Expand Down
31 changes: 20 additions & 11 deletions darts/models/forecasting/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,26 @@ def predict(self, n: int, num_samples: int = 1, verbose: bool = False):

class NaiveEnsembleModel(EnsembleModel):
def __init__(
self, models: Union[List[LocalForecastingModel], List[GlobalForecastingModel]]
self,
models: Union[List[LocalForecastingModel], List[GlobalForecastingModel]],
show_warnings: bool = True,
):
"""Naive combination model

Naive implementation of `EnsembleModel`
Returns the average of all predictions of the constituent models

If `future_covariates` or `past_covariates` are provided at training or inference time,
they will be passed only to the models supporting them.

Parameters
----------
models
List of forecasting models whose predictions to ensemble
show_warnings
Whether to show warnings related to models covariates support.
"""
super().__init__(models)
super().__init__(models=models, show_warnings=show_warnings)

def fit(
self,
Expand All @@ -184,15 +196,12 @@ def fit(
future_covariates=future_covariates,
)
for model in self.models:
if self.is_global_ensemble:
kwargs = dict(series=series)
if model.supports_past_covariates:
kwargs["past_covariates"] = past_covariates
if model.supports_future_covariates:
kwargs["future_covariates"] = future_covariates
model.fit(**kwargs)
else:
model.fit(series=series)
kwargs = dict(series=series)
if model.supports_past_covariates:
kwargs["past_covariates"] = past_covariates
if model.supports_future_covariates:
kwargs["future_covariates"] = future_covariates
model.fit(**kwargs)

return self

Expand Down
129 changes: 102 additions & 27 deletions darts/models/forecasting/ensemble_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from darts.logging import get_logger, raise_if, raise_if_not
from darts.models.forecasting.forecasting_model import (
ForecastingModel,
GlobalForecastingModel,
LocalForecastingModel,
)
Expand All @@ -22,31 +23,43 @@ class EnsembleModel(GlobalForecastingModel):
Ensemble models take in a list of forecasting models and ensemble their predictions
to make a single one according to the rule defined by their `ensemble()` method.

If `future_covariates` or `past_covariates` are provided at training or inference time,
they will be passed only to the models supporting them.

Parameters
----------
models
List of forecasting models whose predictions to ensemble
show_warnings
Whether to show warnings related to models covariates support.
"""

def __init__(
self, models: Union[List[LocalForecastingModel], List[GlobalForecastingModel]]
):
def __init__(self, models: List[ForecastingModel], show_warnings: bool = True):
raise_if_not(
isinstance(models, list) and models,
"Cannot instantiate EnsembleModel with an empty list of models",
logger,
)

is_local_ensemble = all(
isinstance(model, LocalForecastingModel) for model in models
)
self.is_global_ensemble = all(
is_local_model = [isinstance(model, LocalForecastingModel) for model in models]
is_global_model = [
isinstance(model, GlobalForecastingModel) for model in models
)
]

self.is_local_ensemble = all(is_local_model)
self.is_global_ensemble = all(is_global_model)

raise_if_not(
is_local_ensemble or self.is_global_ensemble,
"All models must be of the same type: either GlobalForecastingModel, or LocalForecastingModel.",
all(
[
local_model or global_model
for local_model, global_model in zip(
is_local_model, is_global_model
)
]
),
"All models must be of type `GlobalForecastingModel`, or `LocalForecastingModel`. "
"Also, make sure that all models in `forecasting_model/models` are instantiated.",
logger,
)

Expand All @@ -60,6 +73,27 @@ def __init__(
super().__init__()
self.models = models

if show_warnings:
if (
self.supports_past_covariates
and not self._full_past_covariates_support()
):
logger.warning(
"Some models in the ensemble do not support past covariates, the past covariates will be "
"provided only to the models supporting them when calling fit()` or `predict()`. "
"To hide these warnings, set `show_warnings=False`."
)

if (
self.supports_future_covariates
and not self._full_future_covariates_support()
):
logger.warning(
"Some models in the ensemble do not support future covariates, the future covariates will be "
"provided only to the models supporting them when calling `fit()` or `predict()`. "
"To hide these warnings, set `show_warnings=False`."
)

def fit(
self,
series: Union[TimeSeries, Sequence[TimeSeries]],
Expand All @@ -71,34 +105,37 @@ def fit(
Note that `EnsembleModel.fit()` does NOT call `fit()` on each of its constituent forecasting models.
It is left to classes inheriting from EnsembleModel to do so appropriately when overriding `fit()`
"""

is_single_series = isinstance(series, TimeSeries)

# local models OR mix of local and global models
raise_if(
not self.is_global_ensemble and not isinstance(series, TimeSeries),
"The models are of type LocalForecastingModel, which does not support training on multiple series.",
logger,
)
raise_if(
not self.is_global_ensemble and past_covariates is not None,
"The models are of type LocalForecastingModel, which does not support past covariates.",
not self.is_global_ensemble and not is_single_series,
"The models contain at least one LocalForecastingModel, which does not support training on multiple "
"series.",
logger,
)

is_single_series = isinstance(series, TimeSeries)

# check that if timeseries is single series, than covariates are as well and vice versa
error = False
# check that if timeseries is single series, that covariates are as well and vice versa
error_past_cov = False
error_future_cov = False

if past_covariates is not None:
error = is_single_series != isinstance(past_covariates, TimeSeries)
error_past_cov = is_single_series != isinstance(past_covariates, TimeSeries)

if future_covariates is not None:
error = is_single_series != isinstance(future_covariates, TimeSeries)
error_future_cov = is_single_series != isinstance(
future_covariates, TimeSeries
)

raise_if(
error,
"Both series and covariates have to be either univariate or multivariate.",
error_past_cov or error_future_cov,
"Both series and covariates have to be either single TimeSeries or sequences of TimeSeries.",
logger,
)

self._verify_past_future_covariates(past_covariates, future_covariates)

super().fit(series, past_covariates, future_covariates)

return self
Expand All @@ -125,12 +162,17 @@ def _make_multiple_predictions(
num_samples: int = 1,
):
is_single_series = isinstance(series, TimeSeries) or series is None
# maximize covariate usage
predictions = [
model._predict_wrapper(
n=n,
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
past_covariates=past_covariates
if model.supports_past_covariates
else None,
future_covariates=future_covariates
if model.supports_future_covariates
else None,
num_samples=num_samples,
)
for model in self.models
Expand Down Expand Up @@ -160,6 +202,8 @@ def predict(
verbose=verbose,
)

self._verify_past_future_covariates(past_covariates, future_covariates)

predictions = self._make_multiple_predictions(
n=n,
series=series,
Expand Down Expand Up @@ -229,3 +273,34 @@ def find_max_lag_or_none(lag_id, aggregator) -> Optional[int]:

def _is_probabilistic(self) -> bool:
return all([model._is_probabilistic() for model in self.models])

@property
def supports_past_covariates(self) -> bool:
return any([model.supports_past_covariates for model in self.models])

@property
def supports_future_covariates(self) -> bool:
return any([model.supports_future_covariates for model in self.models])

def _full_past_covariates_support(self) -> bool:
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
return all([model.supports_past_covariates for model in self.models])

def _full_future_covariates_support(self) -> bool:
return all([model.supports_future_covariates for model in self.models])

def _verify_past_future_covariates(self, past_covariates, future_covariates):
"""
Verify that any non-None covariates comply with the model type.
"""
raise_if(
past_covariates is not None and not self.supports_past_covariates,
"`past_covariates` were provided to an `EnsembleModel` but none of its "
"base models support such covariates.",
logger,
)
raise_if(
future_covariates is not None and not self.supports_future_covariates,
"`future_covariates` were provided to an `EnsembleModel` but none of its "
"base models support such covariates.",
logger,
)
28 changes: 18 additions & 10 deletions darts/models/forecasting/regression_ensemble_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@

from darts.logging import get_logger, raise_if, raise_if_not
from darts.models.forecasting.ensemble_model import EnsembleModel
from darts.models.forecasting.forecasting_model import (
GlobalForecastingModel,
LocalForecastingModel,
)
from darts.models.forecasting.forecasting_model import ForecastingModel
from darts.models.forecasting.linear_regression_model import LinearRegressionModel
from darts.models.forecasting.regression_model import RegressionModel
from darts.timeseries import TimeSeries
Expand All @@ -23,11 +20,10 @@
class RegressionEnsembleModel(EnsembleModel):
def __init__(
self,
forecasting_models: Union[
List[LocalForecastingModel], List[GlobalForecastingModel]
],
forecasting_models: List[ForecastingModel],
regression_train_n_points: int,
regression_model=None,
show_warnings: bool = True,
):
"""
Use a regression model for ensembling individual models' predictions.
Expand All @@ -38,6 +34,11 @@ def __init__(
as in :class:`RegressionModel`, where the regression model is used to produce forecasts based on the
lagged series.

If `future_covariates` or `past_covariates` are provided at training or inference time,
they will be passed only to the forecasting models supporting them.

The regression model does not leverage the covariates passed to ``fit()`` and ``predict()``.

Parameters
----------
forecasting_models
Expand All @@ -47,8 +48,10 @@ def __init__(
regression_model
Any regression model with ``predict()`` and ``fit()`` methods (e.g. from scikit-learn)
Default: ``darts.model.LinearRegressionModel(fit_intercept=False)``
show_warnings
Whether to show warnings related to forecasting_models covariates support.
"""
super().__init__(forecasting_models)
super().__init__(models=forecasting_models, show_warnings=show_warnings)
if regression_model is None:
regression_model = LinearRegressionModel(
lags=None, lags_future_covariates=[0], fit_intercept=False
Expand Down Expand Up @@ -115,10 +118,15 @@ def fit(
)

for model in self.models:
# maximize covariate usage
model._fit_wrapper(
series=forecast_training,
past_covariates=past_covariates,
future_covariates=future_covariates,
past_covariates=past_covariates
if model.supports_past_covariates
else None,
future_covariates=future_covariates
if model.supports_future_covariates
else None,
)

predictions = self._make_multiple_predictions(
Expand Down
8 changes: 8 additions & 0 deletions darts/models/forecasting/regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,14 @@ def lagged_feature_names(self) -> Optional[List[str]]:
def __str__(self):
return self.model.__str__()

@property
def supports_past_covariates(self) -> bool:
return len(self.lags.get("past", [])) > 0

@property
def supports_future_covariates(self) -> bool:
return len(self.lags.get("future", [])) > 0

@property
def supports_static_covariates(self) -> bool:
return True
Expand Down
2 changes: 1 addition & 1 deletion darts/models/forecasting/theta.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(

`season_mode` must be a ``SeasonalityMode`` Enum member.

You can access the Enum with ``from darts import SeasonalityMode``.
You can access the Enum with ``from darts.utils.utils import SeasonalityMode``.
madtoinou marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
Expand Down
Loading