Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Notebook and Template For Global Forecasting API #6699

Merged
merged 128 commits into from
Aug 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
128 commits
Select commit Hold shift + click to select a range
0278280
pytorch-forecasting first draft
XinyuWuu Mar 28, 2024
20c52b0
Merge branch 'main' into global_pytorch_forecasting
XinyuWuu Mar 28, 2024
4aaa022
set None params to default value
XinyuWuu Mar 29, 2024
7dd2a2f
convert X, y to TimeSeriesDataSet in fit
XinyuWuu Mar 30, 2024
670f53c
fix monotone_constaints is None
XinyuWuu Mar 31, 2024
f68fb1c
add dataset_params
XinyuWuu Mar 31, 2024
f0ac35c
add to_dataloader_params
XinyuWuu Mar 31, 2024
473d606
train validation split by max_prediction_length
XinyuWuu Mar 31, 2024
17b02e2
fix kwargs overwrite in model.from_dataset
XinyuWuu Mar 31, 2024
861182a
fix soft dependencies import error
XinyuWuu Mar 31, 2024
1e233bd
data convertion in predict
XinyuWuu Apr 1, 2024
1bd91f2
Merge branch 'main' into global_pytorch_forecasting
XinyuWuu Apr 1, 2024
ef946ca
fix unsupported operand type(s) for |
XinyuWuu Apr 1, 2024
abefa3a
fix kwargs loss after reset
XinyuWuu Apr 6, 2024
1341b36
fix output y name
XinyuWuu Apr 6, 2024
f78e635
add comments
XinyuWuu Apr 14, 2024
d3fb3c8
Merge branch 'main' into global_pytorch_forecasting
XinyuWuu Apr 14, 2024
8ee461a
rename GlobalBaseForecaster to BaseGlobalForecaster
XinyuWuu Apr 19, 2024
f72fc41
add global_forecaster tag
XinyuWuu May 5, 2024
af4fce4
add BaseGlobalForecaster to BASE_CLASS_REGISTER
XinyuWuu May 5, 2024
6525392
add TestAllGlobalForecasters
XinyuWuu May 5, 2024
c3c43de
BaseGlobalForecaster as an exception in test_inheritance
XinyuWuu May 5, 2024
17ec882
Merge branch 'main' into global_pytorch_forecasting
XinyuWuu May 5, 2024
e7e7205
set _tags before init (add "global_forecaster" tag)
XinyuWuu May 5, 2024
80e4005
add capability:global_forecasting to PytorchForecastingTFT
XinyuWuu May 8, 2024
ec800bf
fix global_forecasting check in base class
XinyuWuu May 8, 2024
4adf5b6
register capability:global_forecasting tag
XinyuWuu May 8, 2024
5383bcd
test_global_forecasting_tag
XinyuWuu May 8, 2024
b5de8f0
add pytorch-forecasting to pyproject.toml
XinyuWuu May 8, 2024
433acbc
test_pridect_signature
XinyuWuu May 9, 2024
71f1477
fix empty param dict not work
XinyuWuu May 14, 2024
9361b3c
fix no pd.Series support for y_inner_mtype
XinyuWuu May 14, 2024
1ada234
add requires_X tag for PytorchForecastingTFT
XinyuWuu May 14, 2024
0824d27
fix time index not integer error
XinyuWuu May 15, 2024
922a6c8
add get_test_params for PytorchForecastingTFT
XinyuWuu May 17, 2024
5e15bfd
test_global_fit_predict_insample
XinyuWuu May 17, 2024
9f6dc57
Merge branch 'main' into global_pytorch_forecasting
XinyuWuu May 17, 2024
d96e54e
fix soft dependencies in get_test_param
XinyuWuu May 17, 2024
8cc520d
add target to time_varying_unknown_reals
XinyuWuu May 20, 2024
ac9a211
add PytorchForecastingNBeats
XinyuWuu May 22, 2024
2769321
big fix to support defferent input
XinyuWuu May 22, 2024
aeb0862
Merge branch 'main' into global_pytorch_forecasting
XinyuWuu May 22, 2024
944f9b7
fix encoder length to long
XinyuWuu May 23, 2024
9097abf
fix absolute fh to max prediction length
XinyuWuu May 23, 2024
e642e53
fix x is pd.Series
XinyuWuu May 23, 2024
6f5ceda
add y to _vectorize in predict
XinyuWuu May 23, 2024
58f688f
dummy X for TFT if X is None
XinyuWuu May 23, 2024
ec7d9c7
fix QuantileLoss not pass test_set_params
XinyuWuu May 27, 2024
0e7e22a
self._y in predict if y=None
XinyuWuu May 27, 2024
386e01b
capability:pred_int:insample
XinyuWuu May 27, 2024
0adadd0
fix fh not continue
XinyuWuu May 29, 2024
7b55f08
fix column name in _series_to_frame
XinyuWuu May 29, 2024
f19e59e
fix overwrite self._X self._y
XinyuWuu May 29, 2024
fac46a0
move global_forecasting tag check to the top
XinyuWuu May 29, 2024
4addd77
extend y and concat x
XinyuWuu May 29, 2024
3852184
wordaround for max_prediction_length=1 problem in CI
XinyuWuu May 30, 2024
07acbd6
test_global_forecasting_multiindex_hier
XinyuWuu May 30, 2024
2ac25f1
test_global_forecasting_multiindex
XinyuWuu May 30, 2024
6aec47d
test_global_forecasting_series
XinyuWuu May 30, 2024
c2781a2
test_global_forecasting_no_X
XinyuWuu May 30, 2024
7511ed1
Merge branch 'main' into global_pytorch_forecasting
XinyuWuu May 30, 2024
df2303f
add version constraint
XinyuWuu May 31, 2024
9d940e7
clean fix
XinyuWuu Jun 2, 2024
9ca9709
Merge branch 'main' into global_pytorch_forecasting
XinyuWuu Jun 2, 2024
ce39727
Merge branch 'main' into pr/6228
fkiraly Jun 3, 2024
7c8907d
Update pytorchforecasting.py
fkiraly Jun 3, 2024
caa0b14
Update forecasting.rst
XinyuWuu Jun 3, 2024
a332c61
Merge branch 'main' into pr/6228
fkiraly Jun 3, 2024
b21a336
fix for review by benHeid
XinyuWuu Jun 4, 2024
7f3c9c1
Merge branch 'main' into global_pytorch_forecasting
XinyuWuu Jun 4, 2024
3a60a67
[AUTOMATED] update CONTRIBUTORS.md
Jun 4, 2024
58c9ab4
add DeepAR and NHiTS
XinyuWuu Jun 6, 2024
49978d0
Merge branch 'main' into pr/6228
fkiraly Jun 7, 2024
54530db
Merge branch 'global_pytorch_forecasting' of https://github.com/Xinyu…
fkiraly Jun 7, 2024
415a67e
test_load_model_from_disk
XinyuWuu Jun 10, 2024
677ecdb
Merge branch 'main' into global_pytorch_forecasting
XinyuWuu Jun 10, 2024
ba49fb3
Merge branch 'global_pytorch_forecasting' into DeepAR_NHiTS
XinyuWuu Jun 10, 2024
3890cf8
global forecast docstring in predict
XinyuWuu Jun 11, 2024
bdab849
improve CI test
XinyuWuu Jun 11, 2024
fdfa0fb
Merge branch 'global_pytorch_forecasting' into DeepAR_NHiTS
XinyuWuu Jun 11, 2024
3beb5e3
improve CI test
XinyuWuu Jun 11, 2024
aed8476
add underscore to BaseGlobalForecaster
XinyuWuu Jun 11, 2024
c6bfddd
NotImplementedError for in sample predict
XinyuWuu Jun 11, 2024
c6728f6
set tag dict explicitly
XinyuWuu Jun 11, 2024
c16c918
add model to docs
XinyuWuu Jun 11, 2024
87ba617
Merge branch 'main' into global_pytorch_forecasting
XinyuWuu Jun 12, 2024
d5ee476
pd.Dataframe inner type
XinyuWuu Jun 12, 2024
3587590
fix NaN in target column if fh is not continuous
XinyuWuu Jun 12, 2024
d439955
Merge branch 'global_pytorch_forecasting' into DeepAR_NHiTS
XinyuWuu Jun 12, 2024
a7a36ed
fix FileNotFoundError in CI
XinyuWuu Jun 13, 2024
7416f33
global forecast notebook
XinyuWuu Jun 13, 2024
dcb5f7f
Merge branch 'main' into global_pytorch_forecasting
XinyuWuu Jun 13, 2024
d75f7f8
Merge branch 'global_pytorch_forecasting' into DeepAR_NHiTS
XinyuWuu Jun 13, 2024
fa20a44
Merge branch 'main' into pr/6228
fkiraly Jun 13, 2024
49638af
Merge branch 'global_pytorch_forecasting' of https://github.com/Xinyu…
fkiraly Jun 13, 2024
d0ebd75
install pytorch_forecasting in notebook
XinyuWuu Jun 14, 2024
3a47566
add examples
XinyuWuu Jun 14, 2024
59a4b5c
fix non-continuous fh in extend_y
XinyuWuu Jun 14, 2024
4d2ce8c
Merge branch 'main' into global_pytorch_forecasting
XinyuWuu Jun 14, 2024
ad06c43
Merge branch 'global_pytorch_forecasting' into DeepAR_NHiTS
XinyuWuu Jun 14, 2024
ab00927
random log dir in CI
XinyuWuu Jun 15, 2024
14c8c93
Merge branch 'global_pytorch_forecasting' into DeepAR_NHiTS
XinyuWuu Jun 15, 2024
0814177
random log dir in CI
XinyuWuu Jun 15, 2024
16b47b4
fix no attribute '_random_log_dir'
XinyuWuu Jun 17, 2024
ee2da0a
remove time verification from test_load_model_from_disk
XinyuWuu Jun 17, 2024
741fd61
is_all_out_of_sample
XinyuWuu Jun 17, 2024
eaea795
skit test_score for DeepAR
XinyuWuu Jun 18, 2024
75694c1
docstring for model_path and random_log_path
XinyuWuu Jun 18, 2024
10a59ac
Merge branch 'main' into global_pytorch_forecasting
XinyuWuu Jun 19, 2024
176328a
Merge branch 'global_pytorch_forecasting' into DeepAR_NHiTS
XinyuWuu Jun 19, 2024
24df5b8
deterministic for DeepAR
XinyuWuu Jun 19, 2024
729b0c8
Merge branch 'main' into DeepAR_NHiTS
XinyuWuu Jun 22, 2024
bf63de0
Merge branch 'main' into DeepAR_NHiTS
XinyuWuu Jun 27, 2024
4f392da
address review from benHeid
XinyuWuu Jul 1, 2024
1a3496c
remove notebook
XinyuWuu Jul 1, 2024
786f84a
notebook from #6551
XinyuWuu Jul 1, 2024
e17ece5
super simple template
XinyuWuu Jul 1, 2024
93a5f22
Merge branch 'main' into notebook-and-template
XinyuWuu Aug 1, 2024
2394be8
global to panel in 01c
XinyuWuu Aug 1, 2024
4bcf504
refine 01d
XinyuWuu Aug 1, 2024
6be62b3
merge 01d to 01c
XinyuWuu Aug 1, 2024
c32aaa0
fix ruff format error
XinyuWuu Aug 1, 2024
19cb933
broadcasting in template
XinyuWuu Aug 1, 2024
1a33227
fix dependency error
XinyuWuu Aug 1, 2024
7c10164
docstring for broadcasting
XinyuWuu Aug 2, 2024
bbddf4d
Merge branch 'main' into pr/6699
fkiraly Aug 10, 2024
f3ea04b
minor changes
fkiraly Aug 10, 2024
3d210f8
reformat
fkiraly Aug 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add comments
  • Loading branch information
