Skip to content

Commit

Permalink
Support float("inf") and float("nan") with experimental_python_types
Browse files Browse the repository at this point in the history
  • Loading branch information
mdesmet authored and ebyhr committed Apr 25, 2022
1 parent 939cab0 commit 81184b5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
8 changes: 3 additions & 5 deletions tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,9 +614,8 @@ def test_float_query_param(trino_connection):
assert rows[0][0] == 1.1


@pytest.mark.skip(reason="Nan currently not returning the correct python type for nan")
def test_float_nan_query_param(trino_connection):
cur = trino_connection.cursor()
cur = trino_connection.cursor(experimental_python_types=True)
cur.execute("SELECT ?", params=(float("nan"),))
rows = cur.fetchall()

Expand All @@ -625,15 +624,14 @@ def test_float_nan_query_param(trino_connection):
assert math.isnan(rows[0][0])


@pytest.mark.skip(reason="Nan currently not returning the correct python type fon inf")
def test_float_inf_query_param(trino_connection):
cur = trino_connection.cursor()
cur = trino_connection.cursor(experimental_python_types=True)
cur.execute("SELECT ?", params=(float("inf"),))
rows = cur.fetchall()

assert rows[0][0] == float("inf")

cur.execute("SELECT ?", params=(-float("-inf"),))
cur.execute("SELECT ?", params=(float("-inf"),))
rows = cur.fetchall()

assert rows[0][0] == float("-inf")
Expand Down
12 changes: 12 additions & 0 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@

_HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r'^\S[^\s=]*$')

INF = float("inf")
NEGATIVE_INF = float("-inf")
NAN = float("nan")


class ClientSession(object):
def __init__(
Expand Down Expand Up @@ -526,6 +530,14 @@ def _map_to_python_type(cls, item: Tuple[Any, Dict]) -> Any:
return [cls._map_to_python_type((array_item, raw_type)) for array_item in value]
elif "decimal" in raw_type:
return Decimal(value)
elif raw_type == "double":
if value == 'Infinity':
return INF
elif value == '-Infinity':
return NEGATIVE_INF
elif value == 'NaN':
return NAN
return value
elif raw_type == "date":
return datetime.strptime(value, "%Y-%m-%d").date()
elif raw_type == "timestamp with time zone":
Expand Down

0 comments on commit 81184b5

Please sign in to comment.