Skip to content

Commit

Permalink
Add support for executemany in DBAPI
Browse files Browse the repository at this point in the history
  • Loading branch information
ebyhr committed Apr 11, 2022
1 parent 0430df3 commit e21b06d
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 3 deletions.
43 changes: 42 additions & 1 deletion tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import trino
from tests.integration.conftest import trino_version
from trino.exceptions import TrinoQueryError, TrinoUserError
from trino.exceptions import TrinoQueryError, TrinoUserError, NotSupportedError
from trino.transaction import IsolationLevel


Expand Down Expand Up @@ -124,6 +124,47 @@ def test_string_query_param(trino_connection):
assert rows[0][0] == "six'"


def test_execute_many(trino_connection):
cur = trino_connection.cursor()
cur.execute("CREATE TABLE memory.default.test_execute_many (key int, value varchar)")
cur.fetchall()
operation = "INSERT INTO memory.default.test_execute_many (key, value) VALUES (?, ?)"
cur.executemany(operation, [(1, "value1")])
cur.fetchall()
cur.execute("SELECT * FROM memory.default.test_execute_many ORDER BY key")
rows = cur.fetchall()
assert len(list(rows)) == 1
assert rows[0] == [1, "value1"]

operation = "INSERT INTO memory.default.test_execute_many (key, value) VALUES (?, ?)"
cur.executemany(operation, [(2, "value2"), (3, "value3")])
cur.fetchall()

cur.execute("SELECT * FROM memory.default.test_execute_many ORDER BY key")
rows = cur.fetchall()
assert len(list(rows)) == 3
assert rows[0] == [1, "value1"]
assert rows[1] == [2, "value2"]
assert rows[2] == [3, "value3"]


def test_execute_many_without_params(trino_connection):
cur = trino_connection.cursor()
cur.execute("CREATE TABLE memory.default.test_execute_many_without_param (value varchar)")
cur.fetchall()
cur.executemany("INSERT INTO memory.default.test_execute_many_without_param (value) VALUES (?)", [])
with pytest.raises(TrinoUserError) as e:
cur.fetchall()
assert "Incorrect number of parameters: expected 1 but found 0" in str(e.value)


def test_execute_many_select(trino_connection):
cur = trino_connection.cursor()
with pytest.raises(NotSupportedError) as e:
cur.executemany("SELECT ?, ?", [(1, "value1"), (2, "value2")])
assert "Query must return update type" in str(e.value)


def test_python_types_not_used_when_experimental_python_types_is_not_set(trino_connection):
cur = trino_connection.cursor()

Expand Down
10 changes: 9 additions & 1 deletion trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,13 @@ def get_session_property_values(headers, header):


class TrinoStatus(object):
def __init__(self, id, stats, warnings, info_uri, next_uri, rows, columns=None):
def __init__(self, id, stats, warnings, info_uri, next_uri, update_type, rows, columns=None):
self.id = id
self.stats = stats
self.warnings = warnings
self.info_uri = info_uri
self.next_uri = next_uri
self.update_type = update_type
self.rows = rows
self.columns = columns

Expand Down Expand Up @@ -448,6 +449,7 @@ def process(self, http_response) -> TrinoStatus:
warnings=response.get("warnings", []),
info_uri=response["infoUri"],
next_uri=self._next_uri,
update_type=response.get("updateType"),
rows=response.get("data", []),
columns=response.get("columns"),
)
Expand Down Expand Up @@ -572,6 +574,7 @@ def __init__(
self._finished = False
self._cancelled = False
self._request = request
self._update_type = None
self._sql = sql
self._result = TrinoResult(self, experimental_python_types=experimental_python_types)
self._response_headers = None
Expand All @@ -590,6 +593,10 @@ def columns(self):
def stats(self):
return self._stats

@property
def update_type(self):
return self._update_type

@property
def warnings(self):
return self._warnings
Expand Down Expand Up @@ -627,6 +634,7 @@ def execute(self, additional_http_headers=None) -> TrinoResult:

def _update_state(self, status):
self._stats.update(status.stats)
self._update_type = status.update_type
if status.columns:
self._columns = status.columns

Expand Down
31 changes: 30 additions & 1 deletion trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,12 @@ def info_uri(self):
return self._query.info_uri
return None

@property
def update_type(self):
if self._query is not None:
return self._query.update_type
return None

@property
def description(self):
if self._query.columns is None:
Expand Down Expand Up @@ -465,7 +471,30 @@ def execute(self, operation, params=None):
return result

def executemany(self, operation, seq_of_params):
raise trino.exceptions.NotSupportedError
"""
PEP-0249: Prepare a database operation (query or command) and then
execute it against all parameter sequences or mappings found in the sequence seq_of_parameters.
Modules are free to implement this method using multiple calls to
the .execute() method or by using array operations to have the
database process the sequence as a whole in one call.
Use of this method for an operation which produces one or more result
sets constitutes undefined behavior, and the implementation is permitted (but not required)
to raise an exception when it detects that a result set has been created by an invocation of the operation.
The same comments as for .execute() also apply accordingly to this method.
Return values are not defined.
"""
for parameters in seq_of_params[:-1]:
self.execute(operation, parameters)
self.fetchall()
if self._query.update_type is None:
raise NotSupportedError("Query must return update type")
if seq_of_params:
self.execute(operation, seq_of_params[-1])
else:
self.execute(operation)

def fetchone(self) -> Optional[List[Any]]:
"""
Expand Down

0 comments on commit e21b06d

Please sign in to comment.