Skip to content

Commit

Permalink
Implementing client side statement in dbapi starting with commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ankiaga committed Nov 17, 2023
1 parent 38d62b2 commit eec98bf
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 44 deletions.
29 changes: 29 additions & 0 deletions google/cloud/spanner_dbapi/client_side_statement_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 20203 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.cloud.spanner_dbapi.parsed_statement import (
ParsedStatement,
ClientSideStatementType,
)


class StatementExecutor(object):
def __init__(self, connection):
self.connection = connection

def execute(self, parsed_statement: ParsedStatement):
if (
parsed_statement.client_side_statement_type
== ClientSideStatementType.COMMIT
):
self.connection.commit()
37 changes: 37 additions & 0 deletions google/cloud/spanner_dbapi/client_side_statement_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 20203 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re

from google.cloud.spanner_dbapi.parsed_statement import (
ParsedStatement,
StatementType,
ClientSideStatementType,
)

RE_COMMIT = re.compile(r"^\s*(COMMIT)(TRANSACTION)?", re.IGNORECASE)

RE_BEGIN = re.compile(r"^\s*(BEGIN|START)(TRANSACTION)?", re.IGNORECASE)


def parse_stmt(query):
if RE_COMMIT.match(query):
return ParsedStatement(
StatementType.CLIENT_SIDE, query, ClientSideStatementType.COMMIT
)
if RE_BEGIN.match(query):
return ParsedStatement(
StatementType.CLIENT_SIDE, query, ClientSideStatementType.BEGIN
)
return None
27 changes: 20 additions & 7 deletions google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from google.cloud import spanner_v1 as spanner
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.client_side_statement_executor import StatementExecutor
from google.cloud.spanner_dbapi.exceptions import IntegrityError
from google.cloud.spanner_dbapi.exceptions import InterfaceError
from google.cloud.spanner_dbapi.exceptions import OperationalError
Expand All @@ -39,6 +40,7 @@
from google.cloud.spanner_dbapi import parse_utils
from google.cloud.spanner_dbapi.parse_utils import get_param_types
from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner
from google.cloud.spanner_dbapi.parsed_statement import StatementType
from google.cloud.spanner_dbapi.utils import PeekIterator
from google.cloud.spanner_dbapi.utils import StreamedManyResultSets

Expand Down Expand Up @@ -85,6 +87,7 @@ def __init__(self, connection):
self._row_count = _UNSET_COUNT
self.lastrowid = None
self.connection = connection
self.client_side_statement_executor = StatementExecutor(connection)
self._is_closed = False
# the currently running SQL statement results checksum
self._checksum = None
Expand Down Expand Up @@ -239,8 +242,10 @@ def execute(self, sql, args=None):
self._handle_DQL(sql, args or None)
return

class_ = parse_utils.classify_stmt(sql)
if class_ == parse_utils.STMT_DDL:
parsed_statement = parse_utils.classify_stmt(sql)
if parsed_statement.statement_type == StatementType.CLIENT_SIDE:
self.client_side_statement_executor.execute(parsed_statement)
if parsed_statement.statement_type == StatementType.DDL:
self._batch_DDLs(sql)
if self.connection.autocommit:
self.connection.run_prior_DDL_statements()
Expand All @@ -251,7 +256,7 @@ def execute(self, sql, args=None):
# self._run_prior_DDL_statements()
self.connection.run_prior_DDL_statements()

if class_ == parse_utils.STMT_UPDATING:
if parsed_statement.statement_type == StatementType.UPDATE:
sql = parse_utils.ensure_where_clause(sql)

sql, args = sql_pyformat_args_to_spanner(sql, args or None)
Expand All @@ -276,7 +281,7 @@ def execute(self, sql, args=None):
self.connection.retry_transaction()
return

