Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/local and global models in EnsembleModel #1745

Merged
merged 24 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c5786cb
feat: remove restriction in EnsembleModel, models can be a mix of Loc…
madtoinou May 4, 2023
9dc9001
feat: EnsembleModel accepts a mixture of local and global models for …
madtoinou May 5, 2023
7467a4c
feat: updated unittests
madtoinou May 5, 2023
0ca18ca
doc: fix typo in docstring, SeasonalityMode must be imported from dar…
madtoinou May 5, 2023
838b40c
doc: updated changelog
madtoinou May 5, 2023
bad67e0
Merge branch 'master' into feat/local-and-global-ensemble
madtoinou May 8, 2023
f3d060e
Merge branch 'master' into feat/local-and-global-ensemble
dennisbader May 15, 2023
608fb03
Merge branch 'master' into feat/local-and-global-ensemble
madtoinou May 16, 2023
5035cf4
feat: logger info when all the models in the ensemble do not support …
madtoinou May 17, 2023
e7fa8e4
fix: typo, using parenthesis to call proterty method
madtoinou May 17, 2023
e48e754
Apply suggestions from code review
madtoinou May 22, 2023
ce6606b
fix: made the covariates handling in ensemble model more transparent,…
madtoinou May 22, 2023
62397ef
Merge branch 'feat/local-and-global-ensemble' of https://github.com/u…
madtoinou May 22, 2023
fe80f65
Merge branch 'master' into feat/local-and-global-ensemble
madtoinou May 22, 2023
9011104
Merge branch 'master' into feat/local-and-global-ensemble
dennisbader May 23, 2023
a0cf04e
Update CHANGELOG.md
dennisbader May 23, 2023
0423418
Update CHANGELOG.md
dennisbader May 23, 2023
1efbcc1
Update CHANGELOG.md
dennisbader May 23, 2023
29d39be
Apply suggestions from code review
madtoinou May 24, 2023
93a9cd7
fix: addressed reviewer comments, added show_warning arg to ensemble_…
madtoinou May 24, 2023
d4a5370
fix typo
madtoinou May 24, 2023
79ee5ff
fix: improve warning synthax
madtoinou May 24, 2023
33fd7c5
Merge branch 'master' into feat/local-and-global-ensemble
madtoinou May 24, 2023
6a0970d
Merge branch 'master' into feat/local-and-global-ensemble
dennisbader May 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: EnsembleModel accepts a mixture of local and global models for …
…single ts training/inference only. If provided, the covariates are passed only to the models supporting them (individually)
  • Loading branch information
madtoinou committed May 5, 2023
commit 9dc90010bb69f89d243bd32abe72dd8b6703c56c
15 changes: 6 additions & 9 deletions darts/models/forecasting/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,12 @@ def fit(
future_covariates=future_covariates,
)
for model in self.models:
if self.is_global_ensemble:
kwargs = dict(series=series)
if model.supports_past_covariates:
kwargs["past_covariates"] = past_covariates
if model.supports_future_covariates:
kwargs["future_covariates"] = future_covariates
model.fit(**kwargs)
else:
model.fit(series=series)
kwargs = dict(series=series)
if model.supports_past_covariates:
kwargs["past_covariates"] = past_covariates
if model.supports_future_covariates:
kwargs["future_covariates"] = future_covariates
model.fit(**kwargs)

return self

Expand Down
46 changes: 36 additions & 10 deletions darts/models/forecasting/ensemble_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class EnsembleModel(GlobalForecastingModel):
Ensemble models take in a list of forecasting models and ensemble their predictions
to make a single one according to the rule defined by their `ensemble()` method.

If `future_covariates` or `past_covariates` are provided at training or inference time,
they will be passed only to the models supporting them.

Parameters
----------
models
Expand All @@ -36,11 +39,26 @@ def __init__(self, models: List[ForecastingModel]):
logger,
)

self.is_local_ensemble = all(
isinstance(model, LocalForecastingModel) for model in models
)
self.is_global_ensemble = all(
is_local_model = [isinstance(model, LocalForecastingModel) for model in models]
is_global_model = [
isinstance(model, GlobalForecastingModel) for model in models
]

self.is_local_ensemble = all(is_local_model)
self.is_global_ensemble = all(is_global_model)

raise_if_not(
all(
[
local_model or global_model
for local_model, global_model in zip(
is_local_model, is_global_model
)
]
),
"Cannot instantiate EnsembleModel with forecasting models that are neither local, "
"nor global. Also, please make sure that all the models in `models` are instantiated.",
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
logger,
)

raise_if(
Expand All @@ -64,9 +82,14 @@ def fit(
Note that `EnsembleModel.fit()` does NOT call `fit()` on each of its constituent forecasting models.
It is left to classes inheriting from EnsembleModel to do so appropriately when overriding `fit()`
"""

is_single_series = isinstance(series, TimeSeries)

# local models OR mix of local and global models
raise_if(
self.is_local_ensemble and not isinstance(series, TimeSeries),
"The models are of type LocalForecastingModel, which does not support training on multiple series.",
not self.is_global_ensemble and not is_single_series,
"The models contain at least one LocalForecastingModel, which does not support training on multiple "
"series.",
logger,
)
raise_if(
Expand All @@ -75,8 +98,6 @@ def fit(
logger,
)

is_single_series = isinstance(series, TimeSeries)

# check that if timeseries is single series, that covariates are as well and vice versa
error_past_cov = False
error_future_cov = False
Expand Down Expand Up @@ -121,12 +142,17 @@ def _make_multiple_predictions(
num_samples: int = 1,
):
is_single_series = isinstance(series, TimeSeries) or series is None
# maximize covariate usage
predictions = [
model._predict_wrapper(
n=n,
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
past_covariates=past_covariates
if model.supports_past_covariates
else None,
future_covariates=future_covariates
if model.supports_future_covariates
else None,
num_samples=num_samples,
)
for model in self.models
Expand Down
9 changes: 7 additions & 2 deletions darts/models/forecasting/regression_ensemble_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,15 @@ def fit(
)

for model in self.models:
# maximize covariate usage
model._fit_wrapper(
series=forecast_training,
past_covariates=past_covariates,
future_covariates=future_covariates,
past_covariates=past_covariates
if model.supports_past_covariates
else None,
future_covariates=future_covariates
if model.supports_future_covariates
else None,
)

predictions = self._make_multiple_predictions(
Expand Down