XinyuWuu committed Apr 14, 2024
commit f78e635082944b32e4ef9d0fba426560f69e112f
2 changes: 1 addition & 1 deletion sktime/forecasting/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2604,7 +2604,7 @@ def predict(self, fh=None, X=None, y=None):
# check fh and coerce to ForecastingHorizon, if not already passed in fit
fh = self._check_fh(fh)

# check global tag
# check global forecasting tag
gf = self.get_tag(
"capability:global_forecasting", tag_value_default=None, raise_error=False
)
Expand Down
21 changes: 21 additions & 0 deletions sktime/forecasting/base/adapters/_pytorchforecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ def _fit(
"""
self._dataset_params = _none_check(self.dataset_params, {})
self._max_prediction_length = fh.to_relative()[-1]
# store the target column name
self.y_name = y.columns[-1]
# convert data to pytorch-forecasting datasets
training, validation = self._Xy_to_dataset(
X, y, self._dataset_params, self._max_prediction_length
)
Expand Down Expand Up @@ -175,9 +177,11 @@ def _predict(
guaranteed to have a single column/variable
Point predictions
"""
# convert data to pytorch-forecasting datasets
training, validation = self._Xy_to_dataset(
X, y, self._dataset_params, self._max_prediction_length
)
# load model from checkpoint
best_model_path = self._trainer.checkpoint_callback.best_model_path
best_model = self.algorithm_class.load_from_checkpoint(best_model_path)
predictions = best_model.predict(
Expand All @@ -186,6 +190,7 @@ def _predict(
return_index=True,
return_decoder_lengths=True,
)
# convert pytorch-forecasting predictions to dataframe
output = self._predictions_to_dataframe(
predictions, self._max_prediction_length
)
Expand All @@ -201,19 +206,24 @@ def _Xy_to_dataset(
):
from pytorch_forecasting.data import TimeSeriesDataSet

# X, y must have same index
assert (X.index == y.index).all()
# warning! X will be modified
data = X.join(y, on=X.index.names)
index_names = data.index.names
index_lens = index_names.__len__()
# reset multi index to normal columns
data = data.reset_index(level=list(range(index_lens)))
training_cutoff = data[index_names[-1]].max() - max_prediction_length
# infer time_idx column, target column and instances from data
_dataset_params = {
"data": data[data[index_names[-1]] <= training_cutoff],
"time_idx": index_names[-1],
"target": data.columns[-1],
"group_ids": index_names[0:-1],
}
_dataset_params.update(dataset_params)
# overwrite max_prediction_length
_dataset_params["max_prediction_length"] = int(max_prediction_length)
training = TimeSeriesDataSet(**_dataset_params)
validation = TimeSeriesDataSet.from_dataset(
Expand All @@ -222,22 +232,33 @@ def _Xy_to_dataset(
return training, validation

def _predictions_to_dataframe(self, predictions, max_prediction_length):
# output is the predictions
output = predictions.output.cpu().numpy()
# index will be combined with output
index = predictions.index
# in pytorch-forecasting predictions, the first index is the time_idx
columns_names = index.columns.to_list()
time_idx = columns_names.pop(0)
# make time_idx the last index
columns_names.append(time_idx)
# in pytorch-forecasting predictions,
# the index only contains the start timepoint.
data = index.loc[index.index.repeat(max_prediction_length)].reset_index(
drop=True
)
# make time_idx the last index
data = data.reindex(columns=columns_names)
# add the target column at the end
data[self.y_name] = output.flatten()
# correct the time_idx after repeating
# assume the time_idx column is continuous integers
for i in range(output.shape[0]):
start_idx = i * max_prediction_length
start_time = data.loc[start_idx, time_idx]
data.loc[
start_idx : start_idx + max_prediction_length - 1, time_idx
] = list(range(start_time, start_time + max_prediction_length))
# set the instance columns to multi index
data.set_index(columns_names, inplace=True)
return data

Expand Down