Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Catch rst stream error for all transactions #934

Merged
merged 8 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions google/cloud/spanner_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import datetime
import decimal
import math
import time

from google.protobuf.struct_pb2 import ListValue
from google.protobuf.struct_pb2 import Value
Expand Down Expand Up @@ -294,6 +295,59 @@ def _metadata_with_prefix(prefix, **kw):
return [("google-cloud-resource-prefix", prefix)]


def _retry(
func,
retry_count=5,
delay=2,
allowed_exceptions=None,
):
"""
Retry a function with a specified number of retries, delay between retries, and list of allowed exceptions.

Args:
func: The function to be retried.
retry_count: The maximum number of times to retry the function.
delay: The delay in seconds between retries.
allowed_exceptions: A tuple of exceptions that are allowed to occur without triggering a retry.
asthamohta marked this conversation as resolved.
Show resolved Hide resolved
Passing allowed_exceptions as None will lead to retrying for all exceptions.

Returns:
The result of the function if it is successful, or raises the last exception if all retries fail.
"""
retries = 0
while retries <= retry_count:
try:
return func()
except Exception as exc:
if (
allowed_exceptions is None or exc.__class__ in allowed_exceptions
) and retries < retry_count:
asthamohta marked this conversation as resolved.
Show resolved Hide resolved
if (
allowed_exceptions is not None
and allowed_exceptions[exc.__class__] is not None
):
allowed_exceptions[exc.__class__](exc)
time.sleep(delay)
delay = delay * 2
retries = retries + 1
else:
raise exc


def _check_rst_stream_error(exc):
resumable_error = (
any(
resumable_message in exc.message
for resumable_message in (
"RST_STREAM",
"Received unexpected EOS on DATA frame from server",
)
),
)
if not resumable_error:
raise


def _metadata_with_leader_aware_routing(value, **kw):
"""Create RPC metadata containing a leader aware routing header

Expand Down
11 changes: 10 additions & 1 deletion google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Context manager for Cloud Spanner batched writes."""
import functools

from google.cloud.spanner_v1 import CommitRequest
from google.cloud.spanner_v1 import Mutation
Expand All @@ -26,6 +27,9 @@
)
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from google.cloud.spanner_v1 import RequestOptions
from google.cloud.spanner_v1._helpers import _retry
from google.cloud.spanner_v1._helpers import _check_rst_stream_error
from google.api_core.exceptions import InternalServerError


