Skip to content

Commit

Permalink
Update run_all_model script
Browse files Browse the repository at this point in the history
  • Loading branch information
Derek-Wds committed Nov 23, 2020
1 parent 0c3f50e commit 27b573c
Show file tree
Hide file tree
Showing 11 changed files with 57 additions and 24 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ jobs:
- name: Test workflow by config
run: |
qrun examples/benchmarks/GBDT/workflow_config_gbdt.yaml
qrun examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml
File renamed without changes.
66 changes: 50 additions & 16 deletions examples/run_all_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

import os
import sys
import fire
import venv
import glob
import shutil
import tempfile
import statistics
from pathlib import Path
from subprocess import Popen, PIPE
from threading import Thread
Expand All @@ -18,9 +20,16 @@
from qlib.config import REG_CN
from qlib.workflow import R
from qlib.workflow.cli import workflow
from qlib.utils import exists_qlib_data

# init qlib
provider_uri = "~/.qlib/qlib_data/cn_data"
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData

GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN)


Expand Down Expand Up @@ -152,6 +161,18 @@ def install_pip(self, context):
self.install_script(context, "pip", url)


# function to calculate the mean and std of a list in the results dictionary
def cal_mean_std(results) -> dict:
mean_std = dict()
for fn in results:
mean_std[fn] = dict()
for metric in results[fn]:
mean = statistics.mean(results[fn][metric]) if len(results[fn][metric]) > 1 else results[fn][metric][0]
std = statistics.stdev(results[fn][metric]) if len(results[fn][metric]) > 1 else 0
mean_std[fn][metric] = [mean, std]
return mean_std


# function to get all the folders benchmark folder
def get_all_folders() -> dict:
folders = dict()
Expand All @@ -175,29 +196,37 @@ def get_all_results(folders) -> dict:
for fn in folders:
exp = R.get_exp(experiment_name=fn, create=False)
recorders = exp.list_recorders()
recorder = R.get_recorder(recorder_id=next(iter(recorders)), experiment_name=fn)
metrics = recorder.list_metrics()
results[fn] = {key: metrics[key] for key in metrics if "with_cost" in key}
result = dict()
result["annualized_return_with_cost"] = list()
result["information_ratio_with_cost"] = list()
result["max_drawdown_with_cost"] = list()
for recorder_id in recorders:
recorder = R.get_recorder(recorder_id=recorder_id, experiment_name=fn)
metrics = recorder.list_metrics()
result["annualized_return_with_cost"].append(metrics["excess_return_with_cost.annualized_return"])
result["information_ratio_with_cost"].append(metrics["excess_return_with_cost.information_ratio"])
result["max_drawdown_with_cost"].append(metrics["excess_return_with_cost.max_drawdown"])
results[fn] = result
return results


# function to generate and save markdown tables
def gen_and_save_md_table(results):
# function to generate and save markdown table
def gen_and_save_md_table(metrics):
table = "| Model Name | Annualized Return | Information Ratio | Max Drawdown |\n"
table += "|---|---|---|---|\n"
for fn in results:
ar = metrics[fn]["excess_return_with_cost.annualized_return"]
ir = metrics[fn]["excess_return_with_cost.information_ratio"]
md = metrics[fn]["excess_return_with_cost.max_drawdown"]
table += f"| {fn} | {ar:9.5f} | {ir:9.5f} | {md:9.5f} |\n"
for fn in metrics:
ar = metrics[fn]["annualized_return_with_cost"]
ir = metrics[fn]["information_ratio_with_cost"]
md = metrics[fn]["max_drawdown_with_cost"]
table += f"| {fn} | {ar[0]:9.4f}±{ar[1]:9.2f} | {ir[0]:9.4f}±{ir[1]:9.2f}| {md[0]:9.4f}±{md[1]:9.2f} |\n"
pprint(table)
with open("table.md", "w") as f:
f.write(table)
return table


