Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/static covs #966

Merged
merged 29 commits into from
Jun 5, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
125c078
added methods ``from_longitudinal_dataframe` and `add_static_covariates`
dennisbader Apr 13, 2022
fde974e
dataset adaption for static covs
dennisbader Apr 23, 2022
d6d4885
extended datasets for static covariates support and unified variable …
dennisbader May 19, 2022
4717ff1
adapted PLXCovariatesModules with static covariates
dennisbader May 19, 2022
5b6b781
adapted TFTModel for static covariate support
dennisbader May 20, 2022
55eaf3b
added temporary fix for static covariates with scalers
dennisbader May 20, 2022
3ced3da
Merge branch 'master' into feat/static_covs
dennisbader May 20, 2022
29924f4
unittests for from_longitudinal_dataframe() and set_static_covariates
dennisbader May 24, 2022
079d969
updated dataset tests
dennisbader May 24, 2022
3511b81
fixed all downstream issues from new static covariates in datasets
dennisbader May 27, 2022
eacaf3b
added check for equal static covariates between fit and predict
dennisbader May 28, 2022
55c5090
added tests for passing static covariates in TimeSeries methods
dennisbader May 28, 2022
cc07f5f
added static covariate support for stacking TimeSeries
dennisbader May 28, 2022
0aacd5a
transpose static covariates
dennisbader May 29, 2022
2845f86
added method `static_covariates_values()`
dennisbader May 29, 2022
2ac58e4
updated docs
dennisbader May 29, 2022
a6fa4fb
static covariate support for concatenation
dennisbader May 30, 2022
a4ba617
static covariate support for concatenation
dennisbader May 30, 2022
0586b7d
static covariates are now passed to the torch models
dennisbader May 30, 2022
c18e806
non-numerical dtype support for static covariates
dennisbader May 31, 2022
a048ecc
added slicing support for static covariates
dennisbader May 31, 2022
3661385
multicomponent static covariate support for TFT
dennisbader May 31, 2022
5b9258b
Merge branch 'master' into feat/static_covs
dennisbader May 31, 2022
3a9ad83
added arithmetic static covariate support
dennisbader May 31, 2022
d00c08d
Merge branch 'master' into feat/static_covs
dennisbader Jun 3, 2022
f5fa989
updated all timeseries methods/operations with static cov transfer
dennisbader Jun 4, 2022
41adf3f
applied suggestion from PR review part 1
dennisbader Jun 4, 2022
6dc7ff8
apply suggestions from code review part 2
dennisbader Jun 4, 2022
d001e17
fix black issue from PR suggestion
dennisbader Jun 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
multicomponent static covariate support for TFT
  • Loading branch information
dennisbader committed May 31, 2022
commit 3661385fe74bddff4c70a059e180fa99ceec9000
20 changes: 16 additions & 4 deletions darts/models/forecasting/tft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
self,
output_dim: Tuple[int, int],
variables_meta: Dict[str, Dict[str, List[str]]],
num_static_components: int,
hidden_size: Union[int, List[int]] = 16,
lstm_layers: int = 1,
num_attention_heads: int = 4,
Expand All @@ -60,6 +61,9 @@ def __init__(
shape of output given by (n_targets, loss_size). (loss_size corresponds to nr_params in other models).
variables_meta : Dict[str, Dict[str, List[str]]]
dict containing variable enocder, decoder variable names for mapping tensors in `_TFTModule.forward()`
num_static_components
the number of static components (not variables) of the input target series. This is either equal to the
number of target components or 1.
hidden_size : int
hidden state size of the TFT. It is the main hyper-parameter and common across the internal TFT
architecture.
Expand Down Expand Up @@ -90,6 +94,7 @@ def __init__(

self.n_targets, self.loss_size = output_dim
self.variables_meta = variables_meta
self.num_static_components = num_static_components
self.hidden_size = hidden_size
self.hidden_continuous_size = hidden_continuous_size
self.lstm_layers = lstm_layers
Expand All @@ -113,7 +118,11 @@ def __init__(
# # processing inputs
# continuous variable processing
self.prescalers_linear = {
name: nn.Linear(1, self.hidden_continuous_size) for name in self.reals
name: nn.Linear(
1 if name not in self.static_variables else self.num_static_components,
self.hidden_continuous_size,
)
for name in self.reals
}

static_input_sizes = {
Expand Down Expand Up @@ -412,8 +421,7 @@ def forward(
# Embedding and variable selection
if self.static_variables:
static_embedding = {
name: x_static[:, 0, i].unsqueeze(-1)
for i, name in enumerate(self.static_variables)
name: x_static[:, :, i] for i, name in enumerate(self.static_variables)
}
static_embedding, static_covariate_var = self.static_covariates_vsn(
static_embedding
Expand Down Expand Up @@ -864,9 +872,13 @@ def _create_model(self, train_sample: MixedCovariatesTrainTensorType) -> nn.Modu
dict.fromkeys(static_input)
)

n_static_components = (
len(static_covariates) if static_covariates is not None else 0
)
return _TFTModule(
variables_meta=variables_meta,
output_dim=self.output_dim,
variables_meta=variables_meta,
num_static_components=n_static_components,
hidden_size=self.hidden_size,
lstm_layers=self.lstm_layers,
dropout=self.dropout,
Expand Down
31 changes: 17 additions & 14 deletions darts/tests/models/forecasting/test_TFT.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pandas as pd
import pytest

from darts import TimeSeries
from darts import TimeSeries, concatenate
from darts.dataprocessing.transformers import Scaler
from darts.logging import get_logger
from darts.tests.base_test_class import DartsBaseTestClass
Expand Down Expand Up @@ -166,36 +166,39 @@ def test_mixed_covariates_and_accuracy(self):
)

def test_static_covariates_support(self):
target = tg.sine_timeseries(length=2, freq="h")
target = target.with_static_covariates(
pd.Series([0.0, 1.0], index=["st1", "st2"])
target_multi = concatenate(
[tg.sine_timeseries(length=10, freq="h")] * 2, axis=1
)

target_multi = target_multi.with_static_covariates(
pd.DataFrame([[0.0, 1.0], [2.0, 3.0]], index=["st1", "st2"])
)

# should work with cyclic encoding for time index
model = TFTModel(
input_chunk_length=1,
output_chunk_length=1,
input_chunk_length=3,
output_chunk_length=4,
add_encoders={"cyclic": {"future": "hour"}},
pl_trainer_kwargs={"fast_dev_run": True},
dennisbader marked this conversation as resolved.
Show resolved Hide resolved
)
model.fit(target, verbose=False)
model.fit(target_multi, verbose=False)
assert len(model.model.static_variables) == len(
target.static_covariates.columns
target_multi.static_covariates.columns
)

model.predict(n=1, series=target, verbose=False)
model.predict(n=1, series=target_multi, verbose=False)

# raise an error when trained with static covariates of wrong dimensionality
target = target.with_static_covariates(
pd.concat([target.static_covariates] * 2, axis=1)
target_multi = target_multi.with_static_covariates(
pd.concat([target_multi.static_covariates] * 2, axis=1)
)
with pytest.raises(ValueError):
model.predict(n=1, series=target, verbose=False)
model.predict(n=1, series=target_multi, verbose=False)

# raise an error when trained with static covariates and trying to predict without
target = target.with_static_covariates(None)
target_multi = target_multi.with_static_covariates(None)
with pytest.raises(ValueError):
model.predict(n=1, series=target, verbose=False)
model.predict(n=1, series=target_multi, verbose=False)

def helper_generate_multivariate_case_data(self, season_length, n_repeat):
"""generates multivariate test case data. Target series is a sine wave stacked with a repeating
Expand Down