diff --git a/docs/component/strategy.rst b/docs/component/strategy.rst
index cdaf4d6b11..2dccc47fa7 100644
--- a/docs/component/strategy.rst
+++ b/docs/component/strategy.rst
@@ -161,12 +161,9 @@ Running backtest
start_time="2017-01-01", end_time="2020-08-01", strategy=strategy_obj
)
analysis = dict()
- analysis["excess_return_without_cost"] = risk_analysis(
- report_normal["return"] - report_normal["bench"], freq=analysis_freq
- )
- analysis["excess_return_with_cost"] = risk_analysis(
- report_normal["return"] - report_normal["bench"] - report_normal["cost"], freq=analysis_freq
- )
+ # default frequency will be daily (i.e. "day")
+ analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
+ analysis["excess_return_with_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"] - report_normal["cost"])
analysis_df = pd.concat(analysis) # type: pd.DataFrame
pprint(analysis_df)
diff --git a/examples/online_srv/online_management_simulate.py b/examples/online_srv/online_management_simulate.py
index bd7c4675d1..dccc56b682 100644
--- a/examples/online_srv/online_management_simulate.py
+++ b/examples/online_srv/online_management_simulate.py
@@ -2,7 +2,7 @@
# Licensed under the MIT License.
"""
-This example is about how can simulate the OnlineManager based on rolling tasks.
+This example is about how can simulate the OnlineManager based on rolling tasks.
"""
from pprint import pprint
@@ -15,6 +15,10 @@
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.task.manage import TaskManager
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG_ONLINE, CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE
+import pandas as pd
+from qlib.contrib.evaluate import backtest_daily
+from qlib.contrib.evaluate import risk_analysis
+from qlib.contrib.strategy import TopkDropoutStrategy
class OnlineSimulationExample:
@@ -30,6 +34,7 @@ def __init__(
start_time="2018-09-10",
end_time="2018-10-31",
tasks=None,
+ trainer="TrainerR",
):
"""
Init OnlineManagerExample.
@@ -60,7 +65,13 @@ def __init__(
self.rolling_gen = RollingGen(
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
- self.trainer = TrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
+ if trainer == "TrainerRM":
+ self.trainer = TrainerRM(self.exp_name, self.task_pool)
+ elif trainer == "TrainerR":
+ self.trainer = TrainerR(self.exp_name)
+ else:
+ # TODO: support all the trainers: TrainerR, TrainerRM, DelayTrainerR
+ raise NotImplementedError(f"This type of input is not supported")
self.rolling_online_manager = OnlineManager(
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
trainer=self.trainer,
@@ -70,7 +81,8 @@ def __init__(
# Reset all things to the first status, be careful to save important data
def reset(self):
- TaskManager(self.task_pool).remove()
+ if isinstance(self.trainer, TrainerRM):
+ TaskManager(self.task_pool).remove()
exp = R.get_exp(experiment_name=self.exp_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
@@ -84,7 +96,30 @@ def main(self):
print("========== collect results ==========")
print(self.rolling_online_manager.get_collector()())
print("========== signals ==========")
- print(self.rolling_online_manager.get_signals())
+ signals = self.rolling_online_manager.get_signals()
+ print(signals)
+ # Backtesting
+ # - the code is based on this example https://qlib.readthedocs.io/en/latest/component/strategy.html
+ CSI300_BENCH = "SH000903"
+ STRATEGY_CONFIG = {
+ "topk": 30,
+ "n_drop": 3,
+ "signal": signals.to_frame("score"),
+ }
+ strategy_obj = TopkDropoutStrategy(**STRATEGY_CONFIG)
+ report_normal, positions_normal = backtest_daily(
+ start_time=signals.index.get_level_values("datetime").min(),
+ end_time=signals.index.get_level_values("datetime").max(),
+ strategy=strategy_obj,
+ )
+ analysis = dict()
+ analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
+ analysis["excess_return_with_cost"] = risk_analysis(
+ report_normal["return"] - report_normal["bench"] - report_normal["cost"]
+ )
+
+ analysis_df = pd.concat(analysis) # type: pd.DataFrame
+ pprint(analysis_df)
def worker(self):
# train tasks by other progress or machines for multiprocessing
diff --git a/qlib/contrib/model/gbdt.py b/qlib/contrib/model/gbdt.py
index 8601131853..f14205f888 100644
--- a/qlib/contrib/model/gbdt.py
+++ b/qlib/contrib/model/gbdt.py
@@ -71,6 +71,7 @@ def fit(
early_stopping_callback = lgb.early_stopping(
self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds
)
+ # NOTE: if you encounter error here. Please upgrade your lightgbm
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
evals_result_callback = lgb.record_evaluation(evals_result)
self.model = lgb.train(