# function to run the all the models
def run():
def run(times=1):
"""
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
Any PR to enhance this method is highly welcomed.
Expand Down Expand Up @@ -225,6 +254,7 @@ def run():
nopip=False,
verbose=False,
)
# run all the model for iterations
for fn in folders:
# create env
temp_dir = tempfile.mkdtemp()
Expand All @@ -246,16 +276,20 @@ def run():
os.system(f"{python_path} -m pip install --upgrade cython") # TODO: FIX ME!
os.system(f"{python_path} -m pip install -e git+https://github.com/you-n-g/qlib#egg=pyqlib") # TODO: FIX ME!
sys.stderr.write("\n")
# run workflow_by_config
sys.stderr.write(f"Running the model: {fn}...\n")
os.system(f"{python_path} {env_path / 'src/pyqlib/qlib/workflow/cli.py'} {yaml_path} {fn}")
sys.stderr.write("\n")
# run workflow_by_config for multiple times
for i in range(times):
sys.stderr.write(f"Running the model: {fn} for iteration {i+1}...\n")
os.system(f"{python_path} {env_path / 'src/pyqlib/qlib/workflow/cli.py'} {yaml_path} {fn}")
sys.stderr.write("\n")
# remove env
sys.stderr.write(f"Deleting the environment: {env_path}...\n")
shutil.rmtree(env_path)
# getting all results
sys.stderr.write(f"Retrieving results...\n")
results = get_all_results(folders)
# calculating the mean and std
sys.stderr.write(f"Calculating the mean and std of results...\n")
results = cal_mean_std(results)
# generating md table
sys.stderr.write(f"Generating markdown table...\n")
gen_and_save_md_table(results)
Expand All @@ -264,7 +298,7 @@ def run():
if __name__ == "__main__":
rc = 1
try:
run() # run all the model
fire.Fire(run) # run all the model
rc = 0
except Exception as e:
print("Error: %s" % e, file=sys.stderr)
Expand Down
2 changes: 1 addition & 1 deletion examples/workflow_by_code.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
" print(f\"Qlib data is not found in {provider_uri}\")\n",
" sys.path.append(str(Path.cwd().parent.joinpath(\"scripts\")))\n",
" from get_data import GetData\n",
" GetData().qlib_data(target_dir=provider_uri, region=\"cn\")\n",
" GetData().qlib_data(target_dir=provider_uri, region=REG_CN)\n",
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
]
},
Expand Down
2 changes: 1 addition & 1 deletion examples/workflow_by_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData

GetData().qlib_data(target_dir=provider_uri, region="cn")
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)

qlib.init(provider_uri=provider_uri, region=REG_CN)

Expand Down
2 changes: 1 addition & 1 deletion examples/workflow_by_code_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData

GetData().qlib_data(target_dir=provider_uri, region="cn")
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)

qlib.init(provider_uri=provider_uri, region=REG_CN)

Expand Down
2 changes: 1 addition & 1 deletion examples/workflow_by_code_gats.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData

GetData().qlib_data(target_dir=provider_uri, region="cn")
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)

qlib.init(provider_uri=provider_uri, region=REG_CN)

Expand Down
2 changes: 1 addition & 1 deletion examples/workflow_by_code_gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData

GetData().qlib_data(target_dir=provider_uri, region="cn")
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)

qlib.init(provider_uri=provider_uri, region=REG_CN)

Expand Down
2 changes: 1 addition & 1 deletion examples/workflow_by_code_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData

GetData().qlib_data(target_dir=provider_uri, region="cn")
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)

qlib.init(provider_uri=provider_uri, region=REG_CN)

Expand Down
1 change: 0 additions & 1 deletion qlib/workflow/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import fire
import pandas as pd
import ruamel.yaml as yaml
from qlib.config import REG_CN
from qlib.utils import init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord
Expand Down

0 comments on commit 27b573c

Please sign in to comment.