Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes, base changes #225

Merged
merged 1 commit into from
Mar 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions gs_quant/api/gs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from gs_quant.api.data import DataApi
from gs_quant.base import Base
from gs_quant.data.core import DataContext, DataFrequency
from gs_quant.data.log import log_debug
from gs_quant.data.log import log_debug, log_warning
from gs_quant.errors import MqValueError
from gs_quant.markets import MarketDataCoordinate
from gs_quant.session import GsSession
Expand Down Expand Up @@ -476,7 +476,11 @@ def get_data_providers(cls,
def get_market_data(cls, query, request_id=None) -> pd.DataFrame:
GsSession.current: GsSession
start = time.perf_counter()
body = GsSession.current._post('/data/measures', payload=query)
try:
body = GsSession.current._post('/data/measures', payload=query)
except Exception as e:
log_warning(request_id, _logger, f'Market data query {query} failed due to {e}')
raise e
log_debug(request_id, _logger, 'market data query (%s) ran in %.3f ms', body.get('requestId'),
(time.perf_counter() - start) * 1000)

Expand Down
10 changes: 6 additions & 4 deletions gs_quant/backtests/generic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,12 @@ def run_backtest(self, strategy, start=None, end=None, frequency='1m', states=No
if type(action) in trigger_infos else None)

logging.info(f'Filtering strategy calculations to run from {strategy_start_date} to {strategy_end_date}')
backtest.portfolio_dict = {k: backtest.portfolio_dict[k] for k in backtest.portfolio_dict
if strategy_start_date <= k <= strategy_end_date}
backtest.scaling_portfolios = {k: backtest.scaling_portfolios[k] for k in backtest.scaling_portfolios
if strategy_start_date <= k <= strategy_end_date}
backtest.portfolio_dict = defaultdict(Portfolio, {k: backtest.portfolio_dict[k]
for k in backtest.portfolio_dict
if strategy_start_date <= k <= strategy_end_date})
backtest.scaling_portfolios = defaultdict(list, {k: backtest.scaling_portfolios[k]
for k in backtest.scaling_portfolios
if strategy_start_date <= k <= strategy_end_date})

logging.info('Pricing simple and semi-deterministic triggers and actions')
with PricingContext(is_batch=True, show_progress=show_progress, csa_term=csa_term, visible_to_gs=visible_to_gs):
Expand Down
24 changes: 17 additions & 7 deletions gs_quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
from collections import namedtuple
import copy
from dataclasses import Field, InitVar, MISSING, dataclass, field, fields, replace
from dataclasses_json import config, global_config
from dataclasses_json import global_config
from dataclasses_json.core import _is_supported_generic, _decode_generic
import dataclasses_json.core
import datetime as dt
from enum import EnumMeta
from inflection import camelize, underscore
Expand Down Expand Up @@ -544,12 +545,6 @@ def unresolve(self):
self.__unresolved = None


@dataclass
class QuotableBuilder(Base):

valuation_overrides: DictBase = field(default_factory=HashableDict, metadata=config(field_name='overrides'))


@dataclass
class Market(ABC):

Expand Down Expand Up @@ -628,3 +623,18 @@ def get_enum_value(enum_type: EnumMeta, value: Union[EnumBase, str]):

global_config.encoders[Market] = encode_dictable
global_config.encoders[Optional[Market]] = encode_dictable


def __decode_dataclass(cls, kvs, infer_missing):
# EXTREMELY unfortunate
if isinstance(kvs, cls):
return kvs
elif hasattr(cls, 'decode_dataclass'):
return cls.decode_dataclass(kvs)
else:
from dataclasses_json.core import _decode_dataclass_orig
return _decode_dataclass_orig(cls, kvs, infer_missing)


dataclasses_json.core._decode_dataclass_orig = dataclasses_json.core._decode_dataclass
dataclasses_json.core._decode_dataclass = __decode_dataclass
99 changes: 37 additions & 62 deletions gs_quant/instrument/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
under the License.
"""
from dataclasses_json import global_config
from dataclasses_json.core import _decode_dataclass
import datetime as dt
import inspect
import logging
Expand All @@ -24,7 +23,7 @@

from gs_quant.api.gs.parser import GsParserApi
from gs_quant.api.gs.risk import GsRiskApi
from gs_quant.base import get_enum_value, Base, InstrumentBase, Priceable, QuotableBuilder
from gs_quant.base import get_enum_value, Base, InstrumentBase, Priceable
from gs_quant.common import AssetClass, AssetType, XRef, RiskMeasure
from gs_quant.markets import HistoricalPricingContext, MarketDataCoordinate, PricingContext
from gs_quant.priceable import PriceableImpl
Expand Down Expand Up @@ -255,65 +254,35 @@ def cb(f):

@classmethod
def from_dict(cls, values: dict):
return cls.__from_dict(values)
if not values:
return

@classmethod
def __from_dict(cls, values: dict):
if values:
if issubclass(cls, QuotableBuilder):
valuation_overrides = None
if 'builder' in values:
valuation_overrides = values.get('overrides', {})
if valuation_overrides:
valuation_overrides = valuation_overrides.get('properties')

values = values['builder']
elif 'defn' in values:
values = values['defn']
elif 'overrides' in values:
valuation_overrides = values.pop('overrides')

if 'properties' in values:
values.update(values.pop('properties'))

ret = _decode_dataclass(cls, values, False)
if valuation_overrides:
ret.valuation_overrides = valuation_overrides

return ret
elif hasattr(cls, 'asset_class'):
return _decode_dataclass(cls, values, False)
else:
builder_type = values.get('$type') or values.get('builder', values.get('defn', {})).get('$type')
if builder_type:
from gs_quant_internal import tdapi
tdapi_cls = getattr(tdapi, builder_type.replace('Defn', 'Builder'))
if not tdapi_cls:
raise RuntimeError('Cannot resolve TDAPI type {}'.format(tdapi_cls))
values_no_type = values.copy()
del values_no_type['$type']
return tdapi_cls.__from_dict(values_no_type)

asset_class_field = next((f for f in ('asset_class', 'assetClass') if f in values), None)
if not asset_class_field:
raise ValueError('assetClass/asset_class not specified')
if 'type' not in values:
raise ValueError('type not specified')

asset_type = values.pop('type')
asset_class = values.pop(asset_class_field)
default_type = Security if asset_type in [None, "", "Security"] and asset_class in [None, "",
"Security"] \
else None

instrument = Instrument.__asset_class_and_type_to_instrument().get((
get_enum_value(AssetClass, asset_class),
get_enum_value(AssetType, asset_type)), default_type)

if instrument is None:
raise ValueError('unable to build instrument')

return instrument.from_dict(values)
instrument = cls if hasattr(cls, 'asset_class') else None
if instrument is None:
builder_type = values.get('$type') or values.get('builder', values.get('defn', {})).get('$type')
if builder_type:
from gs_quant_internal.base import QuotableBuilder
return QuotableBuilder.from_dict(values)

asset_class_field = next((f for f in ('asset_class', 'assetClass') if f in values), None)
if not asset_class_field:
raise ValueError('assetClass/asset_class not specified')
if 'type' not in values:
raise ValueError('type not specified')

asset_type = values.pop('type')
asset_class = values.pop(asset_class_field)
security_types = (None, '', 'Security')
default_type = Security if asset_type in security_types and asset_class in security_types else None

instrument = Instrument.__asset_class_and_type_to_instrument().get((
get_enum_value(AssetClass, asset_class),
get_enum_value(AssetType, asset_type)), default_type)

if instrument is None:
raise ValueError('unable to build instrument')

return instrument.from_dict(values)

@classmethod
def from_quick_entry(cls, text: str, asset_class: Optional[AssetClass] = None):
Expand Down Expand Up @@ -426,8 +395,14 @@ def __init__(self,
self.quantity_ = quantity


def encode_instrument(instrument: Instrument):
return instrument.to_dict()
def encode_instrument(instrument: Optional[Instrument]) -> Optional[dict]:
if instrument is not None:
return instrument.to_dict()


def encode_instruments(instruments: Optional[Iterable[Instrument]]) -> Optional[Iterable[Optional[dict]]]:
if instruments is not None:
return [encode_instrument(i) for i in instruments]


global_config.decoders[Instrument] = Instrument.from_dict
Expand Down
1 change: 1 addition & 0 deletions gs_quant/timeseries/measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -1917,6 +1917,7 @@ def skew_term(asset: Asset, strike_reference: SkewReference, distance: Real,
series = ExtendedSeries(dtype=float)
else:
df = df.loc[p_date]
df.index = pd.DatetimeIndex(df.index) if not isinstance(df.index, pd.DatetimeIndex) else df.index
df.index = DatetimeIndex([RelativeDate(df['tenor'][i], df.index.date[i]).apply_rule(exchange=asset.exchange)
for i in range(len(df))])
series = _skew(df, 'relativeStrike', 'impliedVolatility', q_strikes)
Expand Down