diff --git a/qlib/backtest/executor.py b/qlib/backtest/executor.py index 664f33a3cd..afed973ba2 100644 --- a/qlib/backtest/executor.py +++ b/qlib/backtest/executor.py @@ -587,20 +587,18 @@ def _get_order_iterator(self, trade_decision: BaseTradeDecision) -> List[Order]: raise NotImplementedError(f"This type of input is not supported") return order_it - def _update_dealt_order_amount(self, order: Order) -> None: - """update date and dealt order amount in the day.""" - - now_deal_day = self.trade_calendar.get_step_time()[0].floor(freq="D") - if self.deal_day is None or now_deal_day > self.deal_day: - self.dealt_order_amount = defaultdict(float) - self.deal_day = now_deal_day - self.dealt_order_amount[order.stock_id] += order.deal_amount - def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]: trade_start_time, _ = self.trade_calendar.get_step_time() execute_result: list = [] for order in self._get_order_iterator(trade_decision): + # Each time we move into a new date, clear `self.dealt_order_amount` since it only maintains intraday + # information. + now_deal_day = self.trade_calendar.get_step_time()[0].floor(freq="D") + if self.deal_day is None or now_deal_day > self.deal_day: + self.dealt_order_amount = defaultdict(float) + self.deal_day = now_deal_day + # execute the order. # NOTE: The trade_account will be changed in this function trade_val, trade_cost, trade_price = self.trade_exchange.deal_order( @@ -609,7 +607,9 @@ def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0) -> Tu dealt_order_amount=self.dealt_order_amount, ) execute_result.append((order, trade_val, trade_cost, trade_price)) - self._update_dealt_order_amount(order) + + self.dealt_order_amount[order.stock_id] += order.deal_amount + if self.verbose: print( "[I {:%Y-%m-%d %H:%M:%S}]: {} {}, price {:.2f}, amount {}, deal_amount {}, factor {}, " diff --git a/qlib/backtest/utils.py b/qlib/backtest/utils.py index f815d10554..595b2acccd 100644 --- a/qlib/backtest/utils.py +++ b/qlib/backtest/utils.py @@ -183,8 +183,8 @@ def get_range_idx(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tup Tuple[int, int]: the index of the range. **the left and right are closed** """ - left = np.searchsorted(self._calendar, start_time, side="right") - 1 - right = np.searchsorted(self._calendar, end_time, side="right") - 1 + left = int(np.searchsorted(self._calendar, start_time, side="right") - 1) + right = int(np.searchsorted(self._calendar, end_time, side="right") - 1) left -= self.start_index right -= self.start_index