if class_ == parse_utils.STMT_NON_UPDATING:
if parsed_statement.statement_type == StatementType.QUERY:
self._handle_DQL(sql, args or None)
else:
self.connection.database.run_in_transaction(
Expand Down Expand Up @@ -309,15 +314,23 @@ def executemany(self, operation, seq_of_params):
self._result_set = None
self._row_count = _UNSET_COUNT

class_ = parse_utils.classify_stmt(operation)
if class_ == parse_utils.STMT_DDL:
parsed_statement = parse_utils.classify_stmt(operation)
if parsed_statement.statement_type == StatementType.DDL:
raise ProgrammingError(
"Executing DDL statements with executemany() method is not allowed."
)

if parsed_statement.statement_type == StatementType.CLIENT_SIDE:
raise ProgrammingError(
"Executing ClientSide statements with executemany() method is not allowed."
)

many_result_set = StreamedManyResultSets()

if class_ in (parse_utils.STMT_INSERT, parse_utils.STMT_UPDATING):
if parsed_statement.statement_type in (
StatementType.INSERT,
StatementType.UPDATE,
):
statements = []

for params in seq_of_params:
Expand Down
23 changes: 11 additions & 12 deletions google/cloud/spanner_dbapi/parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
import sqlparse
from google.cloud import spanner_v1 as spanner
from google.cloud.spanner_v1 import JsonObject
from . import client_side_statement_parser

from .exceptions import Error
from .parsed_statement import ParsedStatement, StatementType
from .types import DateStr, TimestampStr
from .utils import sanitize_literals_for_upload

Expand Down Expand Up @@ -139,11 +141,6 @@
"WITHIN",
}

STMT_DDL = "DDL"
STMT_NON_UPDATING = "NON_UPDATING"
STMT_UPDATING = "UPDATING"
STMT_INSERT = "INSERT"

# Heuristic for identifying statements that don't need to be run as updates.
RE_NON_UPDATE = re.compile(r"^\W*(SELECT)", re.IGNORECASE)

