Skip to content

Commit

Permalink
[END] add loader for .tsf (forecasting benchmark) files (#1934)
Browse files Browse the repository at this point in the history
  • Loading branch information
rakshitha123 authored Feb 11, 2022
1 parent 2825392 commit c02c099
Show file tree
Hide file tree
Showing 6 changed files with 278 additions and 2 deletions.
1 change: 1 addition & 0 deletions .all-contributorsrc
Original file line number Diff line number Diff line change
Expand Up @@ -1442,6 +1442,7 @@
"avatar_url": "https://avatars.githubusercontent.com/u/7654679?v=4",
"profile": "https://github.com/rakshitha123",
"contributions": [
"code",
"doc"
]
},
Expand Down
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ include *.md
recursive-include examples *
recursive-include sktime *.py
recursive-include sktime *.c *.h *.pyx *.pxd *.pxi *.tp
recursive-include sktime/datasets *.csv *.csv.gz *.arff *.arff.gz *.txt *.ts *.tsv
recursive-include sktime/datasets *.csv *.csv.gz *.arff *.arff.gz *.txt *.ts *.tsv *.tsf
include LICENSE
4 changes: 4 additions & 0 deletions sktime/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
"write_ndarray_to_tsfile",
"write_results_to_uea_format",
"write_tabular_transformation_to_arff",
"load_tsf_to_dataframe",
"load_unit_test_tsf",
]

from sktime.datasets._data_io import (
Expand All @@ -42,6 +44,7 @@
load_from_tsfile,
load_from_tsfile_to_dataframe,
load_from_ucr_tsv_to_dataframe,
load_tsf_to_dataframe,
make_multi_index_dataframe,
write_dataframe_to_tsfile,
write_ndarray_to_tsfile,
Expand All @@ -67,5 +70,6 @@
load_shampoo_sales,
load_UCR_UEA_dataset,
load_unit_test,
load_unit_test_tsf,
load_uschange,
)
188 changes: 188 additions & 0 deletions sktime/datasets/_data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import tempfile
import textwrap
import zipfile
from datetime import datetime
from distutils.util import strtobool
from urllib.request import urlretrieve

import numpy as np
Expand Down Expand Up @@ -1676,3 +1678,189 @@ def write_ndarray_to_tsfile(
file.write(f"{a}{missing_values}")
file.write("\n") # open a new line
file.close()


def load_tsf_to_dataframe(
full_file_path_and_name,
replace_missing_vals_with="NaN",
value_column_name="series_value",
):
"""
Convert the contents in a .tsf file into a dataframe.
This code was extracted from
https://github.com/rakshitha123/TSForecasting/blob
/master/utils/data_loader.py.
Parameters
----------
full_file_path_and_name: str
The full path to the .tsf file.
replace_missing_vals_with: str, default="NAN"
A term to indicate the missing values in series in the returning dataframe.
value_column_name: str, default="series_value"
Any name that is preferred to have as the name of the column containing series
values in the returning dataframe.
Returns
-------
loaded_data: pd.DataFrame
The converted dataframe containing the time series.
frequency: str
The frequency of the dataset.
forecast_horizon: int
The expected forecast horizon of the dataset.
contain_missing_values: bool
Whether the dataset contains missing values or not.
contain_equal_length: bool
Whether the series have equal lengths or not.
"""
col_names = []
col_types = []
all_data = {}
line_count = 0
frequency = None
forecast_horizon = None
contain_missing_values = None
contain_equal_length = None
found_data_tag = False
found_data_section = False
started_reading_data_section = False

with open(full_file_path_and_name, "r", encoding="cp1252") as file:
for line in file:
# Strip white space from start/end of line
line = line.strip()

if line:
if line.startswith("@"): # Read meta-data
if not line.startswith("@data"):
line_content = line.split(" ")
if line.startswith("@attribute"):
if (
len(line_content) != 3
): # Attributes have both name and type
raise Exception("Invalid meta-data specification.")

col_names.append(line_content[1])
col_types.append(line_content[2])
else:
if (
len(line_content) != 2
): # Other meta-data have only values
raise Exception("Invalid meta-data specification.")

if line.startswith("@frequency"):
frequency = line_content[1]
elif line.startswith("@horizon"):
forecast_horizon = int(line_content[1])
elif line.startswith("@missing"):
contain_missing_values = bool(
strtobool(line_content[1])
)
elif line.startswith("@equallength"):
contain_equal_length = bool(strtobool(line_content[1]))

else:
if len(col_names) == 0:
raise Exception(
"Missing attribute section. "
"Attribute section must come before data."
)

found_data_tag = True
elif not line.startswith("#"):
if len(col_names) == 0:
raise Exception(
"Missing attribute section. "
"Attribute section must come before data."
)
elif not found_data_tag:
raise Exception("Missing @data tag.")
else:
if not started_reading_data_section:
started_reading_data_section = True
found_data_section = True
all_series = []

for col in col_names:
all_data[col] = []

full_info = line.split(":")

if len(full_info) != (len(col_names) + 1):
raise Exception("Missing attributes/values in series.")

series = full_info[len(full_info) - 1]
series = series.split(",")

if len(series) == 0:
raise Exception(
"A given series should contains a set "
"of comma separated numeric values."
"At least one numeric value should be there "
"in a series. "
"Missing values should be indicated with ? symbol"
)

numeric_series = []

for val in series:
if val == "?":
numeric_series.append(replace_missing_vals_with)
else:
numeric_series.append(float(val))

if numeric_series.count(replace_missing_vals_with) == len(
numeric_series
):
raise Exception(
"All series values are missing. "
"A given series should contains a set "
"of comma separated numeric values."
"At least one numeric value should be there "
"in a series."
)

all_series.append(pd.Series(numeric_series).array)

for i in range(len(col_names)):
att_val = None
if col_types[i] == "numeric":
att_val = int(full_info[i])
elif col_types[i] == "string":
att_val = str(full_info[i])
elif col_types[i] == "date":
att_val = datetime.strptime(
full_info[i], "%Y-%m-%d %H-%M-%S"
)
else:
# Currently, the code supports only
# numeric, string and date types.
# Extend this as required.
raise Exception("Invalid attribute type.")

if att_val is None:
raise Exception("Invalid attribute value.")
else:
all_data[col_names[i]].append(att_val)

line_count = line_count + 1

if line_count == 0:
raise Exception("Empty file.")
if len(col_names) == 0:
raise Exception("Missing attribute section.")
if not found_data_section:
raise Exception("Missing series information under data section.")

all_data[value_column_name] = all_series
loaded_data = pd.DataFrame(all_data)

return (
loaded_data,
frequency,
forecast_horizon,
contain_missing_values,
contain_equal_length,
)
42 changes: 41 additions & 1 deletion sktime/datasets/_single_problem_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"load_gun_point_segmentation",
"load_electric_devices_segmentation",
"load_macroeconomic",
"load_unit_test_tsf",
]

