From bee05f56ef97c0f617432b800d47b408847ba057 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Mon, 19 Sep 2022 14:54:26 +0800 Subject: [PATCH] Migrate backtest logic from NT (#1263) * Backtest migration * Minor bug fix in test * Reorganize file to avoid loop import * Fix test SAOE bug * Remove unnecessary names * Resolve PR comments; remove private classes; * Fix CI error * Resolve PR comments * Refactor data interfaces * Remove convert_instance_config and change config * Pylint issue * Pylint issue * Fix tempfile warning * Resolve PR comments * Add more comments --- qlib/backtest/__init__.py | 8 +- qlib/rl/contrib/backtest.py | 231 ++++++++++++++++++ qlib/rl/contrib/naive_config_parser.py | 103 ++++++++ qlib/rl/contrib/utils.py | 29 +++ qlib/rl/data/base.py | 65 +++++ .../{order_execution => data}/integration.py | 2 +- .../data/{exchange_wrapper.py => native.py} | 46 +++- qlib/rl/data/pickle_styled.py | 75 +++--- qlib/rl/interpreter.py | 9 +- qlib/rl/order_execution/interpreter.py | 90 +++++-- qlib/rl/order_execution/network.py | 21 ++ qlib/rl/order_execution/simulator_qlib.py | 2 +- qlib/rl/order_execution/simulator_simple.py | 4 +- qlib/rl/order_execution/state.py | 9 +- qlib/rl/order_execution/strategy.py | 126 +++++++++- qlib/rl/trainer/__init__.py | 2 +- qlib/rl/utils/env_wrapper.py | 40 ++- tests/rl/test_qlib_simulator.py | 3 + tests/rl/test_saoe_simple.py | 47 ++-- 19 files changed, 794 insertions(+), 118 deletions(-) create mode 100644 qlib/rl/contrib/backtest.py create mode 100644 qlib/rl/contrib/naive_config_parser.py create mode 100644 qlib/rl/contrib/utils.py create mode 100644 qlib/rl/data/base.py rename qlib/rl/{order_execution => data}/integration.py (98%) rename qlib/rl/data/{exchange_wrapper.py => native.py} (66%) diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index e8fe73c5a2..81c6437d6d 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -114,7 +114,7 @@ def get_exchange( def create_account_instance( start_time: Union[pd.Timestamp, str], end_time: Union[pd.Timestamp, str], - benchmark: str, + benchmark: Optional[str], account: Union[float, int, dict], pos_type: str = "Position", ) -> Account: @@ -163,7 +163,9 @@ def create_account_instance( init_cash=init_cash, position_dict=position_dict, pos_type=pos_type, - benchmark_config={ + benchmark_config={} + if benchmark is None + else { "benchmark": benchmark, "start_time": start_time, "end_time": end_time, @@ -176,7 +178,7 @@ def get_strategy_executor( end_time: Union[pd.Timestamp, str], strategy: Union[str, dict, object, Path], executor: Union[str, dict, object, Path], - benchmark: str = "SH000300", + benchmark: Optional[str] = "SH000300", account: Union[float, int, dict] = 1e9, exchange_kwargs: dict = {}, pos_type: str = "Position", diff --git a/qlib/rl/contrib/backtest.py b/qlib/rl/contrib/backtest.py new file mode 100644 index 0000000000..709c050dfb --- /dev/null +++ b/qlib/rl/contrib/backtest.py @@ -0,0 +1,231 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import copy +import pickle +import sys +from pathlib import Path +from typing import Optional, Tuple, Union + +import numpy as np +import pandas as pd +import torch +from joblib import Parallel, delayed + +from qlib.backtest import collect_data_loop, get_strategy_executor +from qlib.backtest.decision import TradeRangeByTime +from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor +from qlib.backtest.high_performance_ds import BaseOrderIndicator +from qlib.rl.contrib.naive_config_parser import get_backtest_config_fromfile +from qlib.rl.contrib.utils import read_order_file +from qlib.rl.data.integration import init_qlib +from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper + + +def _get_multi_level_executor_config( + strategy_config: dict, + cash_limit: float = None, + generate_report: bool = False, +) -> dict: + executor_config = { + "class": "SimulatorExecutor", + "module_path": "qlib.backtest.executor", + "kwargs": { + "time_per_step": "1min", + "verbose": False, + "trade_type": SimulatorExecutor.TT_PARAL if cash_limit is not None else SimulatorExecutor.TT_SERIAL, + "generate_report": generate_report, + "track_data": True, + }, + } + + freqs = list(strategy_config.keys()) + freqs.sort(key=lambda x: pd.Timedelta(x)) + for freq in freqs: + executor_config = { + "class": "NestedExecutor", + "module_path": "qlib.backtest.executor", + "kwargs": { + "time_per_step": freq, + "inner_strategy": strategy_config[freq], + "inner_executor": executor_config, + "track_data": True, + }, + } + + return executor_config + + +def _set_env_for_all_strategy(executor: BaseExecutor) -> None: + if isinstance(executor, NestedExecutor): + if hasattr(executor.inner_strategy, "set_env"): + env = CollectDataEnvWrapper() + env.reset() + executor.inner_strategy.set_env(env) + _set_env_for_all_strategy(executor.inner_executor) + + +def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]: + record_list = [] + for time, value_dict in indicator.items(): + if isinstance(value_dict, BaseOrderIndicator): + # HACK: for qlib v0.8 + value_dict = value_dict.to_series() + try: + value_dict = {k: v for k, v in value_dict.items()} + if value_dict["ffr"].empty: + continue + except Exception: + value_dict = {k: v for k, v in value_dict.items() if k != "pa"} + value_dict = pd.DataFrame(value_dict) + value_dict["datetime"] = time + record_list.append(value_dict) + + if not record_list: + return None + + records: pd.DataFrame = pd.concat(record_list, 0).reset_index().rename(columns={"index": "instrument"}) + records = records.set_index(["instrument", "datetime"]) + return records + + +def _generate_report(decisions: list, report_dict: dict) -> dict: + report = {} + decision_details = pd.concat([d.details for d in decisions if hasattr(d, "details")]) + for key in ["1minute", "5minute", "30minute", "1day"]: + if key not in report_dict["indicator"]: + continue + report[key] = report_dict["indicator"][key] + report[key + "_obj"] = _convert_indicator_to_dataframe( + report_dict["indicator"][key + "_obj"].order_indicator_his + ) + cur_details = decision_details[decision_details.freq == key.rstrip("ute")].set_index(["instrument", "datetime"]) + if len(cur_details) > 0: + cur_details.pop("freq") + report[key + "_obj"] = report[key + "_obj"].join(cur_details, how="outer") + if "1minute" in report_dict["report"]: + report["simulator"] = report_dict["report"]["1minute"][0] + return report + + +def single( + backtest_config: dict, + orders: pd.DataFrame, + split: str = "stock", + cash_limit: float = None, + generate_report: bool = False, +) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]: + if split == "stock": + stock_id = orders.iloc[0].instrument + init_qlib(backtest_config["qlib"], part=stock_id) + else: + day = orders.iloc[0].datetime + init_qlib(backtest_config["qlib"], part=day) + + trade_start_time = orders["datetime"].min() + trade_end_time = orders["datetime"].max() + stocks = orders.instrument.unique().tolist() + + top_strategy_config = { + "class": "FileOrderStrategy", + "module_path": "qlib.contrib.strategy.rule_strategy", + "kwargs": { + "file": orders, + "trade_range": TradeRangeByTime( + pd.Timestamp(backtest_config["start_time"]).time(), + pd.Timestamp(backtest_config["end_time"]).time(), + ), + }, + } + + top_executor_config = _get_multi_level_executor_config( + strategy_config=backtest_config["strategies"], + cash_limit=cash_limit, + generate_report=generate_report, + ) + + tmp_backtest_config = copy.deepcopy(backtest_config["exchange"]) + tmp_backtest_config.update( + { + "codes": stocks, + "freq": "1min", + } + ) + + strategy, executor = get_strategy_executor( + start_time=pd.Timestamp(trade_start_time), + end_time=pd.Timestamp(trade_end_time) + pd.DateOffset(1), + strategy=top_strategy_config, + executor=top_executor_config, + benchmark=None, + account=cash_limit if cash_limit is not None else int(1e12), + exchange_kwargs=tmp_backtest_config, + pos_type="Position" if cash_limit is not None else "InfPosition", + ) + _set_env_for_all_strategy(executor=executor) + + report_dict: dict = {} + decisions = list(collect_data_loop(trade_start_time, trade_end_time, strategy, executor, report_dict)) + + records = _convert_indicator_to_dataframe(report_dict["indicator"]["1day_obj"].order_indicator_his) + assert records is None or not np.isnan(records["ffr"]).any() + + if generate_report: + report = _generate_report(decisions, report_dict) + if split == "stock": + stock_id = orders.iloc[0].instrument + report = {stock_id: report} + else: + day = orders.iloc[0].datetime + report = {day: report} + return records, report + else: + return records + + +def backtest(backtest_config: dict) -> pd.DataFrame: + order_df = read_order_file(backtest_config["order_file"]) + + cash_limit = backtest_config["exchange"].pop("cash_limit") + generate_report = backtest_config["exchange"].pop("generate_report") + + stock_pool = order_df["instrument"].unique().tolist() + stock_pool.sort() + + mp_config = {"n_jobs": backtest_config["concurrency"], "verbose": 10, "backend": "multiprocessing"} + torch.set_num_threads(1) # https://github.com/pytorch/pytorch/issues/17199 + res = Parallel(**mp_config)( + delayed(single)( + backtest_config=backtest_config, + orders=order_df[order_df["instrument"] == stock].copy(), + split="stock", + cash_limit=cash_limit, + generate_report=generate_report, + ) + for stock in stock_pool + ) + + output_path = Path(backtest_config["output_dir"]) + if generate_report: + with (output_path / "report.pkl").open("wb") as f: + report = {} + for r in res: + report.update(r[1]) + pickle.dump(report, f) + res = pd.concat([r[0] for r in res], 0) + else: + res = pd.concat(res) + + res.to_csv(output_path / "summary.csv") + return res + + +if __name__ == "__main__": + import warnings + + warnings.filterwarnings("ignore", category=DeprecationWarning) + warnings.filterwarnings("ignore", category=RuntimeWarning) + + path = sys.argv[1] + backtest(get_backtest_config_fromfile(path)) diff --git a/qlib/rl/contrib/naive_config_parser.py b/qlib/rl/contrib/naive_config_parser.py new file mode 100644 index 0000000000..eaf62636cc --- /dev/null +++ b/qlib/rl/contrib/naive_config_parser.py @@ -0,0 +1,103 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +import platform +import shutil +import sys +import tempfile +from importlib import import_module + +import yaml + + +def merge_a_into_b(a: dict, b: dict) -> dict: + b = b.copy() + for k, v in a.items(): + if isinstance(v, dict) and k in b: + v.pop("_delete_", False) # TODO: make this more elegant + b[k] = merge_a_into_b(v, b[k]) + else: + b[k] = v + return b + + +def check_file_exist(filename: str, msg_tmpl: str = 'file "{}" does not exist') -> None: + if not os.path.isfile(filename): + raise FileNotFoundError(msg_tmpl.format(filename)) + + +def parse_backtest_config(path: str) -> dict: + abs_path = os.path.abspath(path) + check_file_exist(abs_path) + + file_ext_name = os.path.splitext(abs_path)[1] + if file_ext_name not in (".py", ".json", ".yaml", ".yml"): + raise IOError("Only py/yml/yaml/json type are supported now!") + + with tempfile.TemporaryDirectory() as tmp_config_dir: + with tempfile.NamedTemporaryFile(dir=tmp_config_dir, suffix=file_ext_name) as tmp_config_file: + if platform.system() == "Windows": + tmp_config_file.close() + + tmp_config_name = os.path.basename(tmp_config_file.name) + shutil.copyfile(abs_path, tmp_config_file.name) + + if abs_path.endswith(".py"): + tmp_module_name = os.path.splitext(tmp_config_name)[0] + sys.path.insert(0, tmp_config_dir) + module = import_module(tmp_module_name) + sys.path.pop(0) + + config = {k: v for k, v in module.__dict__.items() if not k.startswith("__")} + + del sys.modules[tmp_module_name] + else: + config = yaml.safe_load(open(tmp_config_file.name)) + + if "_base_" in config: + base_file_name = config.pop("_base_") + if not isinstance(base_file_name, list): + base_file_name = [base_file_name] + + for f in base_file_name: + base_config = parse_backtest_config(os.path.join(os.path.dirname(abs_path), f)) + config = merge_a_into_b(a=config, b=base_config) + + return config + + +def _convert_all_list_to_tuple(config: dict) -> dict: + for k, v in config.items(): + if isinstance(v, list): + config[k] = tuple(v) + elif isinstance(v, dict): + config[k] = _convert_all_list_to_tuple(v) + return config + + +def get_backtest_config_fromfile(path: str) -> dict: + backtest_config = parse_backtest_config(path) + + exchange_config_default = { + "open_cost": 0.0005, + "close_cost": 0.0015, + "min_cost": 5.0, + "trade_unit": 100.0, + "cash_limit": None, + "generate_report": False, + } + backtest_config["exchange"] = merge_a_into_b(a=backtest_config["exchange"], b=exchange_config_default) + backtest_config["exchange"] = _convert_all_list_to_tuple(backtest_config["exchange"]) + + backtest_config_default = { + "debug_single_stock": None, + "debug_single_day": None, + "concurrency": -1, + "multiplier": 1.0, + "output_dir": "outputs/", + # "runtime": {}, + } + backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default) + + return backtest_config diff --git a/qlib/rl/contrib/utils.py b/qlib/rl/contrib/utils.py new file mode 100644 index 0000000000..cad25e0dba --- /dev/null +++ b/qlib/rl/contrib/utils.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from pathlib import Path + +import pandas as pd + + +def read_order_file(order_file: Path | pd.DataFrame) -> pd.DataFrame: + if isinstance(order_file, pd.DataFrame): + return order_file + + order_file = Path(order_file) + + if order_file.suffix == ".pkl": + order_df = pd.read_pickle(order_file).reset_index() + elif order_file.suffix == ".csv": + order_df = pd.read_csv(order_file) + else: + raise TypeError(f"Unsupported order file type: {order_file}") + + if "date" in order_df.columns: + # legacy dataframe columns + order_df = order_df.rename(columns={"date": "datetime", "order_type": "direction"}) + order_df["datetime"] = order_df["datetime"].astype(str) + + return order_df diff --git a/qlib/rl/data/base.py b/qlib/rl/data/base.py new file mode 100644 index 0000000000..e258abe869 --- /dev/null +++ b/qlib/rl/data/base.py @@ -0,0 +1,65 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from abc import abstractmethod + +import pandas as pd + + +class BaseIntradayBacktestData: + """ + Raw market data that is often used in backtesting (thus called BacktestData). + + Base class for all types of backtest data. Currently, each type of simulator has its corresponding backtest + data type. + """ + + @abstractmethod + def __repr__(self) -> str: + raise NotImplementedError + + @abstractmethod + def __len__(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_deal_price(self) -> pd.Series: + raise NotImplementedError + + @abstractmethod + def get_volume(self) -> pd.Series: + raise NotImplementedError + + @abstractmethod + def get_time_index(self) -> pd.DatetimeIndex: + raise NotImplementedError + + +class BaseIntradayProcessedData: + """Processed market data after data cleanup and feature engineering. + + It contains both processed data for "today" and "yesterday", as some algorithms + might use the market information of the previous day to assist decision making. + """ + + today: pd.DataFrame + """Processed data for "today". + Number of records must be ``time_length``, and columns must be ``feature_dim``.""" + + yesterday: pd.DataFrame + """Processed data for "yesterday". + Number of records must be ``time_length``, and columns must be ``feature_dim``.""" + + +class ProcessedDataProvider: + """Provider of processed data""" + + def get_data( + self, + stock_id: str, + date: pd.Timestamp, + feature_dim: int, + time_index: pd.Index, + ) -> BaseIntradayProcessedData: + raise NotImplementedError diff --git a/qlib/rl/order_execution/integration.py b/qlib/rl/data/integration.py similarity index 98% rename from qlib/rl/order_execution/integration.py rename to qlib/rl/data/integration.py index 07ca381613..d32ce49c82 100644 --- a/qlib/rl/order_execution/integration.py +++ b/qlib/rl/data/integration.py @@ -41,7 +41,7 @@ def __init__( @cachetools.cached( # type: ignore cache=cachetools.LRUCache(100), - key=lambda stock_id, date, backtest: (stock_id, date.replace(hour=0, minute=0, second=0), backtest), + key=lambda _, stock_id, date, backtest: (stock_id, date.replace(hour=0, minute=0, second=0), backtest), ) def get(self, stock_id: str, date: pd.Timestamp, backtest: bool = False) -> pd.DataFrame: start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59) diff --git a/qlib/rl/data/exchange_wrapper.py b/qlib/rl/data/native.py similarity index 66% rename from qlib/rl/data/exchange_wrapper.py rename to qlib/rl/data/native.py index 94bb1dcbbd..eb612cf64e 100644 --- a/qlib/rl/data/exchange_wrapper.py +++ b/qlib/rl/data/native.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from __future__ import annotations from typing import cast @@ -8,10 +9,12 @@ from qlib.backtest import Exchange, Order from qlib.backtest.decision import TradeRange, TradeRangeByTime -from qlib.constant import ONE_DAY, EPS_T +from qlib.constant import EPS_T, ONE_DAY from qlib.rl.order_execution.utils import get_ticks_slice from qlib.utils.index_data import IndexData -from .pickle_styled import BaseIntradayBacktestData + +from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider +from .integration import fetch_features class IntradayBacktestData(BaseIntradayBacktestData): @@ -74,7 +77,7 @@ def get_time_index(self) -> pd.DatetimeIndex: cache=cachetools.LRUCache(100), key=lambda order, _, __: order.key_by_day, ) -def load_qlib_backtest_data( +def load_backtest_data( order: Order, trade_exchange: Exchange, trade_range: TradeRange, @@ -108,3 +111,40 @@ def load_qlib_backtest_data( ticks_for_order=ticks_for_order, ) return backtest_data + + +class NTIntradayProcessedData(BaseIntradayProcessedData): + """Subclass of IntradayProcessedData. Used to handle NT style data.""" + + def __init__( + self, + stock_id: str, + date: pd.Timestamp, + ) -> None: + def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame: + return df.reset_index().drop(columns=["instrument"]).set_index(["datetime"]) + + self.today = _drop_stock_id(fetch_features(stock_id, date)) + self.yesterday = _drop_stock_id(fetch_features(stock_id, date, yesterday=True)) + + def __repr__(self) -> str: + with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"): + return f"{self.__class__.__name__}({self.today}, {self.yesterday})" + + +@cachetools.cached( # type: ignore + cache=cachetools.LRUCache(100), # 100 * 50K = 5MB +) +def load_nt_intraday_processed_data(stock_id: str, date: pd.Timestamp) -> NTIntradayProcessedData: + return NTIntradayProcessedData(stock_id, date) + + +class NTProcessedDataProvider(ProcessedDataProvider): + def get_data( + self, + stock_id: str, + date: pd.Timestamp, + feature_dim: int, + time_index: pd.Index, + ) -> BaseIntradayProcessedData: + return load_nt_intraday_processed_data(stock_id, date) diff --git a/qlib/rl/data/pickle_styled.py b/qlib/rl/data/pickle_styled.py index 43fe9dd5ad..ed62a4180d 100644 --- a/qlib/rl/data/pickle_styled.py +++ b/qlib/rl/data/pickle_styled.py @@ -19,7 +19,6 @@ from __future__ import annotations -from abc import abstractmethod from functools import lru_cache from pathlib import Path from typing import List, Sequence, cast @@ -30,6 +29,7 @@ from cachetools.keys import hashkey from qlib.backtest.decision import Order, OrderDir +from qlib.rl.data.base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider from qlib.typehint import Literal DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"] @@ -86,35 +86,6 @@ def _read_pickle(filename_without_suffix: Path) -> pd.DataFrame: return pd.read_pickle(_find_pickle(filename_without_suffix)) -class BaseIntradayBacktestData: - """ - Raw market data that is often used in backtesting (thus called BacktestData). - - Base class for all types of backtest data. Currently, each type of simulator has its corresponding backtest - data type. - """ - - @abstractmethod - def __repr__(self) -> str: - raise NotImplementedError - - @abstractmethod - def __len__(self) -> int: - raise NotImplementedError - - @abstractmethod - def get_deal_price(self) -> pd.Series: - raise NotImplementedError - - @abstractmethod - def get_volume(self) -> pd.Series: - raise NotImplementedError - - @abstractmethod - def get_time_index(self) -> pd.DatetimeIndex: - raise NotImplementedError - - class SimpleIntradayBacktestData(BaseIntradayBacktestData): """Backtest data for simple simulator""" @@ -178,20 +149,8 @@ def get_time_index(self) -> pd.DatetimeIndex: return cast(pd.DatetimeIndex, self.data.index) -class IntradayProcessedData: - """Processed market data after data cleanup and feature engineering. - - It contains both processed data for "today" and "yesterday", as some algorithms - might use the market information of the previous day to assist decision making. - """ - - today: pd.DataFrame - """Processed data for "today". - Number of records must be ``time_length``, and columns must be ``feature_dim``.""" - - yesterday: pd.DataFrame - """Processed data for "yesterday". - Number of records must be ``time_length``, and columns must be ``feature_dim``.""" +class IntradayProcessedData(BaseIntradayProcessedData): + """Subclass of IntradayProcessedData. Used to handle Dataset Handler style data.""" def __init__( self, @@ -246,18 +205,40 @@ def load_simple_intraday_backtest_data( @cachetools.cached( # type: ignore cache=cachetools.LRUCache(100), # 100 * 50K = 5MB - key=lambda data_dir, stock_id, date, _, __: hashkey(data_dir, stock_id, date), + key=lambda data_dir, stock_id, date, feature_dim, time_index: hashkey(data_dir, stock_id, date), ) -def load_intraday_processed_data( +def load_pickled_intraday_processed_data( data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index, -) -> IntradayProcessedData: +) -> BaseIntradayProcessedData: return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index) +class PickleProcessedDataProvider(ProcessedDataProvider): + def __init__(self, data_dir: Path) -> None: + super().__init__() + + self._data_dir = data_dir + + def get_data( + self, + stock_id: str, + date: pd.Timestamp, + feature_dim: int, + time_index: pd.Index, + ) -> BaseIntradayProcessedData: + return load_pickled_intraday_processed_data( + data_dir=self._data_dir, + stock_id=stock_id, + date=date, + feature_dim=feature_dim, + time_index=time_index, + ) + + def load_orders( order_path: Path, start_time: pd.Timestamp = None, diff --git a/qlib/rl/interpreter.py b/qlib/rl/interpreter.py index 61c9b83819..d2d81f81cd 100644 --- a/qlib/rl/interpreter.py +++ b/qlib/rl/interpreter.py @@ -3,16 +3,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar +from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar import numpy as np from qlib.typehint import final - from .simulator import ActType, StateType if TYPE_CHECKING: - from .utils.env_wrapper import EnvWrapper + from .utils.env_wrapper import BaseEnvWrapper import gym from gym import spaces @@ -40,7 +39,7 @@ class Interpreter: class StateInterpreter(Generic[StateType, ObsType], Interpreter): """State Interpreter that interpret execution result of qlib executor into rl env state""" - env: Optional[EnvWrapper] = None + env: Optional[BaseEnvWrapper] = None @property def observation_space(self) -> gym.Space: @@ -74,7 +73,7 @@ def interpret(self, simulator_state: StateType) -> ObsType: class ActionInterpreter(Generic[StateType, PolicyActType, ActType], Interpreter): """Action Interpreter that interpret rl agent action into qlib orders""" - env: Optional[EnvWrapper] = None + env: Optional[BaseEnvWrapper] = None @property def action_space(self) -> gym.Space: diff --git a/qlib/rl/order_execution/interpreter.py b/qlib/rl/order_execution/interpreter.py index 089fc553cf..0b89977491 100644 --- a/qlib/rl/order_execution/interpreter.py +++ b/qlib/rl/order_execution/interpreter.py @@ -4,15 +4,14 @@ from __future__ import annotations import math -from pathlib import Path -from typing import Any, List, cast +from typing import Any, List, Optional, cast import numpy as np import pandas as pd from gym import spaces from qlib.constant import EPS -from qlib.rl.data import pickle_styled +from qlib.rl.data.base import ProcessedDataProvider from qlib.rl.interpreter import ActionInterpreter, StateInterpreter from qlib.rl.order_execution.state import SAOEState from qlib.typehint import TypedDict @@ -25,6 +24,8 @@ "FullHistoryObs", ] +from qlib.utils import init_instance_by_config + def canonicalize(value: int | float | np.ndarray | pd.DataFrame | dict) -> np.ndarray | dict: """To 32-bit numeric types. Recursively.""" @@ -57,8 +58,6 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]): Parameters ---------- - data_dir - Path to load data after feature engineering. max_step Total number of steps (an upper-bound estimation). For example, 390min / 30min-per-step = 13 steps. data_ticks @@ -66,21 +65,37 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]): the total ticks is the length of day in minutes. data_dim Number of dimensions in data. + processed_data_provider + Provider of the processed data. """ - def __init__(self, data_dir: Path, max_step: int, data_ticks: int, data_dim: int) -> None: - self.data_dir = data_dir + # TODO: All implementations related to `data_dir` is coupled with the specific data format for that specific case. + # TODO: So it should be redesigned after the data interface is well-designed. + def __init__( + self, + max_step: int, + data_ticks: int, + data_dim: int, + processed_data_provider: dict | ProcessedDataProvider, + ) -> None: self.max_step = max_step self.data_ticks = data_ticks self.data_dim = data_dim + self.processed_data_provider: ProcessedDataProvider = init_instance_by_config( + processed_data_provider, + accept_types=ProcessedDataProvider, + ) def interpret(self, state: SAOEState) -> FullHistoryObs: - processed = pickle_styled.load_intraday_processed_data( - self.data_dir, - state.order.stock_id, - pd.Timestamp(state.order.start_time.date()), - self.data_dim, - state.ticks_index, + # TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running + # backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant + # way to decompose interpreter and EnvWrapper in the future. + + processed = self.processed_data_provider.get_data( + stock_id=state.order.stock_id, + date=pd.Timestamp(state.order.start_time.date()), + feature_dim=self.data_dim, + time_index=state.ticks_index, ) position_history = np.full(self.max_step + 1, 0.0, dtype=np.float32) @@ -96,15 +111,15 @@ def interpret(self, state: SAOEState) -> FullHistoryObs: FullHistoryObs, canonicalize( { - "data_processed": self._mask_future_info(processed.today, state.cur_time), - "data_processed_prev": processed.yesterday, - "acquiring": state.order.direction == state.order.BUY, - "cur_tick": min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1), - "cur_step": min(self.env.status["cur_step"], self.max_step - 1), - "num_step": self.max_step, - "target": state.order.amount, - "position": state.position, - "position_history": position_history[: self.max_step], + "data_processed": np.array(self._mask_future_info(processed.today, state.cur_time)), + "data_processed_prev": np.array(processed.yesterday), + "acquiring": _to_int32(state.order.direction == state.order.BUY), + "cur_tick": _to_int32(min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1)), + "cur_step": _to_int32(min(self.env.status["cur_step"], self.max_step - 1)), + "num_step": _to_int32(self.max_step), + "target": _to_float32(state.order.amount), + "position": _to_float32(state.position), + "position_history": _to_float32(position_history[: self.max_step]), }, ), ) @@ -162,6 +177,10 @@ def observation_space(self) -> spaces.Dict: return spaces.Dict(space) def interpret(self, state: SAOEState) -> CurrentStateObs: + # TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running + # backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant + # way to decompose interpreter and EnvWrapper in the future. + assert self.env is not None assert self.env.status["cur_step"] <= self.max_step obs = CurrentStateObs( @@ -184,20 +203,31 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]): Then when policy givens decision $x$, $a_x$ times order amount is the output. It can also be an integer $n$, in which case the list of length $n+1$ is auto-generated, i.e., $[0, 1/n, 2/n, \\ldots, n/n]$. + max_step + Total number of steps (an upper-bound estimation). For example, 390min / 30min-per-step = 13 steps. """ - def __init__(self, values: int | List[float]) -> None: + def __init__(self, values: int | List[float], max_step: Optional[int] = None) -> None: if isinstance(values, int): values = [i / values for i in range(0, values + 1)] self.action_values = values + self.max_step = max_step @property def action_space(self) -> spaces.Discrete: return spaces.Discrete(len(self.action_values)) def interpret(self, state: SAOEState, action: int) -> float: + # TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running + # backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant + # way to decompose interpreter and EnvWrapper in the future. + assert 0 <= action < len(self.action_values) - return min(state.position, state.order.amount * self.action_values[action]) + assert self.env is not None + if self.max_step is not None and self.env.status["cur_step"] >= self.max_step - 1: + return state.position + else: + return min(state.position, state.order.amount * self.action_values[action]) class TwapRelativeActionInterpreter(ActionInterpreter[SAOEState, float, float]): @@ -214,7 +244,19 @@ def action_space(self) -> spaces.Box: return spaces.Box(0, np.inf, shape=(), dtype=np.float32) def interpret(self, state: SAOEState, action: float) -> float: + # TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running + # backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant + # way to decompose interpreter and EnvWrapper in the future. + assert self.env is not None estimated_total_steps = math.ceil(len(state.ticks_for_order) / state.ticks_per_step) twap_volume = state.position / (estimated_total_steps - self.env.status["cur_step"]) return min(state.position, twap_volume * action) + + +def _to_int32(val): + return np.array(int(val), dtype=np.int32) + + +def _to_float32(val): + return np.array(val, dtype=np.float32) diff --git a/qlib/rl/order_execution/network.py b/qlib/rl/order_execution/network.py index 3d0279559e..d6a11189cf 100644 --- a/qlib/rl/order_execution/network.py +++ b/qlib/rl/order_execution/network.py @@ -117,3 +117,24 @@ def forward(self, batch: Batch) -> torch.Tensor: out = torch.cat(sources, -1) return self.fc(out) + + +class Attention(nn.Module): + def __init__(self, in_dim, out_dim): + super().__init__() + self.q_net = nn.Linear(in_dim, out_dim) + self.k_net = nn.Linear(in_dim, out_dim) + self.v_net = nn.Linear(in_dim, out_dim) + + def forward(self, Q, K, V): + q = self.q_net(Q) + k = self.k_net(K) + v = self.v_net(V) + + attn = torch.einsum("ijk,ilk->ijl", q, k) + attn = attn.to(Q.device) + attn_prob = torch.softmax(attn, dim=-1) + + attn_vec = torch.einsum("ijk,ikl->ijl", attn_prob, v) + + return attn_vec diff --git a/qlib/rl/order_execution/simulator_qlib.py b/qlib/rl/order_execution/simulator_qlib.py index 3002fd333e..718c2ba572 100644 --- a/qlib/rl/order_execution/simulator_qlib.py +++ b/qlib/rl/order_execution/simulator_qlib.py @@ -11,7 +11,7 @@ from qlib.backtest.executor import NestedExecutor from qlib.rl.simulator import Simulator -from .integration import init_qlib +from qlib.rl.data.integration import init_qlib from .state import SAOEState, SAOEStateAdapter from .strategy import SAOEStrategy diff --git a/qlib/rl/order_execution/simulator_simple.py b/qlib/rl/order_execution/simulator_simple.py index f95aeebad0..17efb4b093 100644 --- a/qlib/rl/order_execution/simulator_simple.py +++ b/qlib/rl/order_execution/simulator_simple.py @@ -18,10 +18,10 @@ # TODO: Integrating Qlib's native data with simulator_simple -__all__ = ["SingleAssetOrderExecution"] +__all__ = ["SingleAssetOrderExecutionSimple"] -class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]): +class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]): """Single-asset order execution (SAOE) simulator. As there's no "calendar" in the simple simulator, ticks are used to trade. diff --git a/qlib/rl/order_execution/state.py b/qlib/rl/order_execution/state.py index d6bbeaea5a..a46928ee89 100644 --- a/qlib/rl/order_execution/state.py +++ b/qlib/rl/order_execution/state.py @@ -3,6 +3,7 @@ from __future__ import annotations +import typing from typing import cast, NamedTuple, Optional, Tuple import numpy as np @@ -10,11 +11,13 @@ from qlib.backtest import Exchange, Order from qlib.backtest.executor import BaseExecutor from qlib.constant import EPS, ONE_MIN, REG_CN -from qlib.rl.data.exchange_wrapper import IntradayBacktestData -from qlib.rl.data.pickle_styled import BaseIntradayBacktestData from qlib.rl.order_execution.utils import dataframe_append, price_advantage +from qlib.typehint import TypedDict from qlib.utils.time import get_day_min_idx_range -from typing_extensions import TypedDict + +if typing.TYPE_CHECKING: + from qlib.rl.data.base import BaseIntradayBacktestData + from qlib.rl.data.native import IntradayBacktestData def _get_all_timestamps( diff --git a/qlib/rl/order_execution/strategy.py b/qlib/rl/order_execution/strategy.py index 4a85bc76ed..ecc879bf51 100644 --- a/qlib/rl/order_execution/strategy.py +++ b/qlib/rl/order_execution/strategy.py @@ -5,17 +5,23 @@ import collections from types import GeneratorType -from typing import Any, Optional, Union, cast, Dict, Generator +from typing import Any, cast, Dict, Generator, Optional, Union import pandas as pd +import torch +from tianshou.data import Batch +from tianshou.policy import BasePolicy from qlib.backtest import CommonInfrastructure, Order from qlib.backtest.decision import BaseTradeDecision, TradeDecisionWO, TradeRange from qlib.backtest.utils import LevelInfrastructure from qlib.constant import ONE_MIN -from qlib.rl.data.exchange_wrapper import load_qlib_backtest_data -from qlib.rl.order_execution.state import SAOEStateAdapter, SAOEState +from qlib.rl.data.native import load_backtest_data +from qlib.rl.interpreter import ActionInterpreter, StateInterpreter +from qlib.rl.order_execution.state import SAOEState, SAOEStateAdapter +from qlib.rl.utils.env_wrapper import BaseEnvWrapper from qlib.strategy.base import RLStrategy +from qlib.utils import init_instance_by_config class SAOEStrategy(RLStrategy): @@ -41,7 +47,7 @@ def __init__( self._last_step_range = (0, 0) def _create_qlib_backtest_adapter(self, order: Order, trade_range: TradeRange) -> SAOEStateAdapter: - backtest_data = load_qlib_backtest_data(order, self.trade_exchange, trade_range) + backtest_data = load_backtest_data(order, self.trade_exchange, trade_range) return SAOEStateAdapter( order=order, @@ -106,7 +112,10 @@ def generate_trade_decision( return decision - def _generate_trade_decision(self, execute_result: list = None) -> Generator[Any, Any, BaseTradeDecision]: + def _generate_trade_decision( + self, + execute_result: list = None, + ) -> Union[BaseTradeDecision, Generator[Any, Any, BaseTradeDecision]]: raise NotImplementedError @@ -146,3 +155,110 @@ def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) - order_list = outer_trade_decision.order_list assert len(order_list) == 1 self._order = order_list[0] + + +class SAOEIntStrategy(SAOEStrategy): + """(SAOE)state based strategy with (Int)preters.""" + + def __init__( + self, + policy: dict | BasePolicy, + state_interpreter: dict | StateInterpreter, + action_interpreter: dict | ActionInterpreter, + network: object = None, # TODO: add accurate typehint later. + outer_trade_decision: BaseTradeDecision = None, + level_infra: LevelInfrastructure = None, + common_infra: CommonInfrastructure = None, + backtest: bool = False, + **kwargs: Any, + ) -> None: + super(SAOEIntStrategy, self).__init__( + policy=policy, + outer_trade_decision=outer_trade_decision, + level_infra=level_infra, + common_infra=common_infra, + **kwargs, + ) + + self._backtest = backtest + + self._state_interpreter: StateInterpreter = init_instance_by_config( + state_interpreter, + accept_types=StateInterpreter, + ) + self._action_interpreter: ActionInterpreter = init_instance_by_config( + action_interpreter, + accept_types=ActionInterpreter, + ) + + if isinstance(policy, dict): + assert network is not None + + if isinstance(network, dict): + network["kwargs"].update( + { + "obs_space": self._state_interpreter.observation_space, + } + ) + network_inst = init_instance_by_config(network) + else: + network_inst = network + + policy["kwargs"].update( + { + "obs_space": self._state_interpreter.observation_space, + "action_space": self._action_interpreter.action_space, + "network": network_inst, + } + ) + self._policy = init_instance_by_config(policy) + elif isinstance(policy, BasePolicy): + self._policy = policy + else: + raise ValueError(f"Unsupported policy type: {type(policy)}.") + + if self._policy is not None: + self._policy.eval() + + def set_env(self, env: BaseEnvWrapper) -> None: + # TODO: This method is used to set EnvWrapper for interpreters since they rely on EnvWrapper. + # We should decompose the interpreters with EnvWrapper in the future and we should remove this method + # after that. + + self._env = env + self._state_interpreter.env = self._action_interpreter.env = self._env + + def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) -> None: + super().reset(outer_trade_decision=outer_trade_decision, **kwargs) + + # In backtest, env.reset() needs to be manually called since there is no outer trainer to call it + if self._backtest: + self._env.reset() + + def _generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision: + states = [] + obs_batch = [] + for decision in self.outer_trade_decision.get_decision(): + order = cast(Order, decision) + state = self.get_saoe_state_by_order(order) + + states.append(state) + obs_batch.append({"obs": self._state_interpreter.interpret(state)}) + + with torch.no_grad(): + policy_out = self._policy(Batch(obs_batch)) + act = policy_out.act.numpy() if torch.is_tensor(policy_out.act) else policy_out.act + exec_vols = [self._action_interpreter.interpret(s, a) for s, a in zip(states, act)] + + # In backtest, env.step() needs to be manually called since there is no outer trainer to call it + if self._backtest: + self._env.step(None) + + oh = self.trade_exchange.get_order_helper() + order_list = [] + for decision, exec_vol in zip(self.outer_trade_decision.get_decision(), exec_vols): + if exec_vol != 0: + order = cast(Order, decision) + order_list.append(oh.create(order.stock_id, exec_vol, order.direction)) + + return TradeDecisionWO(order_list=order_list, strategy=self) diff --git a/qlib/rl/trainer/__init__.py b/qlib/rl/trainer/__init__.py index efce804c41..0a197b3781 100644 --- a/qlib/rl/trainer/__init__.py +++ b/qlib/rl/trainer/__init__.py @@ -4,6 +4,6 @@ """Train, test, inference utilities.""" from .api import backtest, train -from .callbacks import EarlyStopping, Checkpoint +from .callbacks import Checkpoint, EarlyStopping from .trainer import Trainer from .vessel import TrainingVessel, TrainingVesselBase diff --git a/qlib/rl/utils/env_wrapper.py b/qlib/rl/utils/env_wrapper.py index 529bfe5973..f082f3b013 100644 --- a/qlib/rl/utils/env_wrapper.py +++ b/qlib/rl/utils/env_wrapper.py @@ -4,7 +4,7 @@ from __future__ import annotations import weakref -from typing import Any, Callable, Dict, Generic, Iterable, Iterator, Optional, Tuple, cast +from typing import Any, Callable, cast, Dict, Generic, Iterable, Iterator, Optional, Tuple import gym from gym import Space @@ -14,7 +14,6 @@ from qlib.rl.reward import Reward from qlib.rl.simulator import ActType, InitialStateType, Simulator, StateType from qlib.typehint import TypedDict - from .finite_env import generate_nan_observation from .log import LogCollector, LogLevel @@ -49,9 +48,24 @@ class EnvWrapperStatus(TypedDict): reward_history: list -class EnvWrapper( +class BaseEnvWrapper( gym.Env[ObsType, PolicyActType], Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType], +): + """Base env wrapper for RL environments. It has two implementations: + - EnvWrapper: Qlib-based RL environment used in training. + - CollectDataEnvWrapper: Dummy environment used in collect_data_loop. + """ + + def __init__(self) -> None: + self.status: EnvWrapperStatus = cast(EnvWrapperStatus, None) + + def render(self, mode: str = "human") -> None: + raise NotImplementedError("Render is not implemented in BaseEnvWrapper.") + + +class EnvWrapper( + BaseEnvWrapper[InitialStateType, StateType, ActType, ObsType, PolicyActType], ): """Qlib-based RL environment, subclassing ``gym.Env``. A wrapper of components, including simulator, state-interpreter, action-interpreter, reward. @@ -115,6 +129,8 @@ def __init__( # 3. Avoid circular reference. # 4. When the components get serialized, we can throw away the env without any burden. # (though this part is not implemented yet) + super().__init__() + for obj in [state_interpreter, action_interpreter, reward_fn, aux_info_collector]: if obj is not None: obj.env = weakref.proxy(self) # type: ignore @@ -247,5 +263,19 @@ def step(self, policy_action: PolicyActType, **kwargs: Any) -> Tuple[ObsType, fl info_dict = InfoDict(log=self.logger.logs(), aux_info=aux_info) return obs, rew, done, info_dict - def render(self, mode: str = "human") -> None: - raise NotImplementedError("Render is not implemented in EnvWrapper.") + +class CollectDataEnvWrapper(BaseEnvWrapper[InitialStateType, StateType, ActType, ObsType, PolicyActType]): + """Dummy EnvWrapper for collect_data_loop. It only has minimum interfaces to support the collect_data_loop.""" + + def reset(self, **kwargs: Any) -> None: + self.status = EnvWrapperStatus( + cur_step=0, + done=False, + initial_state=None, + obs_history=[], + action_history=[], + reward_history=[], + ) + + def step(self, policy_action: Any = None, **kwargs: Any) -> None: + self.status["cur_step"] += 1 diff --git a/tests/rl/test_qlib_simulator.py b/tests/rl/test_qlib_simulator.py index b7d548e9ea..14bf8b5a11 100644 --- a/tests/rl/test_qlib_simulator.py +++ b/tests/rl/test_qlib_simulator.py @@ -11,6 +11,7 @@ from qlib.backtest.executor import SimulatorExecutor from qlib.rl.order_execution import CategoricalActionInterpreter from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution +from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper TOTAL_POSITION = 2100.0 @@ -192,6 +193,8 @@ def test_interpreter() -> None: order = get_order() simulator = get_simulator(order) interpreter_action = CategoricalActionInterpreter(values=NUM_EXECUTION) + interpreter_action.env = CollectDataEnvWrapper() + interpreter_action.env.reset() NUM_STEPS = 7 state = simulator.get_state() diff --git a/tests/rl/test_saoe_simple.py b/tests/rl/test_saoe_simple.py index 78df41690a..22bd039096 100644 --- a/tests/rl/test_saoe_simple.py +++ b/tests/rl/test_saoe_simple.py @@ -16,9 +16,11 @@ from qlib.config import C from qlib.log import set_log_with_config from qlib.rl.data import pickle_styled +from qlib.rl.data.pickle_styled import PickleProcessedDataProvider from qlib.rl.order_execution import * from qlib.rl.trainer import backtest, train from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus +from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8") @@ -40,16 +42,15 @@ def test_pickle_data_inspect(): data = pickle_styled.load_simple_intraday_backtest_data(BACKTEST_DATA_DIR, "AAL", "2013-12-11", "close", 0) assert len(data) == 390 - data = pickle_styled.load_intraday_processed_data( - DATA_DIR / "processed", "AAL", "2013-12-11", 5, data.get_time_index() - ) + provider = PickleProcessedDataProvider(DATA_DIR / "processed") + data = provider.get_data("AAL", "2013-12-11", 5, data.get_time_index()) assert len(data.today) == len(data.yesterday) == 390 def test_simulator_first_step(): order = Order("AAL", 30.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59")) - simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR) + simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR) state = simulator.get_state() assert state.cur_time == pd.Timestamp("2013-12-11 09:30:00") assert state.position == 30.0 @@ -83,7 +84,7 @@ def test_simulator_first_step(): def test_simulator_stop_twap(): order = Order("AAL", 13.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59")) - simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR) + simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR) for _ in range(13): simulator.step(1.0) @@ -106,10 +107,10 @@ def test_simulator_stop_early(): order = Order("AAL", 1.0, 1, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59")) with pytest.raises(ValueError): - simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR) + simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR) simulator.step(2.0) - simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR) + simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR) simulator.step(1.0) with pytest.raises(AssertionError): @@ -119,7 +120,7 @@ def test_simulator_stop_early(): def test_simulator_start_middle(): order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59")) - simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR) + simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR) assert len(simulator.ticks_for_order) == 330 assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00") simulator.step(2.0) @@ -138,7 +139,7 @@ def test_simulator_start_middle(): def test_interpreter(): order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59")) - simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR) + simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR) assert len(simulator.ticks_for_order) == 330 assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00") @@ -146,7 +147,7 @@ def test_interpreter(): class EmulateEnvWrapper(NamedTuple): status: EnvWrapperStatus - interpreter = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5) + interpreter = FullHistoryStateInterpreter(13, 390, 5, PickleProcessedDataProvider(FEATURE_DATA_DIR)) interpreter_step = CurrentStepStateInterpreter(13) interpreter_action = CategoricalActionInterpreter(20) interpreter_action_twap = TwapRelativeActionInterpreter() @@ -185,6 +186,10 @@ class EmulateEnvWrapper(NamedTuple): assert np.sum(obs["data_processed"][60:]) == 0 # second step: action + interpreter_action.env = CollectDataEnvWrapper() + interpreter_action_twap.env = CollectDataEnvWrapper() + interpreter_action.env.reset() + interpreter_action_twap.env.reset() action = interpreter_action(simulator.get_state(), 1) assert action == 15 / 20 @@ -219,13 +224,13 @@ def test_network_sanity(): # we won't check the correctness of networks here order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 9:30:00"), pd.Timestamp("2013-12-11 15:59:59")) - simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR) + simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR) assert len(simulator.ticks_for_order) == 390 class EmulateEnvWrapper(NamedTuple): status: EnvWrapperStatus - interpreter = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5) + interpreter = FullHistoryStateInterpreter(13, 390, 5, PickleProcessedDataProvider(FEATURE_DATA_DIR)) action_interp = CategoricalActionInterpreter(13) wrapper_status_kwargs = dict(initial_state=order, obs_history=[], action_history=[], reward_history=[]) @@ -253,13 +258,15 @@ def test_twap_strategy(finite_env_type): orders = pickle_styled.load_orders(ORDER_DIR) assert len(orders) == 248 - state_interp = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5) + state_interp = FullHistoryStateInterpreter(13, 390, 5, PickleProcessedDataProvider(FEATURE_DATA_DIR)) action_interp = TwapRelativeActionInterpreter() + action_interp.env = CollectDataEnvWrapper() + action_interp.env.reset() policy = AllOne(state_interp.observation_space, action_interp.action_space) csv_writer = CsvWriter(Path(__file__).parent / ".output") backtest( - partial(SingleAssetOrderExecution, data_dir=BACKTEST_DATA_DIR, ticks_per_step=30), + partial(SingleAssetOrderExecutionSimple, data_dir=BACKTEST_DATA_DIR, ticks_per_step=30), state_interp, action_interp, orders, @@ -282,15 +289,17 @@ def test_cn_ppo_strategy(): orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58")) assert len(orders) == 40 - state_interp = FullHistoryStateInterpreter(CN_FEATURE_DATA_DIR, 8, 240, 6) + state_interp = FullHistoryStateInterpreter(8, 240, 6, PickleProcessedDataProvider(CN_FEATURE_DATA_DIR)) action_interp = CategoricalActionInterpreter(4) + action_interp.env = CollectDataEnvWrapper() + action_interp.env.reset() network = Recurrent(state_interp.observation_space) policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4) policy.load_state_dict(torch.load(CN_POLICY_WEIGHTS_DIR / "ppo_recurrent_30min.pth", map_location="cpu")) csv_writer = CsvWriter(Path(__file__).parent / ".output") backtest( - partial(SingleAssetOrderExecution, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30), + partial(SingleAssetOrderExecutionSimple, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30), state_interp, action_interp, orders, @@ -313,13 +322,15 @@ def test_ppo_train(): orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58")) assert len(orders) == 40 - state_interp = FullHistoryStateInterpreter(CN_FEATURE_DATA_DIR, 8, 240, 6) + state_interp = FullHistoryStateInterpreter(8, 240, 6, PickleProcessedDataProvider(CN_FEATURE_DATA_DIR)) action_interp = CategoricalActionInterpreter(4) + action_interp.env = CollectDataEnvWrapper() + action_interp.env.reset() network = Recurrent(state_interp.observation_space) policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4) train( - partial(SingleAssetOrderExecution, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30), + partial(SingleAssetOrderExecutionSimple, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30), state_interp, action_interp, orders,