diff --git a/gs_quant/backtests/actions.py b/gs_quant/backtests/actions.py index 4bb8e5f2..426e26c3 100644 --- a/gs_quant/backtests/actions.py +++ b/gs_quant/backtests/actions.py @@ -15,31 +15,38 @@ """ from collections import namedtuple -from typing import TypeVar +import copy +from dataclasses import dataclass, field +from dataclasses_json import dataclass_json +from typing import TypeVar, Callable from gs_quant.backtests.backtest_utils import * from gs_quant.backtests.backtest_objects import ConstantTransactionModel, TransactionModel -from gs_quant.risk.transform import Transformer from gs_quant.base import Priceable -from gs_quant.markets.securities import * +from gs_quant.common import RiskMeasure from gs_quant.markets.portfolio import Portfolio +from gs_quant.markets.securities import * +from gs_quant.risk.transform import Transformer from gs_quant.target.backtests import BacktestTradingQuantityType action_count = 1 +def default_transaction_cost(obj): + return field(default_factory=lambda: copy.copy(obj)) + + +@dataclass_json +@dataclass class Action(object): - def __init__(self, name: str = None): - self._needs_scaling = False - self._calc_type = CalcType.simple - self._risk = None - global action_count - if name is None: - self._name = 'Action{}'.format(action_count) - action_count += 1 - else: - self._name = name - self._transaction_cost = ConstantTransactionModel(0) + _needs_scaling = False + _calc_type = CalcType.simple + _risk = None + _transaction_cost = ConstantTransactionModel(0) + name = None + + def __post_init__(self): + self.set_name(self.name) @property def calc_type(self): @@ -49,49 +56,52 @@ def calc_type(self): def risk(self): return self._risk + def set_name(self, name: str): + global action_count + if self.name is None: + self.name = 'Action{}'.format(action_count) + action_count += 1 + @property def transaction_cost(self): return self._transaction_cost + @transaction_cost.setter + def transaction_cost(self, value): + self._transaction_cost = value + TAction = TypeVar('TAction', bound='Action') +@dataclass_json +@dataclass class AddTradeAction(Action): - def __init__(self, - priceables: Union[Priceable, Iterable[Priceable]], - trade_duration: Union[str, dt.date, dt.timedelta] = None, - name: str = None, - transaction_cost: TransactionModel = ConstantTransactionModel(0)): - """ - create an action which adds a trade when triggered. The trades are resolved on the trigger date (state) and - last until the trade_duration if specified or for all future dates if not. - :param priceables: a priceable or a list of pricables. - :param trade_duration: an instrument attribute eg. 'expiration_date' or a date or a tenor or timedelta - if left as None the - trade will be added for all future dates - :param name: optional additional name to the priceable name - :param transaction_cost: optional a cash amount paid for each transaction, paid on both enter and exit - """ - super().__init__(name) - - self._priceables = [] + """ + create an action which adds a trade when triggered. The trades are resolved on the trigger date (state) and + last until the trade_duration if specified or for all future dates if not. + :param priceables: a priceable or a list of pricables. + :param trade_duration: an instrument attribute eg. 'expiration_date' or a date or a tenor or timedelta + if left as None the + trade will be added for all future dates + :param name: optional additional name to the priceable name + :param transaction_cost: optional a cash amount paid for each transaction, paid on both enter and exit + """ + + priceables: Union[Priceable, Iterable[Priceable]] = None + trade_duration: Union[str, dt.date, dt.timedelta] = None + name: str = None + transaction_cost: TransactionModel = default_transaction_cost(ConstantTransactionModel()) + + def __post_init__(self): self._dated_priceables = {} - self._trade_duration = trade_duration - self._transaction_cost = transaction_cost - for i, p in enumerate(make_list(priceables)): + named_priceables = [] + for i, p in enumerate(make_list(self.priceables)): if p.name is None: - self._priceables.append(p.clone(name=f'{self._name}_Priceable{i}')) + named_priceables.append(p.clone(name=f'{self.name}_Priceable{i}')) else: - self._priceables.append(p.clone(name=f'{self._name}_{p.name}')) - - @property - def priceables(self): - return self._priceables - - @property - def trade_duration(self): - return self._trade_duration + named_priceables.append(p.clone(name=f'{self.name}_{p.name}')) + self.priceables = named_priceables def set_dated_priceables(self, state, priceables): self._dated_priceables[state] = make_list(priceables) @@ -100,10 +110,6 @@ def set_dated_priceables(self, state, priceables): def dated_priceables(self): return self._dated_priceables - @property - def transaction_cost(self): - return self._transaction_cost - AddTradeActionInfo = namedtuple('AddTradeActionInfo', 'scaling') EnterPositionQuantityScaledActionInfo = namedtuple('EnterPositionQuantityScaledActionInfo', 'not_applicable') @@ -112,14 +118,10 @@ def transaction_cost(self): RebalanceActionInfo = namedtuple('RebalanceActionInfo', 'not_applicable') +@dataclass_json +@dataclass class EnterPositionQuantityScaledAction(Action): - def __init__(self, - priceables: Union[Priceable, Iterable[Priceable]], - trade_duration: Union[str, dt.date, dt.timedelta] = None, - name: str = None, - trade_quantity: float = 1, - trade_quantity_type: Union[BacktestTradingQuantityType, str] = BacktestTradingQuantityType.quantity): - """ + """ create an action which enters trades when triggered. The trades are executed with specified quantity and last until the trade_duration if specified, or for all future dates if not. :param priceables: a priceable or a list of pricables. @@ -129,160 +131,101 @@ def __init__(self, :param trade_quantity: the amount, in units of trade_quantity_type to be traded :param trade_quantity_type: the quantity type used to scale trade. eg. quantity for units, notional for underlier notional - """ - super().__init__(name) - self._priceables = [] - self._trade_duration = trade_duration - for i, p in enumerate(make_list(priceables)): + """ + priceables: Union[Priceable, Iterable[Priceable]] = None + trade_duration: Union[str, dt.date, dt.timedelta] = None + name: str = None + trade_quantity: float = 1 + trade_quantity_type: Union[BacktestTradingQuantityType, str] = BacktestTradingQuantityType.quantity + + def __post_init__(self): + named_priceables = [] + for i, p in enumerate(make_list(self.priceables)): if p.name is None: - self._priceables.append(p.clone(name=f'{self._name}_Priceable{i}')) + named_priceables.append(p.clone(name=f'{self.name}_Priceable{i}')) else: - self._priceables.append(p.clone(name=f'{self._name}_{p.name}')) - self._trade_quantity = trade_quantity - self._trade_quantity_type = trade_quantity_type - - @property - def priceables(self): - return self._priceables - - @property - def trade_duration(self): - return self._trade_duration - - @property - def trade_quantity(self): - return self._trade_quantity - - @property - def trade_quantity_type(self): - return self._trade_quantity_type + named_priceables.append(p.clone(name=f'{self.name}_{p.name}')) + self.priceables = named_priceables +@dataclass_json +@dataclass class ExitPositionAction(Action): - def __init__(self, name: str = None): - """ - Fully exit all held positions - :param name: optional name of the action - """ - super().__init__(name) + name: str = None +@dataclass_json +@dataclass class ExitTradeAction(Action): - def __init__(self, priceable_names: Union[str, Iterable[str]] = None, name: str = None): - """ - Fully exit all held positions - :param priceable_names: optional string or list of strings of priceable names - :param name: optional name of the action - """ - super().__init__(name) - self._priceables_names = make_list(priceable_names) + priceable_names: Union[str, Iterable[str]] = None + name: str = None - @property - def priceable_names(self): - return self._priceables_names + def __post_init__(self): + self.priceables_names = make_list(self.priceable_names) +@dataclass_json +@dataclass class ExitAllPositionsAction(ExitTradeAction): - def __init__(self, name: str = None): - """ - Fully exit all held positions - :param name: optional name of the action - """ - super().__init__([], name) + """ + Fully exit all held positions + """ + def __post_init__(self): self._calc_type = CalcType.path_dependent +@dataclass_json +@dataclass class HedgeAction(Action): - def __init__(self, risk, priceables: Priceable = None, trade_duration: str = None, name: str = None, - csa_term: str = None, scaling_parameter: str = 'notional_amount', - transaction_cost: TransactionModel = ConstantTransactionModel(0), - risk_transformation: Transformer = None): - super().__init__(name) + risk: RiskMeasure = None + priceables: Optional[Priceable] = None + trade_duration: Union[str, dt.date, dt.timedelta] = None + name: str = None + csa_term: str = None + scaling_parameter: str = 'notional_amount' + transaction_cost: TransactionModel = default_transaction_cost(ConstantTransactionModel()) + risk_transformation: Transformer = None + + def __post_init__(self): self._calc_type = CalcType.semi_path_dependent - self._priceable = priceables - self._risk = risk - self._trade_duration = trade_duration - self._csa_term = csa_term - self._scaling_parameter = scaling_parameter - self._transaction_cost = transaction_cost - self._risk_transformation = risk_transformation - if isinstance(priceables, Portfolio): - trades = [] - for i, priceable in enumerate(priceables): + if isinstance(self.priceables, Portfolio): + named_priceables = [] + for i, priceable in enumerate(self.priceables): if priceable.name is None: - trades.append(priceable.clone(name=f'{self._name}_Priceable{i}')) + named_priceables.append(priceable.clone(name=f'{self.name}_Priceable{i}')) else: - trades.append(priceable.clone(name=f'{self._name}_{priceable.name}')) - self._priceable = Portfolio(trades) + named_priceables.append(priceable.clone(name=f'{self.name}_{priceable.name}')) + named_priceable = Portfolio(named_priceables) + elif isinstance(self.priceables, Priceable): + if self.priceables.name is None: + named_priceable = self.priceables.clone(name=f'{self.name}_Priceable0') + else: + named_priceable = self.priceables.clone(name=f'{self.name}_{self.priceables.name}') else: - if priceables is not None: - if priceables.name is None: - self._priceable = priceables.clone(name=f'{self._name}_Priceable0') - else: - self._priceable = priceables.clone(name=f'{self._name}_{priceables.name}') - - @property - def trade_duration(self): - return self._trade_duration + raise RuntimeError('hedge action only accepts one trade or one portfolio') + self.priceables = named_priceable @property def priceable(self): - return self._priceable - - @property - def risk(self): - return self._risk - - @property - def csa_term(self): - return self._csa_term - - @property - def scaling_parameter(self): - return self._scaling_parameter - - @property - def transaction_cost(self): - return self._transaction_cost - - @property - def risk_transformation(self): - return self._risk_transformation + return self.priceables +@dataclass_json +@dataclass class RebalanceAction(Action): - def __init__(self, priceable: Priceable, size_parameter, method, - transaction_cost: TransactionModel = ConstantTransactionModel(0)): - super().__init__() + priceable: Priceable = None + size_parameter: Union[str, float] = None + method: Callable = None + transaction_cost: TransactionModel = default_transaction_cost(ConstantTransactionModel()) + name: str = None + + def __post_init__(self): self._calc_type = CalcType.path_dependent - self._size_parameter = size_parameter - self._method = method - self._transaction_cost = transaction_cost - if priceable.unresolved is None: + if self.priceable.unresolved is None: raise ValueError("Please specify a resolved priceable to rebalance.") - if priceable is not None: - if priceable.name is None: - self._priceable = priceable.clone(name=f'{self._name}_Priceable0') + if self.priceable is not None: + if self.priceable.name is None: + self.priceable = self.priceable.clone(name=f'{self.name}_Priceable0') else: - self._priceable = priceable.clone(name=f'{self._name}_{priceable.name}') - - @property - def priceable(self): - return self._priceable - - @property - def size_parameter(self): - return self._size_parameter - - @property - def method(self): - return self._method - - @property - def args(self): - return self._args - - @property - def transaction_cost(self): - return self._transaction_cost + self.priceable = self.priceable.clone(name=f'{self.name}_{self.priceable.name}') diff --git a/gs_quant/backtests/backtest_objects.py b/gs_quant/backtests/backtest_objects.py index d9e62187..e82fe6de 100644 --- a/gs_quant/backtests/backtest_objects.py +++ b/gs_quant/backtests/backtest_objects.py @@ -17,14 +17,18 @@ from abc import ABC from collections import defaultdict from copy import deepcopy +from dataclasses import dataclass +from dataclasses_json import dataclass_json from queue import Queue as FifoQueue from typing import Iterable, TypeVar, Optional import numpy as np import pandas as pd +from gs_quant.common import RiskMeasure from gs_quant.instrument import Cash from gs_quant.markets.portfolio import Portfolio +from gs_quant.backtests.backtest_utils import make_list from gs_quant.backtests.core import ValuationMethod from gs_quant.backtests.data_handler import DataHandler from gs_quant.backtests.event import FillEvent @@ -39,17 +43,22 @@ class BaseBacktest(ABC): TBaseBacktest = TypeVar('TBaseBacktest', bound='BaseBacktest') +@dataclass_json +@dataclass class BackTest(BaseBacktest): - def __init__(self, strategy, states, risks): + strategy: object + states: Iterable + risks: Iterable[RiskMeasure] + + def __post_init__(self): self._portfolio_dict = defaultdict(Portfolio) # portfolio by state self._cash_dict = {} # cash by state self._hedges = defaultdict(list) # list of Hedge by date self._cash_payments = defaultdict(list) # list of cash payments (entry, unwind) self._transaction_costs = defaultdict(int) # list of transaction costs by date - self._strategy = deepcopy(strategy) # the strategy definition - self._states = states # list of states + self.strategy = deepcopy(self.strategy) # the strategy definition self._results = defaultdict(list) - self._risks = tuple(risks) # list of risks to calculate + self.risks = make_list(self.risks) # list of risks to calculate self._calc_calls = 0 self._calculations = 0 @@ -85,10 +94,6 @@ def hedges(self): def hedges(self, hedges): self._hedges = hedges - @property - def states(self): - return self._states - @property def results(self): return self._results @@ -96,10 +101,6 @@ def results(self): def set_results(self, date, results): self._results[date] = results - @property - def risks(self): - return self._risks - def add_results(self, date, results, replace=False): if date in self._results and len(self._results[date]) and not replace: self._results[date] += results @@ -204,7 +205,8 @@ def strategy_as_time_series(self): risk_measure_table = risk_measure_table.rename(columns={'pricing_date': 'Pricing Date', 'instrument_name': 'Instrument Name'}) risk_measure_table = risk_measure_table.set_index(['Pricing Date', 'Instrument Name']) - risk_measure_table.columns = pd.MultiIndex.from_product([['Risk Measures'], risk_measure_table.columns]) + risk_measure_table.columns = pd.MultiIndex.from_product( + [['Risk Measures'], [str(col) for col in risk_measure_table.columns]]) risk_and_cp_joined = risk_measure_table.join(cp_table, how='outer') @@ -260,40 +262,35 @@ def __init__(self, self.exit_payment = exit_payment +@dataclass_json() +@dataclass class TransactionModel: def get_cost(self, state, backtest, info) -> float: pass +@dataclass_json() +@dataclass class ConstantTransactionModel(TransactionModel): - def __init__(self, cost): - self._cost = cost + cost: float = 0 def get_cost(self, state, backtest, info) -> float: - return self._cost + return self.cost +@dataclass_json +@dataclass class PredefinedAssetBacktest(BaseBacktest): - """ - :param data_handler: holds all the data required to run the backtest - :param performance: backtest values - :param cash_asset: currently restricted to USD non-accrual - :param holdings: a dictionary keyed by instruments with quantity values - :param historical_holdings: holdings for each backtest date - :param orders: a list of all the orders generated - :param initial_value: the initial value of the index - :param results: a dictionary which can be used to store intermediate results - """ - - def __init__(self, data_handler: DataHandler, initial_value: float): - self.data_handler = data_handler + data_handler: DataHandler + initial_value: float + + def __post_init__(self): self.performance = pd.Series(dtype=float) self.cash_asset = Cash('USD') self.holdings = defaultdict(float) self.historical_holdings = pd.Series(dtype=float) self.historical_weights = pd.Series(dtype=float) self.orders = [] - self.initial_value = initial_value self.results = {} def set_start_date(self, start: dt.date): diff --git a/gs_quant/backtests/data_sources.py b/gs_quant/backtests/data_sources.py index 21f88bee..e3aba71a 100644 --- a/gs_quant/backtests/data_sources.py +++ b/gs_quant/backtests/data_sources.py @@ -14,6 +14,8 @@ under the License. """ +from dataclasses import dataclass +from dataclasses_json import dataclass_json import datetime as dt from enum import Enum import numpy as np @@ -32,52 +34,69 @@ class MissingDataStrategy(Enum): fail = 'fail' +@dataclass_json +@dataclass class DataSource: def get_data(self, state): raise RuntimeError("Implemented by subclass") + def get_data_range(self, start: Union[dt.date, dt.datetime], + end: Union[dt.date, dt.datetime, int]): + raise RuntimeError("Implemented by subclass") + +@dataclass_json +@dataclass class GsDataSource(DataSource): - def __init__(self, data_set: str, asset_id: str, min_date: dt.date = None, max_date: dt.date = None, - value_header: str = 'rate'): - self._data_set = data_set - self._asset_id = asset_id - self._min_date = min_date - self._max_date = max_date - self._value_header = value_header - self._loaded_data = None + data_set: str + asset_id: str + min_date: dt.date = None + max_date: dt.date = None + value_header: str = 'rate' + + def __post_init__(self): + self.loaded_data = None def get_data(self, state: Union[dt.date, dt.datetime] = None): - if self._loaded_data is None: - ds = Dataset(self._data_set) - if self._min_date: - self._loaded_data = ds.get_data(self._min_date, self._max_date, assetId=(self._asset_id,)) + if self.loaded_data is None: + ds = Dataset(self.data_set) + if self.min_date: + self.loaded_data = ds.get_data(self.min_date, self.max_date, assetId=(self.asset_id,)) else: - return ds.get_data(state, state, assetId=(self._asset_id,))[self._value_header] - return self._loaded_data[self._value_header].at[pd.to_datetime(state)] + return ds.get_data(state, state, assetId=(self.asset_id,))[self.value_header] + return self.loaded_data[self.value_header].at[pd.to_datetime(state)] + + def get_data_range(self, start: Union[dt.date, dt.datetime], end: Union[dt.date, dt.datetime, int]): + if self.loaded_data is None: + ds = Dataset(self.data_set) + if self.min_date: + self.loaded_data = ds.get_data(self.min_date, self.max_date, assetId=(self.asset_id,)) + else: + self.loaded_data = ds.get_data(start, self.max_date, assetId=(self.asset_id,)) + if isinstance(end, int): + return self.loaded_data.loc[self.loaded_data.index < start].tail(end) + return self.loaded_data.loc[(start < self.loaded_data.index) & (self.loaded_data.index <= end)] +@dataclass_json +@dataclass class GenericDataSource(DataSource): - def __init__(self, data_set: pd.Series, missing_data_strategy: MissingDataStrategy = MissingDataStrategy.fail): - """ - A data source which holds a pandas series indexed by date or datetime - :param data_set: a pandas dataframe indexed by date or datetime - :param missing_data_strategy: MissingDataStrategy which defines behaviour if data is missing, will only take - effect if using get_data, gat_data_range has no expectations of the number of - expected data points. - """ - self._data_set = data_set - self._missing_data_strategy = missing_data_strategy - self._tz_aware = isinstance(self._data_set.index[0], - dt.datetime) and self._data_set.index[0].tzinfo is not None - if self._missing_data_strategy == MissingDataStrategy.interpolate: - self._data_set.interpolate() - elif self._missing_data_strategy == MissingDataStrategy.fill_forward: - self._data_set.ffill() + """ + A data source which holds a pandas series indexed by date or datetime + :param data_set: a pandas dataframe indexed by date or datetime + :param missing_data_strategy: MissingDataStrategy which defines behaviour if data is missing, will only take + effect if using get_data, get_data_range has no expectations of the number of + expected data points. + """ + data_set: pd.Series + missing_data_strategy: MissingDataStrategy = MissingDataStrategy.fail + + def __post_init__(self): + self._tz_aware = isinstance(self.data_set.index[0], dt.datetime) and self.data_set.index[0].tzinfo is not None def get_data(self, state: Union[dt.date, dt.datetime, Iterable]): """ - Get the value of the dataset at a time or date. If a list of dates or times is provided return the avg value + Get the value of the dataset at a time or date. If a list of dates or times is provided return list of values :param state: a date, datetime or a list of dates or datetimes :return: float value """ @@ -86,25 +105,25 @@ def get_data(self, state: Union[dt.date, dt.datetime, Iterable]): if self._tz_aware and (state.tzinfo is None or state.tzinfo.utcoffset(state) is None): state = pytz.utc.localize(state) - if pd.Timestamp(state) in self._data_set: - return self._data_set[pd.Timestamp(state)] - elif state in self._data_set or self._missing_data_strategy == MissingDataStrategy.fail: - return self._data_set[state] + if pd.Timestamp(state) in self.data_set: + return self.data_set[pd.Timestamp(state)] + elif state in self.data_set or self.missing_data_strategy == MissingDataStrategy.fail: + return self.data_set[state] else: - if isinstance(self._data_set.index, pd.DatetimeIndex): - self._data_set.at[pd.to_datetime(state)] = np.nan - self._data_set.sort_index(inplace=True) + if isinstance(self.data_set.index, pd.DatetimeIndex): + self.data_set.at[pd.to_datetime(state)] = np.nan + self.data_set.sort_index(inplace=True) else: - self._data_set.at[state] = np.nan - self._data_set.sort_index() - if self._missing_data_strategy == MissingDataStrategy.interpolate: - self._data_set = self._data_set.interpolate() - elif self._missing_data_strategy == MissingDataStrategy.fill_forward: - self._data_set = self._data_set.ffill() + self.data_set.at[state] = np.nan + self.data_set.sort_index() + if self.missing_data_strategy == MissingDataStrategy.interpolate: + self.data_set = self.data_set.interpolate() + elif self.missing_data_strategy == MissingDataStrategy.fill_forward: + self.data_set = self.data_set.ffill() else: - raise RuntimeError(f'unrecognised missing data strategy: {str(self._missing_data_strategy)}') - return self._data_set[pd.to_datetime(state)] if isinstance(self._data_set.index, - pd.DatetimeIndex) else self._data_set[state] + raise RuntimeError(f'unrecognised missing data strategy: {str(self.missing_data_strategy)}') + return self.data_set[pd.to_datetime(state)] if isinstance(self.data_set.index, + pd.DatetimeIndex) else self.data_set[state] def get_data_range(self, start: Union[dt.date, dt.datetime], end: Union[dt.date, dt.datetime, int]): @@ -115,13 +134,16 @@ def get_data_range(self, start: Union[dt.date, dt.datetime], start date :return: pd.Series """ + if isinstance(end, int): - return self._data_set.loc[self._data_set.index < start].tail(end) - return self._data_set.loc[(start < self._data_set.index) & (self._data_set.index <= end)] + return self.data_set.loc[self.data_set.index < start].tail(end) + return self.data_set.loc[(start < self.data_set.index) & (self.data_set.index <= end)] +@dataclass_json +@dataclass class DataManager: - def __init__(self): + def __post_init__(self): self._data_sources = {} def add_data_source(self, series: Union[pd.Series, DataSource], data_freq: DataFrequency, diff --git a/gs_quant/backtests/equity_vol_engine.py b/gs_quant/backtests/equity_vol_engine.py index e0de02a4..6ab4116f 100644 --- a/gs_quant/backtests/equity_vol_engine.py +++ b/gs_quant/backtests/equity_vol_engine.py @@ -35,7 +35,7 @@ def is_synthetic_forward(priceable): is_syn_fwd = is_portfolio if is_portfolio: is_size_two = len(priceable) == 2 - is_syn_fwd &= is_size_two + is_syn_fwd = is_size_two if is_size_two: has_two_eq_options = isinstance(priceable[0], EqOption) and isinstance(priceable[1], EqOption) is_syn_fwd &= has_two_eq_options diff --git a/gs_quant/backtests/generic_engine.py b/gs_quant/backtests/generic_engine.py index 44373c1e..8cbe280f 100644 --- a/gs_quant/backtests/generic_engine.py +++ b/gs_quant/backtests/generic_engine.py @@ -504,7 +504,7 @@ def run_backtest(self, strategy, start=None, end=None, frequency='1m', states=No """ - logging.info(f'Starting Backtest: Building Date Schedule - {dt.datetime.now()}') + logger.info(f'Starting Backtest: Building Date Schedule - {dt.datetime.now()}') self._tracing_enabled = Tracer.get_instance().active_span is not None self._pricing_context_params = {'show_progress': show_progress, 'csa_term': csa_term, @@ -558,15 +558,15 @@ def __run(self, strategy, start, end, frequency, states, risks, initial_value, r backtest = BackTest(strategy, strategy_pricing_dates, risks) - logging.info('Resolving initial portfolio') + logger.info('Resolving initial portfolio') with self._trace('Resolve initial portfolio'): self._resolve_initial_portfolio(strategy, backtest, strategy_start_date, strategy_pricing_dates) - logging.info('Building simple and semi-deterministic triggers and actions') + logger.info('Building simple and semi-deterministic triggers and actions') self._build_simple_and_semi_triggers_and_actions(strategy, backtest, strategy_pricing_dates) - logging.info(f'Filtering strategy calculations to run from {strategy_start_date} to {strategy_end_date}') + logger.info(f'Filtering strategy calculations to run from {strategy_start_date} to {strategy_end_date}') backtest.portfolio_dict = defaultdict(Portfolio, {k: backtest.portfolio_dict[k] for k in backtest.portfolio_dict if strategy_start_date <= k <= strategy_end_date}) @@ -574,12 +574,11 @@ def __run(self, strategy, start, end, frequency, states, risks, initial_value, r for k in backtest.hedges if strategy_start_date <= k <= strategy_end_date}) - logging.info('Pricing simple and semi-deterministic triggers and actions') + logger.info('Pricing simple and semi-deterministic triggers and actions') with self._trace('Pricing semi-det Triggers'): self._price_semi_det_triggers(backtest, risks) - logging.info('Scaling semi-deterministic triggers and actions and calculating path dependent triggers ' - 'and actions') + logger.info('Scaling semi-determ triggers and actions and calculating path dependent triggers and actions') for d in strategy_pricing_dates: with self._trace('Process date') as scope: if scope: @@ -592,7 +591,7 @@ def __run(self, strategy, start, end, frequency, states, risks, initial_value, r with self._trace('Handle Cash'): self._handle_cash(backtest, risks, price_risk, strategy_pricing_dates, strategy_end_date, initial_value) - logging.info(f'Finished Backtest:- {dt.datetime.now()}') + logger.info(f'Finished Backtest:- {dt.datetime.now()}') return backtest def _resolve_initial_portfolio(self, strategy, backtest, strategy_start_date, strategy_pricing_dates): @@ -662,7 +661,7 @@ def _price_semi_det_triggers(self, backtest, risks): p.results = port.calc(tuple(risks)) def _process_triggers_and_actions_for_date(self, d, strategy, backtest, risks): - logging.info(f'{d}: Processing triggers and actions') + logger.info(f'{d}: Processing triggers and actions') # path dependent for trigger in strategy.triggers: if trigger.calc_type == CalcType.path_dependent: @@ -746,7 +745,7 @@ def _process_triggers_and_actions_for_date(self, d, strategy, backtest, risks): backtest.cash_payments[hedge.exit_payment.effective_date].append(hedge.exit_payment) def _calc_new_trades(self, backtest, risks): - logging.info('Calculating and scaling newly added portfolio positions') + logger.info('Calculating and scaling newly added portfolio positions') # test to see if new trades have been added and calc with PricingContext(): backtest.calc_calls += 1 @@ -762,7 +761,7 @@ def _calc_new_trades(self, backtest, risks): leaves = [] for leaf in portfolio: if leaf.name not in trades_for_date: - logging.info(f'{day}: new portfolio position {leaf} scheduled for calculation') + logger.info(f'{day}: new portfolio position {leaf} scheduled for calculation') leaves.append(leaf) if len(leaves): @@ -770,12 +769,12 @@ def _calc_new_trades(self, backtest, risks): leaves_by_date[day] = Portfolio(leaves).calc(tuple(risks)) backtest.calculations += len(leaves) * len(risks) - logging.info('Processing results for newly added portfolio positions') + logger.info('Processing results for newly added portfolio positions') for day, leaves in leaves_by_date.items(): backtest.add_results(day, leaves) def _handle_cash(self, backtest, risks, price_risk, strategy_pricing_dates, strategy_end_date, initial_value): - logging.info('Calculating prices for cash payments') + logger.info('Calculating prices for cash payments') # run any additional calcs to handle cash scaling (e.g. unwinds) cash_results = {} cash_trades_by_date = defaultdict(list) diff --git a/gs_quant/backtests/predefined_asset_engine.py b/gs_quant/backtests/predefined_asset_engine.py index 315f964e..490b3520 100644 --- a/gs_quant/backtests/predefined_asset_engine.py +++ b/gs_quant/backtests/predefined_asset_engine.py @@ -50,14 +50,14 @@ def generate_orders(self, state: dt.datetime, backtest: PredefinedAssetBacktest, quantity=quantity, generation_time=state, execution_datetime=state, - source=self.action._name)) + source=self.action.name)) if isinstance(self.action.trade_duration, dt.timedelta): # create close order orders.append(OrderAtMarket(instrument=pricable, quantity=quantity * -1, generation_time=state, execution_datetime=state + self.action.trade_duration, - source=self.action._name)) + source=self.action.name)) return orders def apply_action(self, state: dt.datetime, backtest: PredefinedAssetBacktest, info=None): diff --git a/gs_quant/backtests/strategy.py b/gs_quant/backtests/strategy.py index 37d9d23a..a1746d8e 100644 --- a/gs_quant/backtests/strategy.py +++ b/gs_quant/backtests/strategy.py @@ -14,6 +14,8 @@ under the License. """ +from dataclasses import dataclass +from dataclasses_json import dataclass_json from typing import Tuple from gs_quant.backtests.triggers import * @@ -26,27 +28,23 @@ backtest_engines = [GenericEngine(), PredefinedAssetEngine(), EquityVolEngine()] +@dataclass_json +@dataclass class Strategy(object): """ A strategy object on which one may run a backtest """ + initial_portfolio: Optional[Tuple[Priceable, ...]] + triggers: Union[Trigger, Iterable[Trigger]] - def __init__(self, initial_portfolio: Optional[Tuple[Priceable, ...]], triggers: Union[Trigger, Iterable[Trigger]]): - self._initial_portfolio = make_list(initial_portfolio) - self._triggers = make_list(triggers) + def __post_init__(self): + self.initial_portfolio = make_list(self.initial_portfolio) + self.triggers = make_list(self.triggers) + self.risks = self.get_risks() - @property - def triggers(self): - return self._triggers - - @property - def initial_portfolio(self): - return self._initial_portfolio - - @property - def risks(self): + def get_risks(self): risk_list = [] - for t in self._triggers: + for t in self.triggers: risk_list += t.risks if t.risks is not None else [] return risk_list diff --git a/gs_quant/backtests/triggers.py b/gs_quant/backtests/triggers.py index 6e5f0f05..e17c5f59 100644 --- a/gs_quant/backtests/triggers.py +++ b/gs_quant/backtests/triggers.py @@ -15,12 +15,14 @@ """ +from dataclasses import dataclass, field +from dataclasses_json import dataclass_json from typing import Optional -import warnings from gs_quant.backtests.actions import Action, AddTradeAction, AddTradeActionInfo from gs_quant.backtests.backtest_objects import BackTest, PredefinedAssetBacktest from gs_quant.backtests.backtest_utils import make_list, CalcType from gs_quant.backtests.data_sources import * +from gs_quant.base import field_metadata from gs_quant.datetime.relative_date import RelativeDateSchedule from gs_quant.risk.transform import Transformer from gs_quant.risk import RiskMeasure @@ -37,109 +39,88 @@ class AggType(Enum): ANY_OF = 2 +@dataclass_json +@dataclass class TriggerRequirements(object): - def __init__(self): - pass + pass +@dataclass_json +@dataclass class PeriodicTriggerRequirements(TriggerRequirements): - def __init__(self, start_date: dt.date = None, end_date: dt.date = None, frequency: str = None, - calendar: str = None): - super().__init__() - self.start_date = start_date - self.end_date = end_date - self.frequency = frequency - self.calendar = calendar + start_date: Optional[dt.date] = field(default=None, metadata=field_metadata) + end_date: Optional[dt.date] = field(default=None, metadata=field_metadata) + frequency: Optional[str] = field(default=None, metadata=field_metadata) + calendar: Optional[str] = field(default=None, metadata=field_metadata) +@dataclass_json +@dataclass class IntradayTriggerRequirements(TriggerRequirements): - def __init__(self, start_time: dt.datetime, end_time: dt.datetime, frequency: str): - super().__init__() - self.start_time = start_time - self.end_time = end_time - self.frequency = frequency + start_time: Optional[dt.datetime] = field(default=None, metadata=field_metadata) + end_time: Optional[dt.datetime] = field(default=None, metadata=field_metadata) + frequency: Optional[str] = field(default=None, metadata=field_metadata) +@dataclass_json +@dataclass class MktTriggerRequirements(TriggerRequirements): - def __init__(self, data_source: DataSource, trigger_level: float, direction: TriggerDirection): - super().__init__() - self.data_source = data_source - self.trigger_level = trigger_level - self.direction = direction + data_source: DataSource = field(default=None, metadata=field_metadata) + trigger_level: float = field(default=None, metadata=field_metadata) + direction: TriggerDirection = field(default=None, metadata=field_metadata) +@dataclass_json +@dataclass class RiskTriggerRequirements(TriggerRequirements): - def __init__(self, risk: RiskMeasure, trigger_level: float, direction: TriggerDirection, - risk_transformation: Optional[Transformer] = None): - super().__init__() - self.risk = risk - self.trigger_level = trigger_level - self.direction = direction - self.risk_transformation = risk_transformation + risk: RiskMeasure = field(default=None, metadata=field_metadata) + trigger_level: float = field(default=None, metadata=field_metadata) + direction: TriggerDirection = field(default=None, metadata=field_metadata) + risk_transformation: Optional[Transformer] = field(default=None, metadata=field_metadata) +@dataclass_json +@dataclass class AggregateTriggerRequirements(TriggerRequirements): - def __init__(self, triggers: Iterable[object], aggregate_type: AggType = AggType.ALL_OF): - super().__init__() - self.triggers = triggers - self.aggregate_type = aggregate_type + triggers: Iterable = field(default=None, metadata=field_metadata) + aggregate_type: AggType = field(default=AggType.ALL_OF, metadata=field_metadata) +@dataclass_json +@dataclass class NotTriggerRequirements(TriggerRequirements): - def __init__(self, trigger: object): - super().__init__() - self.trigger = trigger + trigger: object = field(default=None, metadata=field_metadata) +@dataclass_json +@dataclass class DateTriggerRequirements(TriggerRequirements): - def __init__(self, dates: Iterable[Union[dt.datetime, dt.date]], entire_day: bool = False): - super().__init__() - """ - :param dates: the list of dates on which to trigger - :param entire_day: flag that indicates whether to check against dates instead of datetimes - """ - self.dates = dates - self.entire_day = entire_day + dates: Iterable[Union[dt.datetime, dt.date]] = field(default=None, metadata=field_metadata) + entire_day: bool = field(default=False, metadata=field_metadata) +@dataclass_json +@dataclass class PortfolioTriggerRequirements(TriggerRequirements): - def __init__(self, data_source: str, trigger_level: float, direction: TriggerDirection): - """ - :param data_source: the portfolio property to check - :param trigger_level: the threshold level on which to trigger - :param direction: a direction for the trigger_level comparison - """ - super().__init__() - self.data_source = data_source - self.trigger_level = trigger_level - self.direction = direction + data_source: str = field(default=None, metadata=field_metadata) + trigger_level: float = field(default=None, metadata=field_metadata) + direction: TriggerDirection = field(default=None, metadata=field_metadata) +@dataclass_json +@dataclass class MeanReversionTriggerRequirements(TriggerRequirements): - def __init__(self, data_source: DataSource, - z_score_bound: float, - rolling_mean_window: int, - rolling_std_window: int): - """ - This trigger will sell when the value hits the z score threshold on the up side, will close out a position - when the value crosses the rolling_mean and buy when the value hits the z score threshold on the down side. - - :param data_source: the asset values - :param z_score_bound: the threshold level on which to trigger - :param rolling_mean_window: the number of values to consider when calculating the rolling mean - :param rolling_std_window: the number of values to consider when calculating the standard deviation - """ - super().__init__() - self.data_source = data_source - self.z_score_bound = z_score_bound - self.rolling_mean_window = rolling_mean_window - self.rolling_std_window = rolling_std_window + data_source: DataSource = field(default=None, metadata=field_metadata) + z_score_bound: float = field(default=None, metadata=field_metadata) + rolling_mean_window: int = field(default=None, metadata=field_metadata) + rolling_std_window: int = field(default=None, metadata=field_metadata) +@dataclass_json +@dataclass class TriggerInfo(object): - def __init__(self, triggered: bool, info_dict: Optional[dict] = None): - self.triggered = triggered - self.info_dict = info_dict + triggered: bool + info_dict: Optional[dict] = None def __eq__(self, other): return self.triggered is other @@ -148,13 +129,14 @@ def __bool__(self): return self.triggered +@dataclass_json +@dataclass class Trigger(object): + trigger_requirements: Optional[TriggerRequirements] = field(default=None, metadata=field_metadata) + actions: Union[Action, Iterable[Action]] = field(default=None, metadata=field_metadata) - def __init__(self, trigger_requirements: Optional[TriggerRequirements], actions: Union[Action, Iterable[Action]]): - self._trigger_requirements = trigger_requirements - self._actions = make_list(actions) - self._risks = [x.risk for x in self.actions if x.risk is not None] - self._calc_type = CalcType.simple + def __post_init__(self): + self.actions = make_list(self.actions) def has_triggered(self, state: dt.date, backtest: BackTest = None) -> TriggerInfo: """ @@ -172,35 +154,27 @@ def get_trigger_times(self): @property def calc_type(self): - return self._calc_type - - @property - def actions(self): - return self._actions - - @property - def trigger_requirements(self): - return self._trigger_requirements + return CalcType.simple @property def risks(self): - return self._risks + return [x.risk for x in make_list(self.actions) if x.risk is not None] +@dataclass_json +@dataclass class PeriodicTrigger(Trigger): - def __init__(self, - trigger_requirements: PeriodicTriggerRequirements, - actions: Union[Action, Iterable[Action]]): - super().__init__(trigger_requirements, actions) - self._trigger_dates = None + trigger_requirements: PeriodicTriggerRequirements = field(default=None, metadata=field_metadata) + actions: Union[Action, Iterable[Action]] = field(default=None, metadata=field_metadata) + _trigger_dates = None def get_trigger_times(self) -> [dt.date]: if not self._trigger_dates: - self._trigger_dates = self._trigger_requirements.dates if \ - hasattr(self._trigger_requirements, 'dates') else \ - RelativeDateSchedule(self._trigger_requirements.frequency, - self._trigger_requirements.start_date, - self._trigger_requirements.end_date).apply_rule( + self._trigger_dates = self.trigger_requirements.dates if \ + hasattr(self.trigger_requirements, 'dates') else \ + RelativeDateSchedule(self.trigger_requirements.frequency, + self.trigger_requirements.start_date, + self.trigger_requirements.end_date).apply_rule( holiday_calendar=self.trigger_requirements.calendar) return self._trigger_dates @@ -210,16 +184,18 @@ def has_triggered(self, state: dt.date, backtest: BackTest = None) -> TriggerInf return TriggerInfo(state in self._trigger_dates) +@dataclass_json +@dataclass class IntradayPeriodicTrigger(Trigger): - def __init__(self, - trigger_requirements: IntradayTriggerRequirements, - actions: Union[Action, Iterable[Action]]): - super().__init__(trigger_requirements, actions) + trigger_requirements: IntradayTriggerRequirements = field(default=None, metadata=field_metadata) + actions: Union[Action, Iterable[Action]] = field(default=None, metadata=field_metadata) + def __post_init__(self): + super().__post_init__() # generate all the trigger times - start = trigger_requirements.start_time - end = trigger_requirements.end_time - freq = trigger_requirements.frequency + start = self.trigger_requirements.start_time + end = self.trigger_requirements.end_time + freq = self.trigger_requirements.frequency self._trigger_times = [] time = start @@ -234,114 +210,127 @@ def has_triggered(self, state: Union[dt.date, dt.datetime], backtest: BackTest = return TriggerInfo(state.time() in self._trigger_times) +@dataclass_json +@dataclass class MktTrigger(Trigger): - def __init__(self, - trigger_requirements: MktTriggerRequirements, - actions: Union[Action, Iterable[Action]]): - super().__init__(trigger_requirements, actions) + trigger_requirements: MktTriggerRequirements = field(default=None, metadata=field_metadata) + actions: Union[Action, Iterable[Action]] = field(default=None, metadata=field_metadata) def has_triggered(self, state: dt.date, backtest: BackTest = None) -> TriggerInfo: - data_value = self._trigger_requirements.data_source.get_data(state) - if self._trigger_requirements.direction == TriggerDirection.ABOVE: - if data_value > self._trigger_requirements.trigger_level: + data_value = self.trigger_requirements.data_source.get_data(state) + if self.trigger_requirements.direction == TriggerDirection.ABOVE: + if data_value > self.trigger_requirements.trigger_level: return TriggerInfo(True) - elif self._trigger_requirements.direction == TriggerDirection.BELOW: - if data_value < self._trigger_requirements.trigger_level: + elif self.trigger_requirements.direction == TriggerDirection.BELOW: + if data_value < self.trigger_requirements.trigger_level: return TriggerInfo(True) else: - if data_value == self._trigger_requirements.trigger_level: + if data_value == self.trigger_requirements.trigger_level: return TriggerInfo(True) return TriggerInfo(False) +@dataclass_json +@dataclass class StrategyRiskTrigger(Trigger): - def __init__(self, - trigger_requirements: RiskTriggerRequirements, - actions: Union[Action, Iterable[Action]]): - super().__init__(trigger_requirements, actions) - self._calc_type = CalcType.path_dependent - self._risks += [trigger_requirements.risk] + trigger_requirements: RiskTriggerRequirements = field(default=None, metadata=field_metadata) + actions: Union[Action, Iterable[Action]] = field(default=None, metadata=field_metadata) + + @property + def calc_type(self): + return CalcType.path_dependent + + @property + def risks(self): + return [x.risk for x in make_list(self.actions) if x.risk is not None] + [self.trigger_requirements.risk] def has_triggered(self, state: dt.date, backtest: BackTest = None) -> TriggerInfo: if self.trigger_requirements.risk_transformation is None: - risk_value = backtest.results[state][self._trigger_requirements.risk].aggregate() + risk_value = backtest.results[state][self.trigger_requirements.risk].aggregate() else: - risk_value = backtest.results[state][self._trigger_requirements.risk].transform( + risk_value = backtest.results[state][self.trigger_requirements.risk].transform( risk_transformation=self.trigger_requirements.risk_transformation).aggregate( allow_mismatch_risk_keys=True) - if self._trigger_requirements.direction == TriggerDirection.ABOVE: - if risk_value > self._trigger_requirements.trigger_level: + if self.trigger_requirements.direction == TriggerDirection.ABOVE: + if risk_value > self.trigger_requirements.trigger_level: return TriggerInfo(True) - elif self._trigger_requirements.direction == TriggerDirection.BELOW: - if risk_value < self._trigger_requirements.trigger_level: + elif self.trigger_requirements.direction == TriggerDirection.BELOW: + if risk_value < self.trigger_requirements.trigger_level: return TriggerInfo(True) else: - if risk_value == self._trigger_requirements.trigger_level: + if risk_value == self.trigger_requirements.trigger_level: return TriggerInfo(True) return TriggerInfo(False) +@dataclass_json +@dataclass class AggregateTrigger(Trigger): - def __init__(self, - trigger_requirements: Optional[AggregateTriggerRequirements] = None, - actions: Optional[Union[Action, Iterable[Action]]] = None, - triggers: Optional[Iterable[Trigger]] = None): - # support previous behaviour where a list of triggers was passed. - if not trigger_requirements and triggers is not None: - warnings.warn('triggers is deprecated; trigger_requirements', DeprecationWarning, 2) - trigger_requirements = AggregateTriggerRequirements(triggers) - actions = [] if not actions else actions - for t in trigger_requirements.triggers: + trigger_requirements: AggregateTriggerRequirements = field(default=None, metadata=field_metadata) + actions: Union[Action, Iterable[Action]] = field(default=None, metadata=field_metadata) + + def __post_init__(self): + super().__post_init__() + actions = [] if not self.actions else make_list(self.actions) + for t in self.trigger_requirements.triggers: actions += [action for action in t.actions] - super().__init__(trigger_requirements, actions) + self.actions = actions def has_triggered(self, state: dt.date, backtest: BackTest = None) -> TriggerInfo: - self._actions = [] info_dict = {} - if self._trigger_requirements.aggregate_type == AggType.ALL_OF: - for trigger in self._trigger_requirements.triggers: + if self.trigger_requirements.aggregate_type == AggType.ALL_OF: + for trigger in self.trigger_requirements.triggers: t_info = trigger.has_triggered(state, backtest) if not t_info: return TriggerInfo(False) else: if t_info.info_dict: info_dict.update(t_info.info_dict) - self._actions.extend(trigger.actions) return TriggerInfo(True, info_dict) - elif self._trigger_requirements.aggregate_type == AggType.ANY_OF: + elif self.trigger_requirements.aggregate_type == AggType.ANY_OF: triggered = False - for trigger in self._trigger_requirements.triggers: + for trigger in self.trigger_requirements.triggers: t_info = trigger.has_triggered(state, backtest) if t_info: triggered = True if t_info.info_dict: info_dict.update(t_info.info_dict) - self._actions.extend(trigger.actions) return TriggerInfo(True, info_dict) if triggered else TriggerInfo(False) else: - raise RuntimeError(f'Unrecognised aggregation type: {self._trigger_requirements.aggregate_type}') + raise RuntimeError(f'Unrecognised aggregation type: {self.trigger_requirements.aggregate_type}') @property def triggers(self) -> Iterable[Trigger]: - return self._trigger_requirements.triggers + return self.trigger_requirements.triggers +@dataclass_json +@dataclass class NotTrigger(Trigger): - def __init__(self, trigger_requirements: NotTriggerRequirements, actions: Optional[Iterable[Action]] = None): - super().__init__(trigger_requirements, actions) + trigger_requirements: NotTriggerRequirements = field(default=None, metadata=field_metadata) + actions: Union[Action, Iterable[Action]] = field(default=None, metadata=field_metadata) + + def __post_init__(self): + super().__post_init__() + actions = [] if not self.actions else self.actions + actions += [action for action in self.trigger_requirements.trigger.actions] def has_triggered(self, state: dt.date, backtest: BackTest = None) -> TriggerInfo: t_info = self.trigger_requirements.trigger.has_triggered(state, backtest) if t_info: return TriggerInfo(False) else: - self._actions.extend(self.trigger_requirements.trigger.actions) return TriggerInfo(True) +@dataclass_json +@dataclass class DateTrigger(Trigger): - def __init__(self, trigger_requirements: DateTriggerRequirements, actions: Iterable[Action]): - super().__init__(trigger_requirements, actions) + trigger_requirements: DateTriggerRequirements = field(default=None, metadata=field_metadata) + actions: Union[Action, Iterable[Action]] = field(default=None, metadata=field_metadata) + + def __post_init__(self): + super().__post_init__() self._dates_from_datetimes = [d.date() if isinstance(d, dt.datetime) else d for d in self.trigger_requirements.dates] \ if self.trigger_requirements.entire_day else None @@ -355,41 +344,49 @@ def has_triggered(self, state: Union[dt.date, dt.datetime], backtest: BackTest = elif isinstance(state, dt.date): return TriggerInfo(state in self._dates_from_datetimes) - return TriggerInfo(state in self._trigger_requirements.dates) + return TriggerInfo(state in self.trigger_requirements.dates) def get_trigger_times(self): - return self._dates_from_datetimes or self._trigger_requirements.dates + return self._dates_from_datetimes or self.trigger_requirements.dates +@dataclass_json +@dataclass class PortfolioTrigger(Trigger): - def __init__(self, trigger_requirements: PortfolioTriggerRequirements, actions: Iterable[Action] = None): - super().__init__(trigger_requirements, actions) + trigger_requirements: PortfolioTriggerRequirements = field(default=None, metadata=field_metadata) + actions: Union[Action, Iterable[Action]] = field(default=None, metadata=field_metadata) + + def __post_init__(self): + super().__post_init__() + self._current_position = 0 def has_triggered(self, state: dt.date, backtest: BackTest = None) -> TriggerInfo: - if self._trigger_requirements.data_source == 'len': + if self.trigger_requirements.data_source == 'len': value = len(backtest.portfolio_dict) - if self._trigger_requirements.direction == TriggerDirection.ABOVE: - if value > self._trigger_requirements.trigger_level: + if self.trigger_requirements.direction == TriggerDirection.ABOVE: + if value > self.trigger_requirements.trigger_level: return TriggerInfo(True) - elif self._trigger_requirements.direction == TriggerDirection.BELOW: - if value < self._trigger_requirements.trigger_level: + elif self.trigger_requirements.direction == TriggerDirection.BELOW: + if value < self.trigger_requirements.trigger_level: return TriggerInfo(True) else: - if value == self._trigger_requirements.trigger_level: + if value == self.trigger_requirements.trigger_level: return TriggerInfo(True) - return TriggerInfo(False) +@dataclass_json +@dataclass class MeanReversionTrigger(Trigger): - def __init__(self, - trigger_requirements: MeanReversionTriggerRequirements, - actions: Union[Action, Iterable[Action]]): - super().__init__(trigger_requirements, actions) + trigger_requirements: MeanReversionTriggerRequirements = field(default=None, metadata=field_metadata) + actions: Union[Action, Iterable[Action]] = field(default=None, metadata=field_metadata) + + def __post_init__(self): + super().__post_init__() self._current_position = 0 def has_triggered(self, state: dt.date, backtest: BackTest = None) -> TriggerInfo: - trigger_req = self._trigger_requirements + trigger_req = self.trigger_requirements rolling_mean = trigger_req.data_source.get_data_range(state, trigger_req.rolling_mean_window).mean() rolling_std = trigger_req.data_source.get_data_range(state, trigger_req.rolling_std_window).std() current_price = trigger_req.data_source.get_data(state) @@ -414,10 +411,15 @@ def has_triggered(self, state: dt.date, backtest: BackTest = None) -> TriggerInf return TriggerInfo(False) +@dataclass_json +@dataclass class OrdersGeneratorTrigger(Trigger): """Base class for triggers used with the PredefinedAssetEngine.""" - def __init__(self): - super().__init__(None, Action()) + + def __post_init__(self): + if not self.actions: + self.actions = [Action()] + super().__post_init__() def get_trigger_times(self) -> list: """ diff --git a/gs_quant/base.py b/gs_quant/base.py index b7ea552b..923476d1 100644 --- a/gs_quant/base.py +++ b/gs_quant/base.py @@ -47,7 +47,7 @@ def exclude_none(o): return o is None -def exlude_always(_o): +def exclude_always(_o): return True @@ -91,7 +91,7 @@ def wrapper(self, *args, **kwargs): field_metadata = config(exclude=exclude_none) -name_metadata = config(exclude=exlude_always) +name_metadata = config(exclude=exclude_always) class RiskKey(namedtuple('RiskKey', ('provider', 'date', 'market', 'params', 'scenario', 'risk_measure'))): @@ -350,6 +350,7 @@ def from_instance(self, instance): __setattr__(self, fld.name, __getattribute__(instance, fld.name)) +@dataclass_json @dataclass class Priceable(Base): diff --git a/gs_quant/test/backtest/test_backtest_eq_vol_engine.py b/gs_quant/test/backtest/test_backtest_eq_vol_engine.py index 6b8f0a2c..93f7acf0 100644 --- a/gs_quant/test/backtest/test_backtest_eq_vol_engine.py +++ b/gs_quant/test/backtest/test_backtest_eq_vol_engine.py @@ -20,8 +20,8 @@ from gs_quant.api.gs.backtests import GsBacktestApi from gs_quant.backtests.strategy import Strategy from gs_quant.backtests.triggers import PeriodicTrigger, PeriodicTriggerRequirements, DateTrigger, \ - DateTriggerRequirements, AggregateTrigger, PortfolioTrigger, PortfolioTriggerRequirements, \ - TriggerDirection + DateTriggerRequirements, AggregateTrigger, AggregateTriggerRequirements, PortfolioTrigger, \ + PortfolioTriggerRequirements, TriggerDirection from gs_quant.backtests.actions import EnterPositionQuantityScaledAction, HedgeAction, ExitPositionAction from gs_quant.backtests.equity_vol_engine import * from gs_quant.common import Currency, AssetClass @@ -271,18 +271,18 @@ def test_engine_mapping_with_signals(mocker): entry_signal_series = pd.Series(data={dt.date(2019, 2, 19): 1}) entry_dates = entry_signal_series[entry_signal_series > 0].keys() - entry_trigger = AggregateTrigger(triggers=[ + entry_trigger = AggregateTrigger(AggregateTriggerRequirements(triggers=[ DateTrigger(trigger_requirements=DateTriggerRequirements(dates=entry_dates), actions=entry_action), PortfolioTrigger(trigger_requirements=PortfolioTriggerRequirements('len', 0, TriggerDirection.EQUAL)) - ]) + ])) exit_signal_series = pd.Series(data={dt.date(2019, 2, 20): 1}) exit_dates = exit_signal_series[exit_signal_series > 0].keys() - exit_trigger = AggregateTrigger(triggers=[ + exit_trigger = AggregateTrigger(AggregateTriggerRequirements(triggers=[ DateTrigger(trigger_requirements=DateTriggerRequirements(dates=exit_dates), actions=ExitPositionAction()), PortfolioTrigger(trigger_requirements=PortfolioTriggerRequirements('len', 0, TriggerDirection.ABOVE)) - ]) + ])) strategy = Strategy(initial_portfolio=None, triggers=[entry_trigger, exit_trigger]) diff --git a/gs_quant/test/backtest/test_triggers.py b/gs_quant/test/backtest/test_triggers.py index eefe7fd7..dfbad626 100644 --- a/gs_quant/test/backtest/test_triggers.py +++ b/gs_quant/test/backtest/test_triggers.py @@ -67,7 +67,7 @@ def test_aggregate_triggger(): agg_trigger = AggregateTrigger(AggregateTriggerRequirements([trigger_1, trigger_2], aggregate_type=AggType.ANY_OF)) assert agg_trigger.has_triggered(dt.date(2021, 11, 8)) - assert isinstance(agg_trigger.actions[0].priceables[0], IRSwaption) + assert isinstance(agg_trigger.actions[1].priceables[0], IRSwaption) assert agg_trigger.has_triggered(dt.date(2021, 11, 10)) assert len(agg_trigger.actions) == 2 diff --git a/gs_quant/timeseries/measures_rates.py b/gs_quant/timeseries/measures_rates.py index 4c00846a..6c3b5de0 100644 --- a/gs_quant/timeseries/measures_rates.py +++ b/gs_quant/timeseries/measures_rates.py @@ -1322,8 +1322,11 @@ def _csa_default(csa=None, currency=None): return csa -def forward_rate(asset: Asset, forward_start_tenor=None, forward_term=None, csa=None, - close_location=None, *, source: str = None, real_time: bool = False): +@plot_measure((AssetClass.Cash,), (AssetType.Currency,), + [MeasureDependency(id_provider=_currency_to_tdapi_swap_rate_asset, + query_type=QueryType.SWAP_RATE)]) +def forward_rate(asset: Asset, forward_start_tenor: str = None, forward_term: str = None, csa: str = None, + close_location: str = None, *, source: str = None, real_time: bool = False) -> Series: """ GS Forward Rate across major currencies. @@ -1356,8 +1359,11 @@ def forward_rate(asset: Asset, forward_start_tenor=None, forward_term=None, csa= return series -def discount_factor(asset: Asset, tenor=None, csa=None, close_location=None, - *, source: str = None, real_time: bool = False): +@plot_measure((AssetClass.Cash,), (AssetType.Currency,), + [MeasureDependency(id_provider=_currency_to_tdapi_swap_rate_asset, + query_type=QueryType.SWAP_RATE)]) +def discount_factor(asset: Asset, tenor: str = None, csa: str = None, close_location: str = None, + *, source: str = None, real_time: bool = False) -> Series: """ GS Discount Factor across major currencies. @@ -1387,8 +1393,11 @@ def discount_factor(asset: Asset, tenor=None, csa=None, close_location=None, return series -def instantaneous_forward_rate(asset: Asset, tenor=None, csa=None, close_location=None, - *, source: str = None, real_time: bool = False): +@plot_measure((AssetClass.Cash,), (AssetType.Currency,), + [MeasureDependency(id_provider=_currency_to_tdapi_swap_rate_asset, + query_type=QueryType.SWAP_RATE)]) +def instantaneous_forward_rate(asset: Asset, tenor: str = None, csa: str = None, close_location: str = None, + *, source: str = None, real_time: bool = False) -> Series: """ GS Floating Rate Benchmark annualised instantaneous forward rates across major currencies. @@ -1418,8 +1427,12 @@ def instantaneous_forward_rate(asset: Asset, tenor=None, csa=None, close_locatio return series -def index_forward_rate(asset: Asset, forward_start_tenor=None, benchmark_type: str = None, fixing_tenor=None, - close_location=None, *, source: str = None, real_time: bool = False): +@plot_measure((AssetClass.Cash,), (AssetType.Currency,), + [MeasureDependency(id_provider=_currency_to_tdapi_swap_rate_asset, + query_type=QueryType.SWAP_RATE)]) +def index_forward_rate(asset: Asset, forward_start_tenor: str = None, benchmark_type: str = None, + fixing_tenor: str = None, close_location: str = None, *, + source: str = None, real_time: bool = False) -> Series: """ GS annualised forward rates across floating rate benchmark