class _BatchBase(_SessionWrapper):
Expand Down Expand Up @@ -186,10 +190,15 @@ def commit(self, return_commit_stats=False, request_options=None):
request_options=request_options,
)
with trace_call("CloudSpanner.Commit", self._session, trace_attributes):
response = api.commit(
method = functools.partial(
api.commit,
request=request,
metadata=metadata,
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)
self.committed = response.commit_timestamp
self.commit_stats = response.commit_stats
return self.committed
Expand Down
29 changes: 23 additions & 6 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@
from google.api_core.exceptions import ServiceUnavailable
from google.api_core.exceptions import InvalidArgument
from google.api_core import gapic_v1
from google.cloud.spanner_v1._helpers import _make_value_pb
from google.cloud.spanner_v1._helpers import _merge_query_options
from google.cloud.spanner_v1._helpers import (
_make_value_pb,
_merge_query_options,
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
_retry,
_check_rst_stream_error,
_SessionWrapper,
)
from google.cloud.spanner_v1._helpers import _SessionWrapper
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from google.cloud.spanner_v1.streamed import StreamedResultSet
from google.cloud.spanner_v1 import RequestOptions
Expand Down Expand Up @@ -560,12 +562,17 @@ def partition_read(
with trace_call(
"CloudSpanner.PartitionReadOnlyTransaction", self._session, trace_attributes
):
response = api.partition_read(
method = functools.partial(
api.partition_read,
request=request,
metadata=metadata,
retry=retry,
timeout=timeout,
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)

return [partition.partition_token for partition in response.partitions]

Expand Down Expand Up @@ -659,12 +666,17 @@ def partition_query(
self._session,
trace_attributes,
):
response = api.partition_query(
method = functools.partial(
api.partition_query,
request=request,
metadata=metadata,
retry=retry,
timeout=timeout,
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)

return [partition.partition_token for partition in response.partitions]

Expand Down Expand Up @@ -791,10 +803,15 @@ def begin(self):
)
txn_selector = self._make_txn_selector()
with trace_call("CloudSpanner.BeginTransaction", self._session):
response = api.begin_transaction(
method = functools.partial(
api.begin_transaction,
session=self._session.name,
options=txn_selector.begin,
metadata=metadata,
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)
self._transaction_id = response.id
return self._transaction_id
34 changes: 29 additions & 5 deletions google/cloud/spanner_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
_merge_query_options,
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
_retry,
_check_rst_stream_error,
)
from google.cloud.spanner_v1 import CommitRequest
from google.cloud.spanner_v1 import ExecuteBatchDmlRequest
Expand All @@ -33,6 +35,7 @@
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from google.cloud.spanner_v1 import RequestOptions
from google.api_core import gapic_v1
from google.api_core.exceptions import InternalServerError


class Transaction(_SnapshotBase, _BatchBase):
Expand Down Expand Up @@ -102,7 +105,11 @@ def _execute_request(
transaction = self._make_txn_selector()
request.transaction = transaction
with trace_call(trace_name, session, attributes):
response = method(request=request)
method = functools.partial(method, request=request)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)

return response

Expand Down Expand Up @@ -132,8 +139,15 @@ def begin(self):
)
txn_options = TransactionOptions(read_write=TransactionOptions.ReadWrite())
with trace_call("CloudSpanner.BeginTransaction", self._session):
response = api.begin_transaction(
session=self._session.name, options=txn_options, metadata=metadata
method = functools.partial(
api.begin_transaction,
session=self._session.name,
options=txn_options,
metadata=metadata,
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)
self._transaction_id = response.id
return self._transaction_id
Expand All @@ -153,11 +167,16 @@ def rollback(self):
)
)
with trace_call("CloudSpanner.Rollback", self._session):
api.rollback(
method = functools.partial(
api.rollback,
session=self._session.name,
transaction_id=self._transaction_id,
metadata=metadata,
)
_retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)
self.rolled_back = True
del self._session._transaction

Expand Down Expand Up @@ -212,10 +231,15 @@ def commit(self, return_commit_stats=False, request_options=None):
request_options=request_options,
)
with trace_call("CloudSpanner.Commit", self._session, trace_attributes):
response = api.commit(
method = functools.partial(
api.commit,
request=request,
metadata=metadata,
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)
self.committed = response.commit_timestamp
if return_commit_stats:
self.commit_stats = response.commit_stats
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/spanner_dbapi/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test__session_checkout(self, mock_database):
connection._session_checkout()
self.assertEqual(connection._session, "db_session")

def test__session_checkout_database_error(self):
def test_session_checkout_database_error(self):
from google.cloud.spanner_dbapi import Connection

connection = Connection(INSTANCE)
Expand All @@ -191,7 +191,7 @@ def test__release_session(self, mock_database):
pool.put.assert_called_once_with("session")
self.assertIsNone(connection._session)

def test__release_session_database_error(self):
def test_release_session_database_error(self):
from google.cloud.spanner_dbapi import Connection

connection = Connection(INSTANCE)
Expand Down
78 changes: 78 additions & 0 deletions tests/unit/test__helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


import unittest
import mock


class Test_merge_query_options(unittest.TestCase):
Expand Down Expand Up @@ -671,6 +672,83 @@ def test(self):
self.assertEqual(metadata, [("google-cloud-resource-prefix", prefix)])


class Test_retry(unittest.TestCase):
class test_class:
def test_fxn(self):
return True

def test_retry_on_error(self):
from google.api_core.exceptions import InternalServerError, NotFound
from google.cloud.spanner_v1._helpers import _retry
import functools

test_api = mock.create_autospec(self.test_class)
test_api.test_fxn.side_effect = [
InternalServerError("testing"),
NotFound("testing"),
True,
]

_retry(functools.partial(test_api.test_fxn))

self.assertEqual(test_api.test_fxn.call_count, 3)

def test_retry_allowed_exceptions(self):
from google.api_core.exceptions import InternalServerError, NotFound
from google.cloud.spanner_v1._helpers import _retry
import functools

test_api = mock.create_autospec(self.test_class)
test_api.test_fxn.side_effect = [
NotFound("testing"),
InternalServerError("testing"),
True,
]

with self.assertRaises(InternalServerError):
_retry(
functools.partial(test_api.test_fxn),
allowed_exceptions={NotFound: None},
)
asthamohta marked this conversation as resolved.
Show resolved Hide resolved

self.assertEqual(test_api.test_fxn.call_count, 2)

def test_retry_count(self):
from google.api_core.exceptions import InternalServerError
from google.cloud.spanner_v1._helpers import _retry
import functools

test_api = mock.create_autospec(self.test_class)
test_api.test_fxn.side_effect = [
InternalServerError("testing"),
InternalServerError("testing"),
]

with self.assertRaises(InternalServerError):
_retry(functools.partial(test_api.test_fxn), retry_count=1)
asthamohta marked this conversation as resolved.
Show resolved Hide resolved

self.assertEqual(test_api.test_fxn.call_count, 2)

def test_check_rst_stream_error(self):
asthamohta marked this conversation as resolved.
Show resolved Hide resolved
from google.api_core.exceptions import InternalServerError
from google.cloud.spanner_v1._helpers import _retry, _check_rst_stream_error
import functools

test_api = mock.create_autospec(self.test_class)
test_api.test_fxn.side_effect = [
InternalServerError("Received unexpected EOS on DATA frame from server"),
InternalServerError("RST_STREAM"),
True,
]

_retry(
functools.partial(test_api.test_fxn),
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)

self.assertEqual(test_api.test_fxn.call_count, 3)


class Test_metadata_with_leader_aware_routing(unittest.TestCase):
def _call_fut(self, *args, **kw):
from google.cloud.spanner_v1._helpers import _metadata_with_leader_aware_routing
Expand Down
Loading