Skip to content

Commit

Permalink
Feat/static covs (#966)
Browse files Browse the repository at this point in the history
* added methods ``from_longitudinal_dataframe` and `add_static_covariates`

* dataset adaption for static covs

* extended datasets for static covariates support and unified variable names
 Please enter the commit message for your changes. Lines starting

* adapted PLXCovariatesModules with static covariates

* adapted TFTModel for static covariate support

* added temporary fix for static covariates with scalers

* unittests for from_longitudinal_dataframe() and set_static_covariates

* updated dataset tests

* fixed all downstream issues from new static covariates in datasets

* added check for equal static covariates between fit and predict

* added tests for passing static covariates in TimeSeries methods

* added static covariate support for stacking TimeSeries

* transpose static covariates

* added method `static_covariates_values()`

* updated docs

* static covariate support for concatenation

* static covariate support for concatenation

* static covariates are now passed to the torch models

* non-numerical dtype support for static covariates

* added slicing support for static covariates

* multicomponent static covariate support for TFT

* added arithmetic static covariate support

* updated all timeseries methods/operations with static cov transfer

* applied suggestion from PR review part 1

* apply suggestions from code review part 2

Co-authored-by: Julien Herzen <julien@unit8.co>

* fix black issue from PR suggestion

Co-authored-by: Julien Herzen <julien@unit8.co>
  • Loading branch information
dennisbader and hrzn authored Jun 5, 2022
1 parent cad8055 commit 6c90980
Show file tree
Hide file tree
Showing 33 changed files with 2,163 additions and 503 deletions.
7 changes: 7 additions & 0 deletions darts/dataprocessing/transformers/boxcox.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ def __init__(
For stochastic series, it is done jointly over all samples, effectively merging all samples of
a component in order to compute the transform.
Notes
-----
The scaler will not scale the series' static covariates. This has to be done either before constructing the
series, or later on by extracting the covariates, transforming the values and then reapplying them to the
series. For this, see TimeSeries properties `TimeSeries.static_covariates` and method
`TimeSeries.with_static_covariates()`
Parameters
----------
name
Expand Down
9 changes: 9 additions & 0 deletions darts/dataprocessing/transformers/scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ def __init__(
The transformation is applied independently for each dimension (component) of the time series,
effectively merging all samples of a component in order to compute the transform.
Notes
-----
The scaler will not scale the series' static covariates. This has to be done either before constructing the
series, or later on by extracting the covariates, transforming the values and then reapplying them to the
series. For this, see TimeSeries properties `TimeSeries.static_covariates` and method
`TimeSeries.with_static_covariates()`
Parameters
----------
scaler
Expand Down Expand Up @@ -106,6 +113,7 @@ def ts_transform(series: TimeSeries, transformer, **kwargs) -> TimeSeries:
values=transformed_vals,
fill_missing_dates=False,
columns=series.columns,
static_covariates=series.static_covariates,
)

@staticmethod
Expand All @@ -126,6 +134,7 @@ def ts_inverse_transform(
values=inv_transformed_vals,
fill_missing_dates=False,
columns=series.columns,
static_covariates=series.static_covariates,
)

@staticmethod
Expand Down
6 changes: 5 additions & 1 deletion darts/models/filtering/gaussian_process_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,8 @@ def filter(self, series: TimeSeries, num_samples: int = 1) -> TimeSeries:
filtered_values = self.model.sample_y(times, n_samples=num_samples)

filtered_values = filtered_values.reshape(len(times), -1, num_samples)
return TimeSeries.from_times_and_values(series.time_index, filtered_values)
return TimeSeries.from_times_and_values(
series.time_index,
filtered_values,
static_covariates=series.static_covariates,
)
5 changes: 4 additions & 1 deletion darts/models/filtering/kalman_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,5 +250,8 @@ def filter(
).T

return TimeSeries.from_times_and_values(
series.time_index, sampled_outputs, columns=series.columns
series.time_index,
sampled_outputs,
columns=series.columns,
static_covariates=series.static_covariates,
)
4 changes: 3 additions & 1 deletion darts/models/filtering/moving_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,6 @@ def filter(self, series: TimeSeries):
.rolling(window=self.window, min_periods=1, center=self.centered)
.mean()
)
return TimeSeries.from_dataframe(filtered_df)
return TimeSeries.from_dataframe(
filtered_df, static_covariates=series.static_covariates
)
8 changes: 6 additions & 2 deletions darts/models/forecasting/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,14 @@ def ensemble(
) -> Union[TimeSeries, Sequence[TimeSeries]]:
if isinstance(predictions, Sequence):
return [
TimeSeries.from_series(p.pd_dataframe().sum(axis=1) / len(self.models))
TimeSeries.from_series(
p.pd_dataframe().sum(axis=1) / len(self.models),
static_covariates=p.static_covariates,
)
for p in predictions
]
else:
return TimeSeries.from_series(
predictions.pd_dataframe().sum(axis=1) / len(self.models)
predictions.pd_dataframe().sum(axis=1) / len(self.models),
static_covariates=predictions.static_covariates,
)
3 changes: 2 additions & 1 deletion darts/models/forecasting/block_rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def __init__(
last = feature
self.fc = nn.Sequential(*feats)

def forward(self, x):
def forward(self, x_in: Tuple):
x, _ = x_in
# data is of size (batch_size, input_chunk_length, input_size)
batch_size = x.size(0)

Expand Down
2 changes: 2 additions & 0 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ def historical_forecasts(
return TimeSeries.from_times_and_values(
pd.DatetimeIndex(last_points_times, freq=series.freq * stride),
np.array(last_points_values),
static_covariates=series.static_covariates,
)
else:
return TimeSeries.from_times_and_values(
Expand All @@ -470,6 +471,7 @@ def historical_forecasts(
step=1,
),
np.array(last_points_values),
static_covariates=series.static_covariates,
)

return forecasts
Expand Down
6 changes: 5 additions & 1 deletion darts/models/forecasting/kalman_forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,11 @@ def _predict(

time_index = self._generate_new_dates(n)
placeholder_vals = np.zeros((n, self.training_series.width)) * np.nan
series_future = TimeSeries.from_times_and_values(time_index, placeholder_vals)
series_future = TimeSeries.from_times_and_values(
time_index,
placeholder_vals,
static_covariates=self.training_series.static_covariates,
)
whole_series = self.training_series.append(series_future)
filtered_series = self.darts_kf.filter(
whole_series, covariates=future_covariates, num_samples=num_samples
Expand Down
3 changes: 2 additions & 1 deletion darts/models/forecasting/nbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,8 @@ def __init__(
self.stacks_list[-1].blocks[-1].backcast_linear_layer.requires_grad_(False)
self.stacks_list[-1].blocks[-1].backcast_g.requires_grad_(False)

def forward(self, x):
def forward(self, x_in: Tuple):
x, _ = x_in

# if x1, x2,... y1, y2... is one multivariate ts containing x and y, and a1, a2... one covariate ts
# we reshape into x1, y1, a1, x2, y2, a2... etc
Expand Down
3 changes: 2 additions & 1 deletion darts/models/forecasting/nhits.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,8 @@ def __init__(
# on this params (the last block backcast is not part of the final output of the net).
self.stacks_list[-1].blocks[-1].backcast_linear_layer.requires_grad_(False)

def forward(self, x):
def forward(self, x_in: Tuple):
x, _ = x_in

# if x1, x2,... y1, y2... is one multivariate ts containing x and y, and a1, a2... one covariate ts
# we reshape into x1, y1, a1, x2, y2, a2... etc
Expand Down
80 changes: 45 additions & 35 deletions darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def _sample_tiling(input_data_tuple, batch_sample_size):
def _is_probabilistic(self) -> bool:
return self.likelihood is not None

def _produce_predict_output(self, x):
def _produce_predict_output(self, x: Tuple):
if self.likelihood:
output = self(x)
return self.likelihood.sample(output)
Expand Down Expand Up @@ -342,12 +342,22 @@ def epochs_trained(self):

class PLPastCovariatesModule(PLForecastingModule, ABC):
def _produce_train_output(self, input_batch: Tuple):
past_target, past_covariate = input_batch
"""
Feeds PastCovariatesTorchModel with input and output chunks of a PastCovariatesSequentialDataset for
training.
Parameters:
----------
input_batch
``(past_target, past_covariates, static_covariates)``
"""
past_target, past_covariates, static_covariates = input_batch
# Currently all our PastCovariates models require past target and covariates concatenated
inpt = (
torch.cat([past_target, past_covariate], dim=2)
if past_covariate is not None
else past_target
torch.cat([past_target, past_covariates], dim=2)
if past_covariates is not None
else past_target,
static_covariates,
)
return self(inpt)

Expand All @@ -363,13 +373,18 @@ def _get_batch_prediction(
n
prediction length
input_batch
(past_target, past_covariates, future_past_covariates)
``(past_target, past_covariates, future_past_covariates, static_covariates)``
roll_size
roll input arrays after every sequence by ``roll_size``. Initially, ``roll_size`` is equivalent to
``self.output_chunk_length``
"""
dim_component = 2
past_target, past_covariates, future_past_covariates = input_batch
(
past_target,
past_covariates,
future_past_covariates,
static_covariates,
) = input_batch

n_targets = past_target.shape[dim_component]
n_past_covs = (
Expand All @@ -381,7 +396,7 @@ def _get_batch_prediction(
dim=dim_component,
)

out = self._produce_predict_output(input_past)[
out = self._produce_predict_output((input_past, static_covariates))[
:, self.first_prediction_index :, :
]

Expand Down Expand Up @@ -430,7 +445,7 @@ def _get_batch_prediction(
] = future_past_covariates[:, left_past:right_past, :]

# take only last part of the output sequence where needed
out = self._produce_predict_output(input_past)[
out = self._produce_predict_output((input_past, static_covariates))[
:, self.first_prediction_index :, :
]
batch_prediction.append(out)
Expand Down Expand Up @@ -462,63 +477,56 @@ class PLMixedCovariatesModule(PLForecastingModule, ABC):
def _produce_train_output(
self, input_batch: Tuple
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Feeds MixedCovariatesTorchModel with input and output chunks of a MixedCovariatesSequentialDataset for
training.
Parameters:
----------
input_batch
``(past_target, past_covariates, historic_future_covariates, future_covariates, static_covariates)``.
"""
return self(self._process_input_batch(input_batch))

def _process_input_batch(
self, input_batch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Converts output of MixedCovariatesDataset (training dataset) into an input/past- and
output/future chunk.
Parameters
----------
input_batch
``(past_target, past_covariates, historic_future_covariates, future_covariates)``.
``(past_target, past_covariates, historic_future_covariates, future_covariates, static_covariates)``.
Returns
-------
tuple
``(x_past, x_future)`` the input/past and output/future chunks.
``(x_past, x_future, x_static)`` the input/past and output/future chunks.
"""

(
past_target,
past_covariates,
historic_future_covariates,
future_covariates,
static_covariates,
) = input_batch
dim_variable = 2

# TODO: impelement static covariates
static_covariates = None

x_past = torch.cat(
[
tensor
for tensor in [
past_target,
past_covariates,
historic_future_covariates,
static_covariates,
]
if tensor is not None
],
dim=dim_variable,
)

x_future = None
if future_covariates is not None or static_covariates is not None:
x_future = torch.cat(
[
tensor
for tensor in [future_covariates, static_covariates]
if tensor is not None
],
dim=dim_variable,
)

return x_past, x_future
return x_past, future_covariates, static_covariates

def _get_batch_prediction(
self, n: int, input_batch: Tuple, roll_size: int
Expand All @@ -545,6 +553,7 @@ def _get_batch_prediction(
historic_future_covariates,
future_covariates,
future_past_covariates,
static_covariates,
) = input_batch

n_targets = past_target.shape[dim_component]
Expand All @@ -557,18 +566,19 @@ def _get_batch_prediction(
else 0
)

input_past, input_future = self._process_input_batch(
input_past, input_future, input_static = self._process_input_batch(
(
past_target,
past_covariates,
historic_future_covariates,
future_covariates[:, :roll_size, :]
if future_covariates is not None
else None,
static_covariates,
)
)

out = self._produce_predict_output(x=(input_past, input_future))[
out = self._produce_predict_output(x=(input_past, input_future, input_static))[
:, self.first_prediction_index :, :
]

Expand Down Expand Up @@ -636,9 +646,9 @@ def _get_batch_prediction(
input_future = future_covariates[:, left_future:right_future, :]

# take only last part of the output sequence where needed
out = self._produce_predict_output(x=(input_past, input_future))[
:, self.first_prediction_index :, :
]
out = self._produce_predict_output(
x=(input_past, input_future, input_static)
)[:, self.first_prediction_index :, :]

batch_prediction.append(out)
prediction_length += self.output_chunk_length
Expand Down
Loading

0 comments on commit 6c90980

Please sign in to comment.