Skip to content

Commit

Permalink
perf: remove repeated GetSession calls for FixedSizePool (#1252)
Browse files Browse the repository at this point in the history
* test: add mock server tests

* chore: move to testing folder + fix formatting

* refactor: move mock server tests to separate directory

* feat: add database admin service

Adds a DatabaseAdminService to the mock server and sets up a basic
test case for this.

Also removes the generated stubs in the grpc files, as these are
not needed.

* test: add DDL test

* test: add async client tests

* chore: remove async + add transaction handling

* chore: cleanup

* perf: remove repeated GetSession calls for FixedSizePool

Add a _last_use_time to Session and use this to determine whether the
FixedSizePool should check whether the session still exists, and whether
it should be replaced. This significantly reduces the number of times that
GetSession is called when using FixedSizePool.

* chore: run code formatter

* chore: revert to utcnow()

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* fix: update _last_use_time in trace_call

* chore: fix formatting

* fix: remove unnecessary update of _last_use_time

---------

Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
  • Loading branch information
olavloite and gcf-owl-bot[bot] authored Dec 5, 2024
1 parent a214885 commit c064815
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 11 deletions.
4 changes: 4 additions & 0 deletions google/cloud/spanner_v1/_opentelemetry_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Manages OpenTelemetry trace creation and handling"""

from contextlib import contextmanager
from datetime import datetime
import os

from google.cloud.spanner_v1 import SpannerClient
Expand Down Expand Up @@ -56,6 +57,9 @@ def get_tracer(tracer_provider=None):

@contextmanager
def trace_call(name, session, extra_attributes=None, observability_options=None):
if session:
session._last_use_time = datetime.now()

if not HAS_OPENTELEMETRY_INSTALLED or not session:
# Empty context manager. Users will have to check if the generated value is None or a span
yield None
Expand Down
9 changes: 7 additions & 2 deletions google/cloud/spanner_v1/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ class FixedSizePool(AbstractSessionPool):
- Pre-allocates / creates a fixed number of sessions.
- "Pings" existing sessions via :meth:`session.exists` before returning
them, and replaces expired sessions.
sessions that have not been used for more than 55 minutes and replaces
expired sessions.
- Blocks, with a timeout, when :meth:`get` is called on an empty pool.
Raises after timing out.
Expand All @@ -171,18 +172,21 @@ class FixedSizePool(AbstractSessionPool):

DEFAULT_SIZE = 10
DEFAULT_TIMEOUT = 10
DEFAULT_MAX_AGE_MINUTES = 55

def __init__(
self,
size=DEFAULT_SIZE,
default_timeout=DEFAULT_TIMEOUT,
labels=None,
database_role=None,
max_age_minutes=DEFAULT_MAX_AGE_MINUTES,
):
super(FixedSizePool, self).__init__(labels=labels, database_role=database_role)
self.size = size
self.default_timeout = default_timeout
self._sessions = queue.LifoQueue(size)
self._max_age = datetime.timedelta(minutes=max_age_minutes)

def bind(self, database):
"""Associate the pool with a database.
Expand Down Expand Up @@ -230,8 +234,9 @@ def get(self, timeout=None):
timeout = self.default_timeout

session = self._sessions.get(block=True, timeout=timeout)
age = _NOW() - session.last_use_time

if not session.exists():
if age >= self._max_age and not session.exists():
session = self._database.session()
session.create()

Expand Down
11 changes: 11 additions & 0 deletions google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from functools import total_ordering
import random
import time
from datetime import datetime

from google.api_core.exceptions import Aborted
from google.api_core.exceptions import GoogleAPICallError
Expand Down Expand Up @@ -69,6 +70,7 @@ def __init__(self, database, labels=None, database_role=None):
labels = {}
self._labels = labels
self._database_role = database_role
self._last_use_time = datetime.utcnow()

def __lt__(self, other):
return self._session_id < other._session_id
Expand All @@ -78,6 +80,14 @@ def session_id(self):
"""Read-only ID, set by the back-end during :meth:`create`."""
return self._session_id

@property
def last_use_time(self):
""" "Approximate last use time of this session
:rtype: datetime
:returns: the approximate last use time of this session"""
return self._last_use_time

@property
def database_role(self):
"""User-assigned database-role for the session.
Expand Down Expand Up @@ -222,6 +232,7 @@ def ping(self):
metadata = _metadata_with_prefix(self._database.name)
request = ExecuteSqlRequest(session=self.name, sql="SELECT 1")
api.execute_sql(request=request, metadata=metadata)
self._last_use_time = datetime.now()

def snapshot(self, **kw):
"""Create a snapshot to perform a set of reads with shared staleness.
Expand Down
2 changes: 2 additions & 0 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Model a set of read-only queries to a database as a snapshot."""

from datetime import datetime
import functools
import threading
from google.protobuf.struct_pb2 import Struct
Expand Down Expand Up @@ -364,6 +365,7 @@ def read(
)

self._read_request_count += 1
self._session._last_use_time = datetime.now()

if self._multi_use:
return StreamedResultSet(
Expand Down
8 changes: 2 additions & 6 deletions tests/mockserver_tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
FixedSizePool,
BatchCreateSessionsRequest,
ExecuteSqlRequest,
GetSessionRequest,
)
from google.cloud.spanner_v1.database import Database
from google.cloud.spanner_v1.instance import Instance
Expand Down Expand Up @@ -125,12 +124,9 @@ def test_select1(self):
self.assertEqual(1, row[0])
self.assertEqual(1, len(result_list))
requests = self.spanner_service.requests
self.assertEqual(3, len(requests))
self.assertEqual(2, len(requests), msg=requests)
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
# TODO: Optimize FixedSizePool so this GetSessionRequest is not executed
# every time a session is fetched.
self.assertTrue(isinstance(requests[1], GetSessionRequest))
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))

def test_create_table(self):
database_admin_api = self.client.database_admin_api
Expand Down
32 changes: 29 additions & 3 deletions tests/unit/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from functools import total_ordering
import unittest
from datetime import datetime, timedelta

import mock

Expand Down Expand Up @@ -184,13 +185,30 @@ def test_bind(self):
for session in SESSIONS:
session.create.assert_not_called()

def test_get_non_expired(self):
def test_get_active(self):
pool = self._make_one(size=4)
database = _Database("name")
SESSIONS = sorted([_Session(database) for i in range(0, 4)])
database._sessions.extend(SESSIONS)
pool.bind(database)

# check if sessions returned in LIFO order
for i in (3, 2, 1, 0):
session = pool.get()
self.assertIs(session, SESSIONS[i])
self.assertFalse(session._exists_checked)
self.assertFalse(pool._sessions.full())

def test_get_non_expired(self):
pool = self._make_one(size=4)
database = _Database("name")
last_use_time = datetime.utcnow() - timedelta(minutes=56)
SESSIONS = sorted(
[_Session(database, last_use_time=last_use_time) for i in range(0, 4)]
)
database._sessions.extend(SESSIONS)
pool.bind(database)

# check if sessions returned in LIFO order
for i in (3, 2, 1, 0):
session = pool.get()
Expand All @@ -201,7 +219,8 @@ def test_get_non_expired(self):
def test_get_expired(self):
pool = self._make_one(size=4)
database = _Database("name")
SESSIONS = [_Session(database)] * 5
last_use_time = datetime.utcnow() - timedelta(minutes=65)
SESSIONS = [_Session(database, last_use_time=last_use_time)] * 5
SESSIONS[0]._exists = False
database._sessions.extend(SESSIONS)
pool.bind(database)
Expand Down Expand Up @@ -915,18 +934,25 @@ def _make_transaction(*args, **kw):
class _Session(object):
_transaction = None

def __init__(self, database, exists=True, transaction=None):
def __init__(
self, database, exists=True, transaction=None, last_use_time=datetime.utcnow()
):
self._database = database
self._exists = exists
self._exists_checked = False
self._pinged = False
self.create = mock.Mock()
self._deleted = False
self._transaction = transaction
self._last_use_time = last_use_time

def __lt__(self, other):
return id(self) < id(other)

@property
def last_use_time(self):
return self._last_use_time

def exists(self):
self._exists_checked = True
return self._exists
Expand Down

0 comments on commit c064815

Please sign in to comment.