diff --git a/qlib/backtest/high_performance_ds.py b/qlib/backtest/high_performance_ds.py index bb75ca8f6a..b185f0d517 100644 --- a/qlib/backtest/high_performance_ds.py +++ b/qlib/backtest/high_performance_ds.py @@ -528,6 +528,9 @@ def apply(self, func: Callable): def reindex(self, index, fill_value): return self.__class__(self.metric.reindex(index, fill_value=fill_value)) + def __repr__(self): + return repr(self.metric) + class PandasOrderIndicator(BaseOrderIndicator): """ @@ -567,6 +570,9 @@ def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, Li tmp_metric = tmp_metric.add(indicator.data[metric], fill_value) order_indicator.assign(metric, tmp_metric.metric) + def __repr__(self): + return repr(self.data) + class NumpyOrderIndicator(BaseOrderIndicator): """ @@ -605,3 +611,6 @@ def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, Li for indicator in indicators: tmp_metric = tmp_metric.add(indicator.data[metric], fill_value) order_indicator.data[metric] = tmp_metric + + def __repr__(self): + return repr(self.data) diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index d76ad07d19..b4b9c5f2e8 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -11,7 +11,7 @@ from qlib.backtest.exchange import Exchange from qlib.backtest.order import BaseTradeDecision, Order, OrderDir -from .high_performance_ds import PandasOrderIndicator, NumpyOrderIndicator, SingleMetric +from .high_performance_ds import BaseOrderIndicator, PandasOrderIndicator, NumpyOrderIndicator, SingleMetric from ..tests.config import CSI300_BENCH from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data from .order import IdxTradeRange @@ -255,7 +255,7 @@ def __init__(self, order_indicator_cls=NumpyOrderIndicator): # order indicator is metrics for a single order for a specific step self.order_indicator_his = OrderedDict() - self.order_indicator = self.order_indicator_cls() + self.order_indicator: BaseOrderIndicator = self.order_indicator_cls() # trade indicator is metrics for all orders for a specific step self.trade_indicator_his = OrderedDict() @@ -265,7 +265,7 @@ def __init__(self, order_indicator_cls=NumpyOrderIndicator): # def reset(self, trade_calendar: TradeCalendarManager): def reset(self): - self.order_indicator = self.order_indicator_cls() + self.order_indicator: BaseOrderIndicator = self.order_indicator_cls() self.trade_indicator = OrderedDict() # self._trade_calendar = trade_calendar diff --git a/qlib/utils/index_data.py b/qlib/utils/index_data.py index c8d6bebeef..79e2f08e34 100644 --- a/qlib/utils/index_data.py +++ b/qlib/utils/index_data.py @@ -280,7 +280,7 @@ def __call__(self, other): self_data_method = getattr(self.obj.data, self.method_name) if isinstance(other, (int, float, np.number)): - return self.obj.__class__(self_data_method(other)) + return self.obj.__class__(self_data_method(other), *self.obj.indices) elif isinstance(other, self.obj.__class__): other_aligned = self.obj._align_indices(other) return self.obj.__class__(self_data_method(other_aligned.data), *self.obj.indices) @@ -450,6 +450,12 @@ def mean(self, axis=None): def isna(self): return self.__class__(np.isnan(self.data), *self.indices) + def fillna(self, value=0.0, inplace: bool = False): + if inplace: + self.data = np.nan_to_num(self.data, nan=value) + else: + return self.__class__(np.nan_to_num(self.data, nan=value), *self.indices) + def count(self): return len(self.data[~np.isnan(self.data)]) @@ -507,6 +513,8 @@ def reindex(self, index: Index, fill_value=np.NaN): ---------- new_index : list new index + fill_value: + what value to fill if index is missing Returns ------- @@ -531,7 +539,7 @@ def add(self, other: "SingleData", fill_value=0): common_index, _ = common_index.sort() tmp_data1 = self.reindex(common_index, fill_value) tmp_data2 = other.reindex(common_index, fill_value) - return tmp_data1 + tmp_data2 + return tmp_data1.fillna(fill_value) + tmp_data2.fillna(fill_value) def to_dict(self): """convert SingleData to dict. diff --git a/tests/misc/test_index_data.py b/tests/misc/test_index_data.py index c7a80fb0f7..010b32847a 100644 --- a/tests/misc/test_index_data.py +++ b/tests/misc/test_index_data.py @@ -99,6 +99,19 @@ def test_ops(self): sd1 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"]) sd2 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"]) print(sd1 + sd2) + new_sd = sd2 * 2 + self.assertTrue(new_sd.index == sd2.index) + + sd1 = idd.SingleData([1, 2, None, 4], index=["foo", "bar", "f", "g"]) + sd2 = idd.SingleData([1, 2, 3, None], index=["foo", "bar", "f", "g"]) + self.assertTrue(np.isnan((sd1 + sd2).iloc[3])) + self.assertTrue(sd1.add(sd2).sum() == 13) + + def test_todo(self): + pass + # here are some examples which do not affect the current system, but it is weird not to support it + # sd2 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"]) + # 2 * sd2 if __name__ == "__main__":