diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index 1e647db339..4f708b20cf 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -17,6 +17,7 @@ import datetime import decimal import math +import time from google.protobuf.struct_pb2 import ListValue from google.protobuf.struct_pb2 import Value @@ -294,6 +295,59 @@ def _metadata_with_prefix(prefix, **kw): return [("google-cloud-resource-prefix", prefix)] +def _retry( + func, + retry_count=5, + delay=2, + allowed_exceptions=None, +): + """ + Retry a function with a specified number of retries, delay between retries, and list of allowed exceptions. + + Args: + func: The function to be retried. + retry_count: The maximum number of times to retry the function. + delay: The delay in seconds between retries. + allowed_exceptions: A tuple of exceptions that are allowed to occur without triggering a retry. + Passing allowed_exceptions as None will lead to retrying for all exceptions. + + Returns: + The result of the function if it is successful, or raises the last exception if all retries fail. + """ + retries = 0 + while retries <= retry_count: + try: + return func() + except Exception as exc: + if ( + allowed_exceptions is None or exc.__class__ in allowed_exceptions + ) and retries < retry_count: + if ( + allowed_exceptions is not None + and allowed_exceptions[exc.__class__] is not None + ): + allowed_exceptions[exc.__class__](exc) + time.sleep(delay) + delay = delay * 2 + retries = retries + 1 + else: + raise exc + + +def _check_rst_stream_error(exc): + resumable_error = ( + any( + resumable_message in exc.message + for resumable_message in ( + "RST_STREAM", + "Received unexpected EOS on DATA frame from server", + ) + ), + ) + if not resumable_error: + raise + + def _metadata_with_leader_aware_routing(value, **kw): """Create RPC metadata containing a leader aware routing header diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 7ee0392aa4..6b71e6d825 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -13,6 +13,7 @@ # limitations under the License. """Context manager for Cloud Spanner batched writes.""" +import functools from google.cloud.spanner_v1 import CommitRequest from google.cloud.spanner_v1 import Mutation @@ -26,6 +27,9 @@ ) from google.cloud.spanner_v1._opentelemetry_tracing import trace_call from google.cloud.spanner_v1 import RequestOptions +from google.cloud.spanner_v1._helpers import _retry +from google.cloud.spanner_v1._helpers import _check_rst_stream_error +from google.api_core.exceptions import InternalServerError class _BatchBase(_SessionWrapper): @@ -186,10 +190,15 @@ def commit(self, return_commit_stats=False, request_options=None): request_options=request_options, ) with trace_call("CloudSpanner.Commit", self._session, trace_attributes): - response = api.commit( + method = functools.partial( + api.commit, request=request, metadata=metadata, ) + response = _retry( + method, + allowed_exceptions={InternalServerError: _check_rst_stream_error}, + ) self.committed = response.commit_timestamp self.commit_stats = response.commit_stats return self.committed diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index dc526c9504..6d17bfc386 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -29,13 +29,15 @@ from google.api_core.exceptions import ServiceUnavailable from google.api_core.exceptions import InvalidArgument from google.api_core import gapic_v1 -from google.cloud.spanner_v1._helpers import _make_value_pb -from google.cloud.spanner_v1._helpers import _merge_query_options from google.cloud.spanner_v1._helpers import ( + _make_value_pb, + _merge_query_options, _metadata_with_prefix, _metadata_with_leader_aware_routing, + _retry, + _check_rst_stream_error, + _SessionWrapper, ) -from google.cloud.spanner_v1._helpers import _SessionWrapper from google.cloud.spanner_v1._opentelemetry_tracing import trace_call from google.cloud.spanner_v1.streamed import StreamedResultSet from google.cloud.spanner_v1 import RequestOptions @@ -560,12 +562,17 @@ def partition_read( with trace_call( "CloudSpanner.PartitionReadOnlyTransaction", self._session, trace_attributes ): - response = api.partition_read( + method = functools.partial( + api.partition_read, request=request, metadata=metadata, retry=retry, timeout=timeout, ) + response = _retry( + method, + allowed_exceptions={InternalServerError: _check_rst_stream_error}, + ) return [partition.partition_token for partition in response.partitions] @@ -659,12 +666,17 @@ def partition_query( self._session, trace_attributes, ): - response = api.partition_query( + method = functools.partial( + api.partition_query, request=request, metadata=metadata, retry=retry, timeout=timeout, ) + response = _retry( + method, + allowed_exceptions={InternalServerError: _check_rst_stream_error}, + ) return [partition.partition_token for partition in response.partitions] @@ -791,10 +803,15 @@ def begin(self): ) txn_selector = self._make_txn_selector() with trace_call("CloudSpanner.BeginTransaction", self._session): - response = api.begin_transaction( + method = functools.partial( + api.begin_transaction, session=self._session.name, options=txn_selector.begin, metadata=metadata, ) + response = _retry( + method, + allowed_exceptions={InternalServerError: _check_rst_stream_error}, + ) self._transaction_id = response.id return self._transaction_id diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 31ce4b24f8..dee99a0c6f 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -22,6 +22,8 @@ _merge_query_options, _metadata_with_prefix, _metadata_with_leader_aware_routing, + _retry, + _check_rst_stream_error, ) from google.cloud.spanner_v1 import CommitRequest from google.cloud.spanner_v1 import ExecuteBatchDmlRequest @@ -33,6 +35,7 @@ from google.cloud.spanner_v1._opentelemetry_tracing import trace_call from google.cloud.spanner_v1 import RequestOptions from google.api_core import gapic_v1 +from google.api_core.exceptions import InternalServerError class Transaction(_SnapshotBase, _BatchBase): @@ -102,7 +105,11 @@ def _execute_request( transaction = self._make_txn_selector() request.transaction = transaction with trace_call(trace_name, session, attributes): - response = method(request=request) + method = functools.partial(method, request=request) + response = _retry( + method, + allowed_exceptions={InternalServerError: _check_rst_stream_error}, + ) return response @@ -132,8 +139,15 @@ def begin(self): ) txn_options = TransactionOptions(read_write=TransactionOptions.ReadWrite()) with trace_call("CloudSpanner.BeginTransaction", self._session): - response = api.begin_transaction( - session=self._session.name, options=txn_options, metadata=metadata + method = functools.partial( + api.begin_transaction, + session=self._session.name, + options=txn_options, + metadata=metadata, + ) + response = _retry( + method, + allowed_exceptions={InternalServerError: _check_rst_stream_error}, ) self._transaction_id = response.id return self._transaction_id @@ -153,11 +167,16 @@ def rollback(self): ) ) with trace_call("CloudSpanner.Rollback", self._session): - api.rollback( + method = functools.partial( + api.rollback, session=self._session.name, transaction_id=self._transaction_id, metadata=metadata, ) + _retry( + method, + allowed_exceptions={InternalServerError: _check_rst_stream_error}, + ) self.rolled_back = True del self._session._transaction @@ -212,10 +231,15 @@ def commit(self, return_commit_stats=False, request_options=None): request_options=request_options, ) with trace_call("CloudSpanner.Commit", self._session, trace_attributes): - response = api.commit( + method = functools.partial( + api.commit, request=request, metadata=metadata, ) + response = _retry( + method, + allowed_exceptions={InternalServerError: _check_rst_stream_error}, + ) self.committed = response.commit_timestamp if return_commit_stats: self.commit_stats = response.commit_stats diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 6867c20d36..1628f84062 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -170,7 +170,7 @@ def test__session_checkout(self, mock_database): connection._session_checkout() self.assertEqual(connection._session, "db_session") - def test__session_checkout_database_error(self): + def test_session_checkout_database_error(self): from google.cloud.spanner_dbapi import Connection connection = Connection(INSTANCE) @@ -191,7 +191,7 @@ def test__release_session(self, mock_database): pool.put.assert_called_once_with("session") self.assertIsNone(connection._session) - def test__release_session_database_error(self): + def test_release_session_database_error(self): from google.cloud.spanner_dbapi import Connection connection = Connection(INSTANCE) diff --git a/tests/unit/test__helpers.py b/tests/unit/test__helpers.py index e90d2dec82..0e0ec903a2 100644 --- a/tests/unit/test__helpers.py +++ b/tests/unit/test__helpers.py @@ -14,6 +14,7 @@ import unittest +import mock class Test_merge_query_options(unittest.TestCase): @@ -671,6 +672,83 @@ def test(self): self.assertEqual(metadata, [("google-cloud-resource-prefix", prefix)]) +class Test_retry(unittest.TestCase): + class test_class: + def test_fxn(self): + return True + + def test_retry_on_error(self): + from google.api_core.exceptions import InternalServerError, NotFound + from google.cloud.spanner_v1._helpers import _retry + import functools + + test_api = mock.create_autospec(self.test_class) + test_api.test_fxn.side_effect = [ + InternalServerError("testing"), + NotFound("testing"), + True, + ] + + _retry(functools.partial(test_api.test_fxn)) + + self.assertEqual(test_api.test_fxn.call_count, 3) + + def test_retry_allowed_exceptions(self): + from google.api_core.exceptions import InternalServerError, NotFound + from google.cloud.spanner_v1._helpers import _retry + import functools + + test_api = mock.create_autospec(self.test_class) + test_api.test_fxn.side_effect = [ + NotFound("testing"), + InternalServerError("testing"), + True, + ] + + with self.assertRaises(InternalServerError): + _retry( + functools.partial(test_api.test_fxn), + allowed_exceptions={NotFound: None}, + ) + + self.assertEqual(test_api.test_fxn.call_count, 2) + + def test_retry_count(self): + from google.api_core.exceptions import InternalServerError + from google.cloud.spanner_v1._helpers import _retry + import functools + + test_api = mock.create_autospec(self.test_class) + test_api.test_fxn.side_effect = [ + InternalServerError("testing"), + InternalServerError("testing"), + ] + + with self.assertRaises(InternalServerError): + _retry(functools.partial(test_api.test_fxn), retry_count=1) + + self.assertEqual(test_api.test_fxn.call_count, 2) + + def test_check_rst_stream_error(self): + from google.api_core.exceptions import InternalServerError + from google.cloud.spanner_v1._helpers import _retry, _check_rst_stream_error + import functools + + test_api = mock.create_autospec(self.test_class) + test_api.test_fxn.side_effect = [ + InternalServerError("Received unexpected EOS on DATA frame from server"), + InternalServerError("RST_STREAM"), + True, + ] + + _retry( + functools.partial(test_api.test_fxn), + allowed_exceptions={InternalServerError: _check_rst_stream_error}, + ) + + self.assertEqual(test_api.test_fxn.call_count, 3) + + class Test_metadata_with_leader_aware_routing(unittest.TestCase): def _call_fut(self, *args, **kw): from google.cloud.spanner_v1._helpers import _metadata_with_leader_aware_routing diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 2731e4f258..285328387c 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -1155,6 +1155,40 @@ def test_partition_read_other_error(self): ), ) + def test_partition_read_w_retry(self): + from google.cloud.spanner_v1.keyset import KeySet + from google.api_core.exceptions import InternalServerError + from google.cloud.spanner_v1 import Partition + from google.cloud.spanner_v1 import PartitionResponse + from google.cloud.spanner_v1 import Transaction + + keyset = KeySet(all_=True) + database = _Database() + api = database.spanner_api = self._make_spanner_api() + new_txn_id = b"ABECAB91" + token_1 = b"FACE0FFF" + token_2 = b"BADE8CAF" + response = PartitionResponse( + partitions=[ + Partition(partition_token=token_1), + Partition(partition_token=token_2), + ], + transaction=Transaction(id=new_txn_id), + ) + database.spanner_api.partition_read.side_effect = [ + InternalServerError("Received unexpected EOS on DATA frame from server"), + response, + ] + + session = _Session(database) + derived = self._makeDerived(session) + derived._multi_use = True + derived._transaction_id = TXN_ID + + list(derived.partition_read(TABLE_NAME, COLUMNS, keyset)) + + self.assertEqual(api.partition_read.call_count, 2) + def test_partition_read_ok_w_index_no_options(self): self._partition_read_helper(multi_use=True, w_txn=True, index="index") @@ -1609,6 +1643,25 @@ def test_begin_w_other_error(self): attributes=BASE_ATTRIBUTES, ) + def test_begin_w_retry(self): + from google.cloud.spanner_v1 import ( + Transaction as TransactionPB, + ) + from google.api_core.exceptions import InternalServerError + + database = _Database() + api = database.spanner_api = self._make_spanner_api() + database.spanner_api.begin_transaction.side_effect = [ + InternalServerError("Received unexpected EOS on DATA frame from server"), + TransactionPB(id=TXN_ID), + ] + timestamp = self._makeTimestamp() + session = _Session(database) + snapshot = self._make_one(session, read_timestamp=timestamp, multi_use=True) + + snapshot.begin() + self.assertEqual(api.begin_transaction.call_count, 2) + def test_begin_ok_exact_staleness(self): from google.protobuf.duration_pb2 import Duration from google.cloud.spanner_v1 import ( diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index ccf52f6a9f..4eb42027f7 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -194,6 +194,25 @@ def test_begin_ok(self): "CloudSpanner.BeginTransaction", attributes=TestTransaction.BASE_ATTRIBUTES ) + def test_begin_w_retry(self): + from google.cloud.spanner_v1 import ( + Transaction as TransactionPB, + ) + from google.api_core.exceptions import InternalServerError + + database = _Database() + api = database.spanner_api = self._make_spanner_api() + database.spanner_api.begin_transaction.side_effect = [ + InternalServerError("Received unexpected EOS on DATA frame from server"), + TransactionPB(id=self.TRANSACTION_ID), + ] + + session = _Session(database) + transaction = self._make_one(session) + transaction.begin() + + self.assertEqual(api.begin_transaction.call_count, 2) + def test_rollback_not_begun(self): database = _Database() api = database.spanner_api = self._make_spanner_api()