Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/dlinear and nlinear use_static_cov. with multivariate series #2070

Merged
merged 7 commits into from
Nov 18, 2023
Next Next commit
feat: added test to check that use_static_covariates covers all possi…
…ble static covariates representations
  • Loading branch information
madtoinou committed Nov 16, 2023
commit 64164a2049ef54c31ded087bc3b79b338ae1ae45
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from copy import deepcopy
from itertools import product
from unittest.mock import ANY, patch

import numpy as np
Expand Down Expand Up @@ -206,6 +207,15 @@ class TestGlobalForecastingModels:
target = sine_1_ts + sine_2_ts + linear_ts + sine_3_ts
target_past, target_future = target.split_after(split_ratio)

# various ts with different static covariates representations
ts_w_static_cov = tg.linear_timeseries(length=80).with_static_covariates(
pd.Series([1, 2])
)
ts_shared_static_cov = ts_w_static_cov.stack(tg.sine_timeseries(length=80))
ts_comps_static_cov = ts_shared_static_cov.with_static_covariates(
pd.DataFrame([[0, 1], [2, 3]], columns=["st1", "st2"])
)

@pytest.mark.parametrize("config", models_cls_kwargs_errs)
def test_save_model_parameters(self, config):
# model creation parameters were saved before. check if re-created model has same params as original
Expand Down Expand Up @@ -450,6 +460,37 @@ def test_future_covariates(self):
with pytest.raises(ValueError):
model.predict(n=161, future_covariates=self.covariates)

@pytest.mark.parametrize(
"model_cls,ts",
product(
[TFTModel, DLinearModel, NLinearModel, TiDEModel],
[ts_w_static_cov, ts_shared_static_cov, ts_comps_static_cov],
),
)
def test_use_static_covariates(self, model_cls, ts):
"""
Check that both static covariates representations are supported (component-specific and shared)
for both uni- and multivariate series when fitting the model.
Also check that the static covariates are present in the forecasted series
"""
model = model_cls(
input_chunk_length=IN_LEN,
output_chunk_length=OUT_LEN,
random_state=0,
use_static_covariates=True,
n_epochs=1,
**tfm_kwargs,
)
# must provide mandatory future_covariates to TFTModel
model.fit(
series=ts,
future_covariates=self.sine_1_ts
if model.supports_future_covariates
else None,
)
pred = model.predict(OUT_LEN)
assert pred.static_covariates.equals(ts.static_covariates)

def test_batch_predictions(self):
# predicting multiple time series at once needs to work for arbitrary batch sizes
# univariate case
Expand Down