Skip to content

Commit

Permalink
Backtest Mypy (microsoft#1130)
Browse files Browse the repository at this point in the history
* Done

* Fix test errors

* Revert profit_attribution.py

* Minor

* A minor update on collect_data type hint

* Resolve PR comments

* Use black to format code

* Fix CI errors
  • Loading branch information
lihuoran authored Jun 28, 2022
1 parent 0c58469 commit b19087f
Show file tree
Hide file tree
Showing 17 changed files with 364 additions and 316 deletions.
2 changes: 1 addition & 1 deletion .mypy.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[mypy]
exclude = (?x)(
^qlib/backtest
^qlib/backtest/high_performance_ds\.py$
| ^qlib/contrib
| ^qlib/data
| ^qlib/model
Expand Down
44 changes: 20 additions & 24 deletions qlib/backtest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import copy
from pathlib import Path
from typing import TYPE_CHECKING, Generator, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Generator, List, Optional, Tuple, Union

import pandas as pd

Expand All @@ -23,7 +23,6 @@
from .backtest import backtest_loop, collect_data_loop
from .decision import Order
from .exchange import Exchange
from .position import Position
from .utils import CommonInfrastructure

# make import more user-friendly by adding `from qlib.backtest import STH`
Expand All @@ -44,22 +43,23 @@ def get_exchange(
min_cost: float = 5.0,
limit_threshold: Union[Tuple[str, str], float, None] = None,
deal_price: Union[str, Tuple[str], List[str]] = None,
**kwargs,
**kwargs: Any,
) -> Exchange:
"""get_exchange
Parameters
----------
# exchange related arguments
exchange: Exchange(). It could be None or any types that are acceptable by `init_instance_by_config`.
exchange: Exchange
It could be None or any types that are acceptable by `init_instance_by_config`.
freq: str
frequency of data.
start_time: Union[pd.Timestamp, str]
closed start time for backtest.
end_time: Union[pd.Timestamp, str]
closed end time for backtest.
codes: list|str
codes: Union[list, str]
list stock_id list or a string of instruments (i.e. all, csi500, sse50)
subscribe_fields: list
subscribe fields.
Expand Down Expand Up @@ -151,28 +151,24 @@ def create_account_instance(
Postion type.
"""
if isinstance(account, (int, float)):
pos_kwargs = {"init_cash": account}
init_cash = account
position_dict = {}
elif isinstance(account, dict):
init_cash = account["cash"]
del account["cash"]
pos_kwargs = {
"init_cash": init_cash,
"position_dict": account,
}
init_cash = account.pop("cash")
position_dict = account
else:
raise ValueError("account must be in (int, float, Position)")
raise ValueError("account must be in (int, float, dict)")

kwargs = {
"init_cash": account,
"benchmark_config": {
return Account(
init_cash=init_cash,
position_dict=position_dict,
pos_type=pos_type,
benchmark_config={
"benchmark": benchmark,
"start_time": start_time,
"end_time": end_time,
},
"pos_type": pos_type,
}
kwargs.update(pos_kwargs)
return Account(**kwargs)
)


def get_strategy_executor(
Expand All @@ -181,7 +177,7 @@ def get_strategy_executor(
strategy: Union[str, dict, object, Path],
executor: Union[str, dict, object, Path],
benchmark: str = "SH000300",
account: Union[float, int, Position] = 1e9,
account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {},
pos_type: str = "Position",
) -> Tuple[BaseStrategy, BaseExecutor]:
Expand Down Expand Up @@ -222,7 +218,7 @@ def backtest(
strategy: Union[str, dict, object, Path],
executor: Union[str, dict, object, Path],
benchmark: str = "SH000300",
account: Union[float, int, Position] = 1e9,
account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {},
pos_type: str = "Position",
) -> Tuple[PortfolioMetrics, Indicator]:
Expand Down Expand Up @@ -285,7 +281,7 @@ def collect_data(
strategy: Union[str, dict, object, Path],
executor: Union[str, dict, object, Path],
benchmark: str = "SH000300",
account: Union[float, int, Position] = 1e9,
account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {},
pos_type: str = "Position",
return_value: dict = None,
Expand Down Expand Up @@ -339,7 +335,7 @@ def format_decisions(

cur_freq = decisions[0].strategy.trade_calendar.get_freq()

res = (cur_freq, [])
res: Tuple[str, list] = (cur_freq, [])
last_dec_idx = 0
for i, dec in enumerate(decisions[1:], 1):
if dec.strategy.trade_calendar.get_freq() == cur_freq:
Expand Down
30 changes: 18 additions & 12 deletions qlib/backtest/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from __future__ import annotations

import copy
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple, cast

import pandas as pd

from qlib.utils import init_instance_by_config

from .decision import BaseTradeDecision, Order
from .exchange import Exchange
from .high_performance_ds import BaseOrderIndicator
from .position import BasePosition
from .report import Indicator, PortfolioMetrics

Expand Down Expand Up @@ -104,7 +105,7 @@ def __init__(

self._pos_type = pos_type
self._port_metr_enabled = port_metr_enabled
self.benchmark_config = None # avoid no attribute error
self.benchmark_config: dict = {} # avoid no attribute error
self.init_vars(init_cash, position_dict, freq, benchmark_config)

def init_vars(self, init_cash: float, position_dict: dict, freq: str, benchmark_config: dict) -> None:
Expand All @@ -124,8 +125,8 @@ def init_vars(self, init_cash: float, position_dict: dict, freq: str, benchmark_
self.accum_info = AccumulatedInfo()

# 2) following variables are not shared between layers
self.portfolio_metrics = None
self.hist_positions = {}
self.portfolio_metrics: Optional[PortfolioMetrics] = None
self.hist_positions: Dict[pd.Timestamp, BasePosition] = {}
self.reset(freq=freq, benchmark_config=benchmark_config)

def is_port_metr_enabled(self) -> bool:
Expand Down Expand Up @@ -171,7 +172,7 @@ def reset(self, freq: str = None, benchmark_config: dict = None, port_metr_enabl

self.reset_report(self.freq, self.benchmark_config)

def get_hist_positions(self) -> dict:
def get_hist_positions(self) -> Dict[pd.Timestamp, BasePosition]:
return self.hist_positions

def get_cash(self) -> float:
Expand Down Expand Up @@ -230,13 +231,15 @@ def update_current_position(
"""
# update price for stock in the position and the profit from changed_price
# NOTE: updating position does not only serve portfolio metrics, it also serve the strategy
assert self.current_position is not None

if not self.current_position.skip_update():
stock_list = self.current_position.get_stock_list()
for code in stock_list:
# if suspend, no new price to be updated, profit is 0
if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time):
continue
bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time)
bar_close = cast(float, trade_exchange.get_close(code, trade_start_time, trade_end_time))
self.current_position.update_stock_price(stock_id=code, price=bar_close)
# update holding day count
# NOTE: updating bar_count does not only serve portfolio metrics, it also serve the strategy
Expand All @@ -249,6 +252,8 @@ def update_portfolio_metrics(self, trade_start_time: pd.Timestamp, trade_end_tim
# for the first trade date, account_value - init_cash
# self.portfolio_metrics.is_empty() to judge is_first_trade_date
# get last_account_value, last_total_cost, last_total_turnover
assert self.portfolio_metrics is not None

if self.portfolio_metrics.is_empty():
last_account_value = self.init_cash
last_total_cost = 0
Expand Down Expand Up @@ -299,9 +304,9 @@ def update_indicator(
trade_exchange: Exchange,
atomic: bool,
outer_trade_decision: BaseTradeDecision,
trade_info: list = None,
inner_order_indicators: List[Dict[str, pd.Series]] = None,
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
trade_info: list = [],
inner_order_indicators: List[BaseOrderIndicator] = [],
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = [],
indicator_config: dict = {},
) -> None:
"""update trade indicators and order indicators in each bar end"""
Expand Down Expand Up @@ -335,9 +340,9 @@ def update_bar_end(
trade_exchange: Exchange,
atomic: bool,
outer_trade_decision: BaseTradeDecision,
trade_info: list = None,
inner_order_indicators: List[Dict[str, pd.Series]] = None,
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
trade_info: list = [],
inner_order_indicators: List[BaseOrderIndicator] = [],
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = [],
indicator_config: dict = {},
) -> None:
"""update account at each trading bar step
Expand Down Expand Up @@ -398,6 +403,7 @@ def update_bar_end(
def get_portfolio_metrics(self) -> Tuple[pd.DataFrame, dict]:
"""get the history portfolio_metrics and positions instance"""
if self.is_port_metr_enabled():
assert self.portfolio_metrics is not None
_portfolio_metrics = self.portfolio_metrics.generate_portfolio_metrics_dataframe()
_positions = self.get_hist_positions()
return _portfolio_metrics, _positions
Expand Down
9 changes: 6 additions & 3 deletions qlib/backtest/backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union
from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union, cast

import pandas as pd

Expand Down Expand Up @@ -36,10 +36,13 @@ def backtest_loop(
indicator: Indicator
it computes the trading indicator
"""
return_value = {}
return_value: dict = {}
for _decision in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value):
pass
return return_value.get("portfolio_metrics"), return_value.get("indicator")

portfolio_metrics = cast(PortfolioMetrics, return_value.get("portfolio_metrics"))
indicator = cast(Indicator, return_value.get("indicator"))
return portfolio_metrics, indicator


def collect_data_loop(
Expand Down
40 changes: 23 additions & 17 deletions qlib/backtest/decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from enum import IntEnum

# try to fix circular imports when enabling type hints
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union
from typing import Generic, List, TYPE_CHECKING, Any, ClassVar, Optional, Tuple, TypeVar, Union, cast

from qlib.backtest.utils import TradeCalendarManager
from qlib.data.data import Cal
Expand All @@ -24,8 +24,11 @@
import pandas as pd


DecisionType = TypeVar("DecisionType")


class OrderDir(IntEnum):
# Order direction
# Order direction
SELL = 0
BUY = 1

Expand Down Expand Up @@ -65,7 +68,7 @@ class Order:
# - not tradable: the deal_amount == 0 , factor is None
# - the stock is suspended and the entire order fails. No cost for this order
# - dealt or partially dealt: deal_amount >= 0 and factor is not None
deal_amount: Optional[float] = None # `deal_amount` is a non-negative value
deal_amount: float = 0.0 # `deal_amount` is a non-negative value
factor: Optional[float] = None

# TODO:
Expand Down Expand Up @@ -281,7 +284,7 @@ def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> T
return max(val_start, start_time), min(val_end, end_time)


class BaseTradeDecision:
class BaseTradeDecision(Generic[DecisionType]):
"""
Trade decisions ara made by strategy and executed by executor
Expand Down Expand Up @@ -316,20 +319,21 @@ def __init__(self, strategy: BaseStrategy, trade_range: Union[Tuple[int, int], T
"""
self.strategy = strategy
self.start_time, self.end_time = strategy.trade_calendar.get_step_time()
self.total_step = None # upper strategy has no knowledge about the sub executor before `_init_sub_trading`
if isinstance(trade_range, Tuple):
# upper strategy has no knowledge about the sub executor before `_init_sub_trading`
self.total_step: Optional[int] = None
if isinstance(trade_range, tuple):
# for Tuple[int, int]
trade_range = IdxTradeRange(*trade_range)
self.trade_range: TradeRange = trade_range
self.trade_range: Optional[TradeRange] = trade_range

def get_decision(self) -> List[object]:
def get_decision(self) -> List[DecisionType]:
"""
get the **concrete decision** (e.g. execution orders)
This will be called by the inner strategy
Returns
-------
List[object]:
List[DecisionType:
The decision result. Typically it is some orders
Example:
[]:
Expand Down Expand Up @@ -363,13 +367,13 @@ def update(self, trade_calendar: TradeCalendarManager) -> Optional[BaseTradeDeci
# purpose 2)
return self.strategy.update_trade_decision(self, trade_calendar)

def _get_range_limit(self, **kwargs) -> Tuple[int, int]:
def _get_range_limit(self, **kwargs: Any) -> Tuple[int, int]:
if self.trade_range is not None:
return self.trade_range(trade_calendar=kwargs.get("inner_calendar"))
return self.trade_range(trade_calendar=cast(TradeCalendarManager, kwargs.get("inner_calendar")))
else:
raise NotImplementedError("The decision didn't provide an index range")

def get_range_limit(self, **kwargs) -> Tuple[int, int]:
def get_range_limit(self, **kwargs: Any) -> Tuple[int, int]:
"""
return the expected step range for limiting the decision execution time
Both left and right are **closed**
Expand Down Expand Up @@ -421,6 +425,7 @@ def get_range_limit(self, **kwargs) -> Tuple[int, int]:
if getattr(self, "total_step", None) is not None:
# if `self.update` is called.
# Then the _start_idx, _end_idx should be clipped
assert self.total_step is not None
if _start_idx < 0 or _end_idx >= self.total_step:
logger = get_module_logger("decision")
logger.warning(
Expand Down Expand Up @@ -516,31 +521,32 @@ def mod_inner_decision(self, inner_trade_decision: BaseTradeDecision) -> None:
inner_trade_decision.trade_range = self.trade_range


class EmptyTradeDecision(BaseTradeDecision):
class EmptyTradeDecision(BaseTradeDecision[object]):
def get_decision(self) -> List[object]:
return []

def empty(self) -> bool:
return True


class TradeDecisionWO(BaseTradeDecision):
class TradeDecisionWO(BaseTradeDecision[Order]):
"""
Trade Decision (W)ith (O)rder.
Besides, the time_range is also included.
"""

def __init__(self, order_list: List[Order], strategy: BaseStrategy, trade_range: Tuple[int, int] = None):
def __init__(self, order_list: List[object], strategy: BaseStrategy, trade_range: Tuple[int, int] = None) -> None:
super().__init__(strategy, trade_range=trade_range)
self.order_list = order_list
self.order_list = cast(List[Order], order_list)
start, end = strategy.trade_calendar.get_step_time()
for o in order_list:
assert isinstance(o, Order)
if o.start_time is None:
o.start_time = start
if o.end_time is None:
o.end_time = end

def get_decision(self) -> List[object]:
def get_decision(self) -> List[Order]:
return self.order_list

def __repr__(self) -> str:
Expand Down
Loading

0 comments on commit b19087f

Please sign in to comment.