import os
Expand All @@ -42,7 +43,11 @@
import pandas as pd
import statsmodels.api as sm

from sktime.datasets._data_io import _load_dataset, _load_provided_dataset
from sktime.datasets._data_io import (
_load_dataset,
_load_provided_dataset,
load_tsf_to_dataframe,
)

DIRNAME = "data"
MODULE = os.path.dirname(__file__)
Expand Down Expand Up @@ -863,3 +868,38 @@ def load_macroeconomic():
y = y.drop(columns=["year", "quarter", "time"])
y.name = "US Macroeconomic Data"
return y


def load_unit_test_tsf():
"""
Load tsf UnitTest dataset.
Returns
-------
loaded_data: pd.DataFrame
The converted dataframe containing the time series.
frequency: str
The frequency of the dataset.
forecast_horizon: int
The expected forecast horizon of the dataset.
contain_missing_values: bool
Whether the dataset contains missing values or not.
contain_equal_length: bool
Whether the series have equal lengths or not.
"""
path = os.path.join(MODULE, DIRNAME, "UnitTest", "UnitTest_Tsf_Loader.tsf")
(
loaded_data,
frequency,
forecast_horizon,
contain_missing_values,
contain_equal_length,
) = load_tsf_to_dataframe(path)

return (
loaded_data,
frequency,
forecast_horizon,
contain_missing_values,
contain_equal_length,
)
43 changes: 43 additions & 0 deletions sktime/datasets/tests/test_data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
load_from_long_to_dataframe,
load_from_tsfile,
load_from_tsfile_to_dataframe,
load_tsf_to_dataframe,
load_uschange,
write_dataframe_to_tsfile,
)
Expand Down Expand Up @@ -1080,3 +1081,45 @@ def test_write_dataframe_to_ts_fail(tmp_path):
path=str(tmp_path),
problem_name="GunPoint",
)


def test_load_tsf_to_dataframe():
"""Test function for loading tsf format."""
data_path = os.path.join(
os.path.dirname(sktime.__file__),
"datasets/data/UnitTest/UnitTest_Tsf_Loader.tsf",
)

df, frequency, horizon, missing_values, equal_length = load_tsf_to_dataframe(
data_path
)

test_df = pd.DataFrame(
{
"series_name": ["T1", "T2", "T3"],
"start_timestamp": [
pd.Timestamp(year=1979, month=1, day=1),
pd.Timestamp(year=1979, month=1, day=1),
pd.Timestamp(year=1973, month=1, day=1),
],
"series_value": [
[
25092.2284,
24271.5134,
25828.9883,
27697.5047,
27956.2276,
29924.4321,
30216.8321,
],
[887896.51, 887068.98, 971549.04],
[227921, 230995, 183635, 238605, 254186],
],
}
)

assert_frame_equal(df, test_df)
assert frequency == "yearly"
assert horizon == 4
assert missing_values is False
assert equal_length is False

0 comments on commit c02c099

Please sign in to comment.