Expand Down Expand Up @@ -180,27 +177,29 @@ def classify_stmt(query):
:type query: str
:param query: A SQL query.
:rtype: str
:returns: The query type name.
:rtype: ParsedStatement
:returns: parsed statement attributes.
"""
# sqlparse will strip Cloud Spanner comments,
# still, special commenting styles, like
# PostgreSQL dollar quoted comments are not
# supported and will not be stripped.
query = sqlparse.format(query, strip_comments=True).strip()

parsed_statement = client_side_statement_parser.parse_stmt(query)
if parsed_statement is not None:
return parsed_statement
if RE_DDL.match(query):
return STMT_DDL
return ParsedStatement(StatementType.DDL, query)

if RE_IS_INSERT.match(query):
return STMT_INSERT
return ParsedStatement(StatementType.INSERT, query)

if RE_NON_UPDATE.match(query) or RE_WITH.match(query):
# As of 13-March-2020, Cloud Spanner only supports WITH for DQL
# statements and doesn't yet support WITH for DML statements.
return STMT_NON_UPDATING
return ParsedStatement(StatementType.QUERY, query)

return STMT_UPDATING
return ParsedStatement(StatementType.UPDATE, query)


def sql_pyformat_args_to_spanner(sql, params):
Expand Down
36 changes: 36 additions & 0 deletions google/cloud/spanner_dbapi/parsed_statement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 20203 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from enum import Enum


class StatementType(Enum):
CLIENT_SIDE = 1
DDL = 2
QUERY = 3
UPDATE = 4
INSERT = 5


class ClientSideStatementType(Enum):
COMMIT = 1
BEGIN = 2


@dataclass
class ParsedStatement:
statement_type: StatementType
query: str
client_side_statement_type: ClientSideStatementType = None
4 changes: 2 additions & 2 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
ISORT_VERSION = "isort==5.11.0"
LINT_PATHS = ["docs", "google", "tests", "noxfile.py", "setup.py"]

DEFAULT_PYTHON_VERSION = "3.8"
DEFAULT_PYTHON_VERSION = "3.9"

UNIT_TEST_PYTHON_VERSIONS: List[str] = ["3.7", "3.8", "3.9", "3.10", "3.11"]
UNIT_TEST_STANDARD_DEPENDENCIES = [
Expand All @@ -48,7 +48,7 @@
UNIT_TEST_EXTRAS: List[str] = []
UNIT_TEST_EXTRAS_BY_PYTHON: Dict[str, List[str]] = {}

SYSTEM_TEST_PYTHON_VERSIONS: List[str] = ["3.8"]
SYSTEM_TEST_PYTHON_VERSIONS: List[str] = ["3.9"]
SYSTEM_TEST_STANDARD_DEPENDENCIES: List[str] = [
"mock",
"pytest",
Expand Down
9 changes: 5 additions & 4 deletions tests/unit/spanner_dbapi/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import sys
import unittest

from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, StatementType


class TestCursor(unittest.TestCase):
INSTANCE = "test-instance"
Expand Down Expand Up @@ -192,17 +194,16 @@ def test_execute_insert_statement_autocommit_off(self):
cursor.connection.transaction_checkout = mock.MagicMock(autospec=True)

cursor._checksum = ResultsChecksum()
sql = "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)"
with mock.patch(
"google.cloud.spanner_dbapi.parse_utils.classify_stmt",
return_value=parse_utils.STMT_UPDATING,
return_value=ParsedStatement(StatementType.DDL, sql),
):
with mock.patch(
"google.cloud.spanner_dbapi.connection.Connection.run_statement",
return_value=(mock.MagicMock(), ResultsChecksum()),
):
cursor.execute(
sql="INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)"
)
cursor.execute(sql)
self.assertIsInstance(cursor._result_set, mock.MagicMock)
self.assertIsInstance(cursor._itr, PeekIterator)

Expand Down
41 changes: 22 additions & 19 deletions tests/unit/spanner_dbapi/test_parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import sys
import unittest

from google.cloud.spanner_dbapi.parsed_statement import StatementType
from google.cloud.spanner_v1 import param_types
from google.cloud.spanner_v1 import JsonObject

Expand All @@ -24,45 +25,47 @@ class TestParseUtils(unittest.TestCase):
skip_message = "Subtests are not supported in Python 2"

def test_classify_stmt(self):
from google.cloud.spanner_dbapi.parse_utils import STMT_DDL
from google.cloud.spanner_dbapi.parse_utils import STMT_INSERT
from google.cloud.spanner_dbapi.parse_utils import STMT_NON_UPDATING
from google.cloud.spanner_dbapi.parse_utils import STMT_UPDATING
from google.cloud.spanner_dbapi.parse_utils import classify_stmt

cases = (
("SELECT 1", STMT_NON_UPDATING),
("SELECT s.SongName FROM Songs AS s", STMT_NON_UPDATING),
("(SELECT s.SongName FROM Songs AS s)", STMT_NON_UPDATING),
("SELECT 1", StatementType.QUERY),
("SELECT s.SongName FROM Songs AS s", StatementType.QUERY),
("(SELECT s.SongName FROM Songs AS s)", StatementType.QUERY),
(
"WITH sq AS (SELECT SchoolID FROM Roster) SELECT * from sq",
STMT_NON_UPDATING,
StatementType.QUERY,
),
(
"CREATE TABLE django_content_type (id STRING(64) NOT NULL, name STRING(100) "
"NOT NULL, app_label STRING(100) NOT NULL, model STRING(100) NOT NULL) PRIMARY KEY(id)",
STMT_DDL,
StatementType.DDL,
),
(
"CREATE INDEX SongsBySingerAlbumSongNameDesc ON "
"Songs(SingerId, AlbumId, SongName DESC), INTERLEAVE IN Albums",
STMT_DDL,
StatementType.DDL,
),
("CREATE INDEX SongsBySongName ON Songs(SongName)", STMT_DDL),
("CREATE INDEX SongsBySongName ON Songs(SongName)", StatementType.DDL),
(
"CREATE INDEX AlbumsByAlbumTitle2 ON Albums(AlbumTitle) STORING (MarketingBudget)",
STMT_DDL,
StatementType.DDL,
),
("CREATE ROLE parent", STMT_DDL),
("GRANT SELECT ON TABLE Singers TO ROLE parent", STMT_DDL),
("REVOKE SELECT ON TABLE Singers TO ROLE parent", STMT_DDL),
("GRANT ROLE parent TO ROLE child", STMT_DDL),
("INSERT INTO table (col1) VALUES (1)", STMT_INSERT),
("UPDATE table SET col1 = 1 WHERE col1 = NULL", STMT_UPDATING),
("CREATE ROLE parent", StatementType.DDL),
("commit", StatementType.CLIENT_SIDE),
(" commit TRANSACTION ", StatementType.CLIENT_SIDE),
("begin", StatementType.CLIENT_SIDE),
("start", StatementType.CLIENT_SIDE),
("begin transaction", StatementType.CLIENT_SIDE),
("start transaction", StatementType.CLIENT_SIDE),
("GRANT SELECT ON TABLE Singers TO ROLE parent", StatementType.DDL),
("REVOKE SELECT ON TABLE Singers TO ROLE parent", StatementType.DDL),
("GRANT ROLE parent TO ROLE child", StatementType.DDL),
("INSERT INTO table (col1) VALUES (1)", StatementType.INSERT),
("UPDATE table SET col1 = 1 WHERE col1 = NULL", StatementType.UPDATE),
)

for query, want_class in cases:
self.assertEqual(classify_stmt(query), want_class)
self.assertEqual(classify_stmt(query).statement_type, want_class)

@unittest.skipIf(skip_condition, skip_message)
def test_sql_pyformat_args_to_spanner(self):
Expand Down

0 comments on commit eec98bf

Please sign in to comment.