diff --git a/google/cloud/spanner_dbapi/client_side_statement_executor.py b/google/cloud/spanner_dbapi/client_side_statement_executor.py new file mode 100644 index 00000000000..6fef255c4ba --- /dev/null +++ b/google/cloud/spanner_dbapi/client_side_statement_executor.py @@ -0,0 +1,29 @@ +# Copyright 20203 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from google.cloud.spanner_dbapi.parsed_statement import ( + ParsedStatement, + ClientSideStatementType, +) + + +class StatementExecutor(object): + def __init__(self, connection): + self.connection = connection + + def execute(self, parsed_statement: ParsedStatement): + if ( + parsed_statement.client_side_statement_type + == ClientSideStatementType.COMMIT + ): + self.connection.commit() diff --git a/google/cloud/spanner_dbapi/client_side_statement_parser.py b/google/cloud/spanner_dbapi/client_side_statement_parser.py new file mode 100644 index 00000000000..baaaf78d539 --- /dev/null +++ b/google/cloud/spanner_dbapi/client_side_statement_parser.py @@ -0,0 +1,37 @@ +# Copyright 20203 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +from google.cloud.spanner_dbapi.parsed_statement import ( + ParsedStatement, + StatementType, + ClientSideStatementType, +) + +RE_COMMIT = re.compile(r"^\s*(COMMIT)(TRANSACTION)?", re.IGNORECASE) + +RE_BEGIN = re.compile(r"^\s*(BEGIN|START)(TRANSACTION)?", re.IGNORECASE) + + +def parse_stmt(query): + if RE_COMMIT.match(query): + return ParsedStatement( + StatementType.CLIENT_SIDE, query, ClientSideStatementType.COMMIT + ) + if RE_BEGIN.match(query): + return ParsedStatement( + StatementType.CLIENT_SIDE, query, ClientSideStatementType.BEGIN + ) + return None diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 91bccedd4ce..a49b295ab6d 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -27,6 +27,7 @@ from google.cloud import spanner_v1 as spanner from google.cloud.spanner_dbapi.checksum import ResultsChecksum +from google.cloud.spanner_dbapi.client_side_statement_executor import StatementExecutor from google.cloud.spanner_dbapi.exceptions import IntegrityError from google.cloud.spanner_dbapi.exceptions import InterfaceError from google.cloud.spanner_dbapi.exceptions import OperationalError @@ -39,6 +40,7 @@ from google.cloud.spanner_dbapi import parse_utils from google.cloud.spanner_dbapi.parse_utils import get_param_types from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner +from google.cloud.spanner_dbapi.parsed_statement import StatementType from google.cloud.spanner_dbapi.utils import PeekIterator from google.cloud.spanner_dbapi.utils import StreamedManyResultSets @@ -85,6 +87,7 @@ def __init__(self, connection): self._row_count = _UNSET_COUNT self.lastrowid = None self.connection = connection + self.client_side_statement_executor = StatementExecutor(connection) self._is_closed = False # the currently running SQL statement results checksum self._checksum = None @@ -239,8 +242,10 @@ def execute(self, sql, args=None): self._handle_DQL(sql, args or None) return - class_ = parse_utils.classify_stmt(sql) - if class_ == parse_utils.STMT_DDL: + parsed_statement = parse_utils.classify_stmt(sql) + if parsed_statement.statement_type == StatementType.CLIENT_SIDE: + self.client_side_statement_executor.execute(parsed_statement) + if parsed_statement.statement_type == StatementType.DDL: self._batch_DDLs(sql) if self.connection.autocommit: self.connection.run_prior_DDL_statements() @@ -251,7 +256,7 @@ def execute(self, sql, args=None): # self._run_prior_DDL_statements() self.connection.run_prior_DDL_statements() - if class_ == parse_utils.STMT_UPDATING: + if parsed_statement.statement_type == StatementType.UPDATE: sql = parse_utils.ensure_where_clause(sql) sql, args = sql_pyformat_args_to_spanner(sql, args or None) @@ -276,7 +281,7 @@ def execute(self, sql, args=None): self.connection.retry_transaction() return - if class_ == parse_utils.STMT_NON_UPDATING: + if parsed_statement.statement_type == StatementType.QUERY: self._handle_DQL(sql, args or None) else: self.connection.database.run_in_transaction( @@ -309,15 +314,23 @@ def executemany(self, operation, seq_of_params): self._result_set = None self._row_count = _UNSET_COUNT - class_ = parse_utils.classify_stmt(operation) - if class_ == parse_utils.STMT_DDL: + parsed_statement = parse_utils.classify_stmt(operation) + if parsed_statement.statement_type == StatementType.DDL: raise ProgrammingError( "Executing DDL statements with executemany() method is not allowed." ) + if parsed_statement.statement_type == StatementType.CLIENT_SIDE: + raise ProgrammingError( + "Executing ClientSide statements with executemany() method is not allowed." + ) + many_result_set = StreamedManyResultSets() - if class_ in (parse_utils.STMT_INSERT, parse_utils.STMT_UPDATING): + if parsed_statement.statement_type in ( + StatementType.INSERT, + StatementType.UPDATE, + ): statements = [] for params in seq_of_params: diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py index 84cb2dc7a59..436a82e2ca0 100644 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -21,8 +21,10 @@ import sqlparse from google.cloud import spanner_v1 as spanner from google.cloud.spanner_v1 import JsonObject +from . import client_side_statement_parser from .exceptions import Error +from .parsed_statement import ParsedStatement, StatementType from .types import DateStr, TimestampStr from .utils import sanitize_literals_for_upload @@ -139,11 +141,6 @@ "WITHIN", } -STMT_DDL = "DDL" -STMT_NON_UPDATING = "NON_UPDATING" -STMT_UPDATING = "UPDATING" -STMT_INSERT = "INSERT" - # Heuristic for identifying statements that don't need to be run as updates. RE_NON_UPDATE = re.compile(r"^\W*(SELECT)", re.IGNORECASE) @@ -180,27 +177,29 @@ def classify_stmt(query): :type query: str :param query: A SQL query. - :rtype: str - :returns: The query type name. + :rtype: ParsedStatement + :returns: parsed statement attributes. """ # sqlparse will strip Cloud Spanner comments, # still, special commenting styles, like # PostgreSQL dollar quoted comments are not # supported and will not be stripped. query = sqlparse.format(query, strip_comments=True).strip() - + parsed_statement = client_side_statement_parser.parse_stmt(query) + if parsed_statement is not None: + return parsed_statement if RE_DDL.match(query): - return STMT_DDL + return ParsedStatement(StatementType.DDL, query) if RE_IS_INSERT.match(query): - return STMT_INSERT + return ParsedStatement(StatementType.INSERT, query) if RE_NON_UPDATE.match(query) or RE_WITH.match(query): # As of 13-March-2020, Cloud Spanner only supports WITH for DQL # statements and doesn't yet support WITH for DML statements. - return STMT_NON_UPDATING + return ParsedStatement(StatementType.QUERY, query) - return STMT_UPDATING + return ParsedStatement(StatementType.UPDATE, query) def sql_pyformat_args_to_spanner(sql, params): diff --git a/google/cloud/spanner_dbapi/parsed_statement.py b/google/cloud/spanner_dbapi/parsed_statement.py new file mode 100644 index 00000000000..c36bc1d81cf --- /dev/null +++ b/google/cloud/spanner_dbapi/parsed_statement.py @@ -0,0 +1,36 @@ +# Copyright 20203 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from enum import Enum + + +class StatementType(Enum): + CLIENT_SIDE = 1 + DDL = 2 + QUERY = 3 + UPDATE = 4 + INSERT = 5 + + +class ClientSideStatementType(Enum): + COMMIT = 1 + BEGIN = 2 + + +@dataclass +class ParsedStatement: + statement_type: StatementType + query: str + client_side_statement_type: ClientSideStatementType = None diff --git a/noxfile.py b/noxfile.py index b1274090f0b..0f38a8ea4e7 100644 --- a/noxfile.py +++ b/noxfile.py @@ -32,7 +32,7 @@ ISORT_VERSION = "isort==5.11.0" LINT_PATHS = ["docs", "google", "tests", "noxfile.py", "setup.py"] -DEFAULT_PYTHON_VERSION = "3.8" +DEFAULT_PYTHON_VERSION = "3.9" UNIT_TEST_PYTHON_VERSIONS: List[str] = ["3.7", "3.8", "3.9", "3.10", "3.11"] UNIT_TEST_STANDARD_DEPENDENCIES = [ @@ -48,7 +48,7 @@ UNIT_TEST_EXTRAS: List[str] = [] UNIT_TEST_EXTRAS_BY_PYTHON: Dict[str, List[str]] = {} -SYSTEM_TEST_PYTHON_VERSIONS: List[str] = ["3.8"] +SYSTEM_TEST_PYTHON_VERSIONS: List[str] = ["3.9"] SYSTEM_TEST_STANDARD_DEPENDENCIES: List[str] = [ "mock", "pytest", diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 46a093b109a..b5dc848397a 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -18,6 +18,8 @@ import sys import unittest +from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, StatementType + class TestCursor(unittest.TestCase): INSTANCE = "test-instance" @@ -192,17 +194,16 @@ def test_execute_insert_statement_autocommit_off(self): cursor.connection.transaction_checkout = mock.MagicMock(autospec=True) cursor._checksum = ResultsChecksum() + sql = "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)" with mock.patch( "google.cloud.spanner_dbapi.parse_utils.classify_stmt", - return_value=parse_utils.STMT_UPDATING, + return_value=ParsedStatement(StatementType.DDL, sql), ): with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.run_statement", return_value=(mock.MagicMock(), ResultsChecksum()), ): - cursor.execute( - sql="INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)" - ) + cursor.execute(sql) self.assertIsInstance(cursor._result_set, mock.MagicMock) self.assertIsInstance(cursor._itr, PeekIterator) diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py index 887f984c2c5..81adcb134ea 100644 --- a/tests/unit/spanner_dbapi/test_parse_utils.py +++ b/tests/unit/spanner_dbapi/test_parse_utils.py @@ -15,6 +15,7 @@ import sys import unittest +from google.cloud.spanner_dbapi.parsed_statement import StatementType from google.cloud.spanner_v1 import param_types from google.cloud.spanner_v1 import JsonObject @@ -24,45 +25,47 @@ class TestParseUtils(unittest.TestCase): skip_message = "Subtests are not supported in Python 2" def test_classify_stmt(self): - from google.cloud.spanner_dbapi.parse_utils import STMT_DDL - from google.cloud.spanner_dbapi.parse_utils import STMT_INSERT - from google.cloud.spanner_dbapi.parse_utils import STMT_NON_UPDATING - from google.cloud.spanner_dbapi.parse_utils import STMT_UPDATING from google.cloud.spanner_dbapi.parse_utils import classify_stmt cases = ( - ("SELECT 1", STMT_NON_UPDATING), - ("SELECT s.SongName FROM Songs AS s", STMT_NON_UPDATING), - ("(SELECT s.SongName FROM Songs AS s)", STMT_NON_UPDATING), + ("SELECT 1", StatementType.QUERY), + ("SELECT s.SongName FROM Songs AS s", StatementType.QUERY), + ("(SELECT s.SongName FROM Songs AS s)", StatementType.QUERY), ( "WITH sq AS (SELECT SchoolID FROM Roster) SELECT * from sq", - STMT_NON_UPDATING, + StatementType.QUERY, ), ( "CREATE TABLE django_content_type (id STRING(64) NOT NULL, name STRING(100) " "NOT NULL, app_label STRING(100) NOT NULL, model STRING(100) NOT NULL) PRIMARY KEY(id)", - STMT_DDL, + StatementType.DDL, ), ( "CREATE INDEX SongsBySingerAlbumSongNameDesc ON " "Songs(SingerId, AlbumId, SongName DESC), INTERLEAVE IN Albums", - STMT_DDL, + StatementType.DDL, ), - ("CREATE INDEX SongsBySongName ON Songs(SongName)", STMT_DDL), + ("CREATE INDEX SongsBySongName ON Songs(SongName)", StatementType.DDL), ( "CREATE INDEX AlbumsByAlbumTitle2 ON Albums(AlbumTitle) STORING (MarketingBudget)", - STMT_DDL, + StatementType.DDL, ), - ("CREATE ROLE parent", STMT_DDL), - ("GRANT SELECT ON TABLE Singers TO ROLE parent", STMT_DDL), - ("REVOKE SELECT ON TABLE Singers TO ROLE parent", STMT_DDL), - ("GRANT ROLE parent TO ROLE child", STMT_DDL), - ("INSERT INTO table (col1) VALUES (1)", STMT_INSERT), - ("UPDATE table SET col1 = 1 WHERE col1 = NULL", STMT_UPDATING), + ("CREATE ROLE parent", StatementType.DDL), + ("commit", StatementType.CLIENT_SIDE), + (" commit TRANSACTION ", StatementType.CLIENT_SIDE), + ("begin", StatementType.CLIENT_SIDE), + ("start", StatementType.CLIENT_SIDE), + ("begin transaction", StatementType.CLIENT_SIDE), + ("start transaction", StatementType.CLIENT_SIDE), + ("GRANT SELECT ON TABLE Singers TO ROLE parent", StatementType.DDL), + ("REVOKE SELECT ON TABLE Singers TO ROLE parent", StatementType.DDL), + ("GRANT ROLE parent TO ROLE child", StatementType.DDL), + ("INSERT INTO table (col1) VALUES (1)", StatementType.INSERT), + ("UPDATE table SET col1 = 1 WHERE col1 = NULL", StatementType.UPDATE), ) for query, want_class in cases: - self.assertEqual(classify_stmt(query), want_class) + self.assertEqual(classify_stmt(query).statement_type, want_class) @unittest.skipIf(skip_condition, skip_message) def test_sql_pyformat_args_to_spanner(self):