Skip to content

Commit

Permalink
fix MixedCovTorchModels multi TS predictions with n<ocl (#2374)
Browse files Browse the repository at this point in the history
* fix MixedCovTorchModels multi TS predictions with n<ocl

* update changelog
  • Loading branch information
dennisbader authored May 6, 2024
1 parent 2430903 commit 26d1e1e
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 60 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

**Fixed**
- Fixed a bug where `n_steps_between` did not work properly with custom business frequencies. This affected metrics computation. [#2357](https://github.com/unit8co/darts/pull/2357) by [Dennis Bader](https://github.com/dennisbader).
- Fixed a bug when calling `predict()` with a `MixedCovariatesTorchModel` (e.g. TiDE, N/DLinear, ...) `n<output_chunk_length` and a list of series with length `len(series) < n`, where the predictions did not return the correct number of series. [#2374](https://github.com/unit8co/darts/pull/2374) by [Dennis Bader](https://github.com/dennisbader).

**Dependencies**
- Improvements to linting via updated pre-commit configurations: [#2324](https://github.com/unit8co/darts/pull/2324) by [Jirka Borovec](https://github.com/borda).
Expand Down
10 changes: 6 additions & 4 deletions darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,13 +830,15 @@ def _get_batch_prediction(
batch_prediction = [out[:, :roll_size, :]]
prediction_length = roll_size

while prediction_length < n:
# we want the last prediction to end exactly at `n` into the future.
# predict at least `output_chunk_length` points, so that we use the most recent target values
min_n = n if n >= self.output_chunk_length else self.output_chunk_length
while prediction_length < min_n:
# we want the last prediction to end exactly at `min_n` into the future.
# this means we may have to truncate the previous prediction and step
# back the roll size for the last chunk
if prediction_length + self.output_chunk_length > n:
if prediction_length + self.output_chunk_length > min_n:
spillover_prediction_length = (
prediction_length + self.output_chunk_length - n
prediction_length + self.output_chunk_length - min_n
)
roll_size -= spillover_prediction_length
prediction_length -= spillover_prediction_length
Expand Down
56 changes: 0 additions & 56 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2779,62 +2779,6 @@ def extreme_lags(
None,
)

def predict(
self,
n: int,
series: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
trainer: Optional[pl.Trainer] = None,
batch_size: Optional[int] = None,
verbose: Optional[bool] = None,
n_jobs: int = 1,
roll_size: Optional[int] = None,
num_samples: int = 1,
num_loader_workers: int = 0,
mc_dropout: bool = False,
predict_likelihood_parameters: bool = False,
show_warnings: bool = True,
) -> Union[TimeSeries, Sequence[TimeSeries]]:
# since we have future covariates, the inference dataset for future input must be at least of length
# `output_chunk_length`. If not, we would have to step back which causes past input to be shorter than
# `input_chunk_length`.

if n >= self.output_chunk_length:
return super().predict(
n=n,
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
trainer=trainer,
batch_size=batch_size,
verbose=verbose,
n_jobs=n_jobs,
roll_size=roll_size,
num_samples=num_samples,
num_loader_workers=num_loader_workers,
mc_dropout=mc_dropout,
predict_likelihood_parameters=predict_likelihood_parameters,
show_warnings=show_warnings,
)
else:
return super().predict(
n=self.output_chunk_length,
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
trainer=trainer,
batch_size=batch_size,
verbose=verbose,
n_jobs=n_jobs,
roll_size=roll_size,
num_samples=num_samples,
num_loader_workers=num_loader_workers,
mc_dropout=mc_dropout,
predict_likelihood_parameters=predict_likelihood_parameters,
show_warnings=show_warnings,
)[:n]


class SplitCovariatesTorchModel(TorchForecastingModel, ABC):
def _build_train_dataset(
Expand Down
23 changes: 23 additions & 0 deletions darts/tests/models/forecasting/test_torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1668,6 +1668,29 @@ def test_output_shift(self, config):
_ = model_fc_shift.predict(n=ocl, **add_covs)
assert f"provided {cov_name} covariates at dataset index" in str(err.value)

@pytest.mark.parametrize("config", itertools.product(models, [2, 3, 4]))
def test_multi_ts_prediction(self, config):
(model_cls, model_kwargs), n = config
model_kwargs = copy.deepcopy(model_kwargs)
model_kwargs["output_chunk_length"] = 3
series = tg.linear_timeseries(
length=model_kwargs["input_chunk_length"]
+ model_kwargs["output_chunk_length"]
)
model = model_cls(**model_kwargs)
model.fit(series)
# test with more series that `n`
n_series_more = 5
pred = model.predict(n=n, series=[series] * n_series_more)
assert len(pred) == n_series_more
assert all(len(p) == n for p in pred)

# test with less series that `n`
n_series_less = 1
pred = model.predict(n=n, series=[series] * n_series_less)
assert len(pred) == n_series_less
assert all(len(p) == n for p in pred)

def helper_equality_encoders(
self, first_encoders: Dict[str, Any], second_encoders: Dict[str, Any]
):
Expand Down

0 comments on commit 26d1e1e

Please sign in to comment.