diff --git a/docs/component/strategy.rst b/docs/component/strategy.rst index cdaf4d6b11..2dccc47fa7 100644 --- a/docs/component/strategy.rst +++ b/docs/component/strategy.rst @@ -161,12 +161,9 @@ Running backtest start_time="2017-01-01", end_time="2020-08-01", strategy=strategy_obj ) analysis = dict() - analysis["excess_return_without_cost"] = risk_analysis( - report_normal["return"] - report_normal["bench"], freq=analysis_freq - ) - analysis["excess_return_with_cost"] = risk_analysis( - report_normal["return"] - report_normal["bench"] - report_normal["cost"], freq=analysis_freq - ) + # default frequency will be daily (i.e. "day") + analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"]) + analysis["excess_return_with_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"] - report_normal["cost"]) analysis_df = pd.concat(analysis) # type: pd.DataFrame pprint(analysis_df) diff --git a/examples/online_srv/online_management_simulate.py b/examples/online_srv/online_management_simulate.py index bd7c4675d1..dccc56b682 100644 --- a/examples/online_srv/online_management_simulate.py +++ b/examples/online_srv/online_management_simulate.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. """ -This example is about how can simulate the OnlineManager based on rolling tasks. +This example is about how can simulate the OnlineManager based on rolling tasks. """ from pprint import pprint @@ -15,6 +15,10 @@ from qlib.workflow.task.gen import RollingGen from qlib.workflow.task.manage import TaskManager from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG_ONLINE, CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE +import pandas as pd +from qlib.contrib.evaluate import backtest_daily +from qlib.contrib.evaluate import risk_analysis +from qlib.contrib.strategy import TopkDropoutStrategy class OnlineSimulationExample: @@ -30,6 +34,7 @@ def __init__( start_time="2018-09-10", end_time="2018-10-31", tasks=None, + trainer="TrainerR", ): """ Init OnlineManagerExample. @@ -60,7 +65,13 @@ def __init__( self.rolling_gen = RollingGen( step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None ) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time. - self.trainer = TrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR + if trainer == "TrainerRM": + self.trainer = TrainerRM(self.exp_name, self.task_pool) + elif trainer == "TrainerR": + self.trainer = TrainerR(self.exp_name) + else: + # TODO: support all the trainers: TrainerR, TrainerRM, DelayTrainerR + raise NotImplementedError(f"This type of input is not supported") self.rolling_online_manager = OnlineManager( RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen), trainer=self.trainer, @@ -70,7 +81,8 @@ def __init__( # Reset all things to the first status, be careful to save important data def reset(self): - TaskManager(self.task_pool).remove() + if isinstance(self.trainer, TrainerRM): + TaskManager(self.task_pool).remove() exp = R.get_exp(experiment_name=self.exp_name) for rid in exp.list_recorders(): exp.delete_recorder(rid) @@ -84,7 +96,30 @@ def main(self): print("========== collect results ==========") print(self.rolling_online_manager.get_collector()()) print("========== signals ==========") - print(self.rolling_online_manager.get_signals()) + signals = self.rolling_online_manager.get_signals() + print(signals) + # Backtesting + # - the code is based on this example https://qlib.readthedocs.io/en/latest/component/strategy.html + CSI300_BENCH = "SH000903" + STRATEGY_CONFIG = { + "topk": 30, + "n_drop": 3, + "signal": signals.to_frame("score"), + } + strategy_obj = TopkDropoutStrategy(**STRATEGY_CONFIG) + report_normal, positions_normal = backtest_daily( + start_time=signals.index.get_level_values("datetime").min(), + end_time=signals.index.get_level_values("datetime").max(), + strategy=strategy_obj, + ) + analysis = dict() + analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"]) + analysis["excess_return_with_cost"] = risk_analysis( + report_normal["return"] - report_normal["bench"] - report_normal["cost"] + ) + + analysis_df = pd.concat(analysis) # type: pd.DataFrame + pprint(analysis_df) def worker(self): # train tasks by other progress or machines for multiprocessing diff --git a/qlib/contrib/model/gbdt.py b/qlib/contrib/model/gbdt.py index 8601131853..f14205f888 100644 --- a/qlib/contrib/model/gbdt.py +++ b/qlib/contrib/model/gbdt.py @@ -71,6 +71,7 @@ def fit( early_stopping_callback = lgb.early_stopping( self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds ) + # NOTE: if you encounter error here. Please upgrade your lightgbm verbose_eval_callback = lgb.log_evaluation(period=verbose_eval) evals_result_callback = lgb.record_evaluation(evals_result) self.model = lgb.train(