diff --git a/gs_quant/api/gs/portfolios.py b/gs_quant/api/gs/portfolios.py index dad523e0..30b0d29d 100644 --- a/gs_quant/api/gs/portfolios.py +++ b/gs_quant/api/gs/portfolios.py @@ -26,7 +26,7 @@ from gs_quant.target.portfolios import Portfolio, Position, PositionSet from gs_quant.target.reports import Report from gs_quant.target.risk_models import RiskModelTerm as Term -from gs_quant.target.workflow_quote import WorkflowPosition, WorkflowPositionsResponse, SaveQuoteRequest +from gs_quant.workflow import WorkflowPosition, WorkflowPositionsResponse, SaveQuoteRequest _logger = logging.getLogger(__name__) diff --git a/gs_quant/quote_reports/core.py b/gs_quant/quote_reports/core.py index 4dea0699..463b733f 100644 --- a/gs_quant/quote_reports/core.py +++ b/gs_quant/quote_reports/core.py @@ -17,7 +17,7 @@ from dataclasses_json.cfg import _GlobalConfig -from gs_quant.target.workflow_quote import VisualStructuringReport, BinaryImageComments, HyperLinkImageComments, \ +from gs_quant.workflow import VisualStructuringReport, BinaryImageComments, HyperLinkImageComments, \ CustomDeltaHedge, DeltaHedge global_config = _GlobalConfig() diff --git a/gs_quant/session.py b/gs_quant/session.py index 55b81b1e..4fae1362 100644 --- a/gs_quant/session.py +++ b/gs_quant/session.py @@ -14,7 +14,6 @@ under the License. """ import asyncio -import contextlib import inspect import itertools import json @@ -39,7 +38,7 @@ from gs_quant import version as APP_VERSION from gs_quant.base import Base -from gs_quant.context_base import ContextBase +from gs_quant.context_base import ContextBase, nullcontext from gs_quant.errors import MqError, MqRequestError, MqAuthenticationError, MqUninitialisedError from gs_quant.json_encoder import JSONEncoder, encode_default from gs_quant.tracing import Tracer @@ -235,7 +234,8 @@ def _build_request_params( include_version: Optional[bool], timeout: Optional[int], use_body: bool, - data_key: str + data_key: str, + tracing_scope: Optional[dict] ) -> Tuple[dict, str]: is_dataframe = isinstance(payload, pd.DataFrame) if not is_dataframe: @@ -244,15 +244,28 @@ def _build_request_params( kwargs = { 'timeout': timeout } + + if tracing_scope: + tracing_scope.span.set_tag('path', path) + tracing_scope.span.set_tag('timeout', timeout) + tracing_scope.span.set_tag(HTTP_URL, url) + tracing_scope.span.set_tag(HTTP_METHOD, method) + tracing_scope.span.set_tag('span.kind', 'client') + if method in ['GET', 'DELETE'] and not use_body: kwargs['params'] = payload + if tracing_scope: + headers = self._session.headers.copy() + Tracer.inject(Format.HTTP_HEADERS, headers) + kwargs['headers'] = headers elif method in ['POST', 'PUT'] or (method in ['GET', 'DELETE'] and use_body): headers = self._session.headers.copy() if request_headers: headers.update(request_headers) - Tracer.inject(Format.HTTP_HEADERS, headers) + if tracing_scope: + Tracer.inject(Format.HTTP_HEADERS, headers) if 'Content-Type' not in headers: headers.update({'Content-Type': 'application/json; charset=utf-8'}) @@ -305,16 +318,10 @@ def __request( use_body: bool = False ) -> Union[Base, tuple, dict]: span = Tracer.get_instance().active_span - tracer = Tracer(f'http:/{path}') if span else contextlib.nullcontext() + tracer = Tracer(f'http:/{path}') if span else nullcontext() with tracer as scope: kwargs, url = self._build_request_params(method, path, payload, request_headers, include_version, timeout, - use_body, "data") - if scope: - scope.span.set_tag('path', path) - scope.span.set_tag('timeout', timeout) - scope.span.set_tag(HTTP_URL, url) - scope.span.set_tag(HTTP_METHOD, method) - scope.span.set_tag('span.kind', 'client') + use_body, "data", scope) response = self._session.request(method, url, **kwargs) request_id = response.headers.get('x-dash-requestid') logger.debug('Handling response for [Request ID]: %s [Method]: %s [URL]: %s', request_id, method, url) @@ -344,10 +351,17 @@ async def __request_async( use_body: bool = False ) -> Union[Base, tuple, dict]: self._init_async() - kwargs, url = self._build_request_params(method, path, payload, request_headers, include_version, timeout, - use_body, "content") - response = await self._session_async.request(method, url, **kwargs) - request_id = response.headers.get('x-dash-requestid') + span = Tracer.get_instance().active_span + tracer = Tracer(f'http:/{path}') if span else nullcontext() + with tracer as scope: + kwargs, url = self._build_request_params(method, path, payload, request_headers, include_version, timeout, + use_body, "content", scope) + response = await self._session_async.request(method, url, **kwargs) + request_id = response.headers.get('x-dash-requestid') + if scope: + scope.span.set_tag(HTTP_STATUS_CODE, response.status_code) + scope.span.set_tag('dash.request.id', request_id) + logger.debug('Handling response for [Request ID]: %s [Method]: %s [URL]: %s', request_id, method, url) if response.status_code == 401: # Expired token or other authorization issue diff --git a/gs_quant/test/api/test_json.py b/gs_quant/test/api/test_json.py index 0e30bf60..26238481 100644 --- a/gs_quant/test/api/test_json.py +++ b/gs_quant/test/api/test_json.py @@ -20,7 +20,7 @@ import pytz from gs_quant.json_encoder import JSONEncoder -from gs_quant.target.workflow_quote import BinaryImageComments, ImgType, Encoding, HyperLinkImageComments, \ +from gs_quant.workflow import BinaryImageComments, ImgType, Encoding, HyperLinkImageComments, \ VisualStructuringReport, ChartingParameters, OverlayType diff --git a/gs_quant/tracing/tracing.py b/gs_quant/tracing/tracing.py index 61e04ca8..6b0f40ea 100644 --- a/gs_quant/tracing/tracing.py +++ b/gs_quant/tracing/tracing.py @@ -20,7 +20,7 @@ from typing import Tuple, Optional import pandas as pd -from opentracing import Span, UnsupportedFormatException +from opentracing import Span, UnsupportedFormatException, SpanContextCorruptedException from opentracing import Tracer as OpenTracer from opentracing.mocktracer import MockTracer @@ -55,6 +55,14 @@ def inject(format, carrier): except UnsupportedFormatException: pass + @staticmethod + def extract(format, carrier): + instance = Tracer.get_instance() + try: + return instance.extract(format, carrier) + except (UnsupportedFormatException, SpanContextCorruptedException): + pass + @staticmethod def set_factory(factory: TracerFactory): Tracer.__factory = factory @@ -148,7 +156,7 @@ def plot(reset=False): fig.show() @staticmethod - def gather_data(as_string: bool = True): + def gather_data(as_string: bool = True, root_id: Optional[str] = None): spans = Tracer.get_spans() spans_by_parent = {} @@ -168,7 +176,8 @@ def _build_tree(parent_span, depth): total = 0 lines = [] - for span in reversed(spans_by_parent.get(None, [])): + # By default, we look for the span with no parent, but this might not always be what we want + for span in reversed(spans_by_parent.get(root_id, [])): _build_tree(span, 0) total += (span.finish_time - span.start_time) * 1000 @@ -179,8 +188,8 @@ def _build_tree(parent_span, depth): return lines, total @staticmethod - def print(reset=True): - tracing_str, total = Tracer.gather_data() + def print(reset=True, root_id=None): + tracing_str, total = Tracer.gather_data(root_id=root_id) _logger.warning(f'Tracing Info:\n{tracing_str}\n{"-" * 61}\nTOTAL:{total:>52.1f} ms') if reset: Tracer.reset()