Skip to content

Commit

Permalink
fix: don't lose notifications between notifies() calls
Browse files Browse the repository at this point in the history
This allows to stop periodically the generator to run some queries (for
example to LISTEN/UNLISTEN certain channels) and start the generator
again without fearing to lose notification in the window.

Cloes #962.
  • Loading branch information
dvarrazzo committed Dec 26, 2024
1 parent 1def29d commit bba3c01
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 51 deletions.
2 changes: 2 additions & 0 deletions docs/news.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ Python 3.3.0 (unreleased)
Psycopg 3.2.4 (unreleased)
^^^^^^^^^^^^^^^^^^^^^^^^^^

- Don't lose notifies received between two `~Connection.notifies()` calls
(:ticket:`#962`).
- Make sure that the notifies callback is called during the use of the
`~Connection.notifies()` generator (:ticket:`#972`).

Expand Down
19 changes: 18 additions & 1 deletion psycopg/psycopg/_connection_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .rows import Row
from .adapt import AdaptersMap
from ._enums import IsolationLevel
from ._compat import LiteralString, Self, TypeAlias, TypeVar
from ._compat import Deque, LiteralString, Self, TypeAlias, TypeVar
from .pq.misc import connection_summary
from ._pipeline import BasePipeline
from ._preparing import PrepareManager
Expand Down Expand Up @@ -116,6 +116,14 @@ def __init__(self, pgconn: PGconn):
pgconn.notice_handler = partial(BaseConnection._notice_handler, wself)
pgconn.notify_handler = partial(BaseConnection._notify_handler, wself)

# Gather notifies when the notifies() generator is not running.
# This handler is registered after notifies() is used te first time.
# backlog = None means that the handler hasn't been registered.
self._notifies_backlog: Deque[Notify] | None = None
self._notifies_backlog_handler = partial(
BaseConnection._add_notify_to_backlog, wself
)

# Attribute is only set if the connection is from a pool so we can tell
# apart a connection in the pool too (when _pool = None)
self._pool: BasePool | None
Expand Down Expand Up @@ -377,6 +385,15 @@ def _notify_handler(
for cb in self._notify_handlers:
cb(n)

@staticmethod
def _add_notify_to_backlog(
wself: ReferenceType[BaseConnection[Row]], notify: Notify
) -> None:
self = wself()
if not self or self._notifies_backlog is None:
return
self._notifies_backlog.append(notify)

@property
def prepare_threshold(self) -> int | None:
"""
Expand Down
69 changes: 45 additions & 24 deletions psycopg/psycopg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .rows import Row, RowFactory, tuple_row, args_row
from .adapt import AdaptersMap
from ._enums import IsolationLevel
from ._compat import Self
from ._compat import Deque, Self
from .conninfo import make_conninfo, conninfo_to_dict
from .conninfo import conninfo_attempts, timeout_from_conninfo
from ._pipeline import Pipeline
Expand Down Expand Up @@ -338,31 +338,52 @@ def notifies(

with self.lock:
enc = self.pgconn._encoding
while True:
try:
ns = self.wait(notifies(self.pgconn), interval=interval)
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)

# Emit the notifications received.
for pgn in ns:
n = Notify(
pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid
)
yield n
nreceived += 1

# Stop if we have received enough notifications.
if stop_after is not None and nreceived >= stop_after:
break

# Check the deadline after the loop to ensure that timeout=0
# polls at least once.
if deadline:
interval = min(_WAIT_INTERVAL, deadline - monotonic())
if interval < 0.0:

# If the backlog is set to not-None, then the handler is also set.
# Remove the handler for the duration of this critical section to
# avoid reporting notifies twice.
if self._notifies_backlog is not None:
self.remove_notify_handler(self._notifies_backlog_handler)

try:
while True:
# if notifies were received when the generator was off,
# return them in a first batch.
if self._notifies_backlog:
while self._notifies_backlog:
yield self._notifies_backlog.popleft()
nreceived += 1
else:
try:
pgns = self.wait(notifies(self.pgconn), interval=interval)
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
# Emit the notifications received.
for pgn in pgns:
yield Notify(
pgn.relname.decode(enc),
pgn.extra.decode(enc),
pgn.be_pid,
)
nreceived += 1

# Stop if we have received enough notifications.
if stop_after is not None and nreceived >= stop_after:
break

# Check the deadline after the loop to ensure that timeout=0
# polls at least once.
if deadline:
interval = min(_WAIT_INTERVAL, deadline - monotonic())
if interval < 0.0:
break
finally:
# Install, or re-install, the backlog notify handler
# to catch notifications received while the generator was off.
if self._notifies_backlog is None:
self._notifies_backlog = Deque()
self.add_notify_handler(self._notifies_backlog_handler)

@contextmanager
def pipeline(self) -> Iterator[Pipeline]:
"""Context manager to switch the connection into pipeline mode."""
Expand Down
73 changes: 49 additions & 24 deletions psycopg/psycopg/connection_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .rows import Row, AsyncRowFactory, tuple_row, args_row
from .adapt import AdaptersMap
from ._enums import IsolationLevel
from ._compat import Self
from ._compat import Deque, Self
from .conninfo import make_conninfo, conninfo_to_dict
from .conninfo import conninfo_attempts_async, timeout_from_conninfo
from ._pipeline import AsyncPipeline
Expand Down Expand Up @@ -358,31 +358,56 @@ async def notifies(

async with self.lock:
enc = self.pgconn._encoding
while True:
try:
ns = await self.wait(notifies(self.pgconn), interval=interval)
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)

# Emit the notifications received.
for pgn in ns:
n = Notify(
pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid
)
yield n
nreceived += 1

# Stop if we have received enough notifications.
if stop_after is not None and nreceived >= stop_after:
break

# Check the deadline after the loop to ensure that timeout=0
# polls at least once.
if deadline:
interval = min(_WAIT_INTERVAL, deadline - monotonic())
if interval < 0.0:

# If the backlog is set to not-None, then the handler is also set.
# Remove the handler for the duration of this critical section to
# avoid reporting notifies twice.
if self._notifies_backlog is not None:
self.remove_notify_handler(self._notifies_backlog_handler)

try:
while True:
# if notifies were received when the generator was off,
# return them in a first batch.
if self._notifies_backlog:
while self._notifies_backlog:
yield self._notifies_backlog.popleft()
nreceived += 1
else:
try:
pgns = await self.wait(
notifies(self.pgconn), interval=interval
)
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)

# Emit the notifications received.
for pgn in pgns:
yield Notify(
pgn.relname.decode(enc),
pgn.extra.decode(enc),
pgn.be_pid,
)
nreceived += 1

# Stop if we have received enough notifications.
if stop_after is not None and nreceived >= stop_after:
break

# Check the deadline after the loop to ensure that timeout=0
# polls at least once.
if deadline:
interval = min(_WAIT_INTERVAL, deadline - monotonic())
if interval < 0.0:
break
finally:
# Install, or re-install, the backlog notify handler
# to catch notifications received while the generator was off.
if self._notifies_backlog is None:
self._notifies_backlog = Deque()

self.add_notify_handler(self._notifies_backlog_handler)

@asynccontextmanager
async def pipeline(self) -> AsyncIterator[AsyncPipeline]:
"""Context manager to switch the connection into pipeline mode."""
Expand Down
44 changes: 43 additions & 1 deletion tests/test_notify.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
from psycopg import Notify

from .acompat import sleep, gather, spawn
from .acompat import Event, sleep, gather, spawn

pytestmark = pytest.mark.crdb_skip("notify")

Expand Down Expand Up @@ -253,3 +253,45 @@ def listener():

assert n1
assert n2


@pytest.mark.slow
@pytest.mark.timing
@pytest.mark.parametrize("sleep_on", ["server", "client"])
def test_notify_query_notify(conn_cls, dsn, sleep_on):
e = Event()
by_gen: list[int] = []
by_cb: list[int] = []
workers = []

def notifier():
with conn_cls.connect(dsn, autocommit=True) as conn:
sleep(0.1)
for i in range(3):
conn.execute("select pg_notify('counter', %s)", (str(i),))
sleep(0.2)

def listener():
with conn_cls.connect(dsn, autocommit=True) as conn:
conn.add_notify_handler(lambda n: by_cb.append(int(n.payload)))

conn.execute("listen counter")
e.set()
for n in conn.notifies(timeout=0.2):
by_gen.append(int(n.payload))

if sleep_on == "server":
conn.execute("select pg_sleep(0.2)")
else:
assert sleep_on == "client"
sleep(0.2)

for n in conn.notifies(timeout=0.2):
by_gen.append(int(n.payload))

workers.append(spawn(listener))
e.wait()
workers.append(spawn(notifier))
gather(*workers)

assert list(range(3)) == by_cb == by_gen, f"by_gen={by_gen!r}, by_cb={by_cb!r}"
44 changes: 43 additions & 1 deletion tests/test_notify_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from psycopg import Notify

from .acompat import alist, asleep, gather, spawn
from .acompat import AEvent, alist, asleep, gather, spawn

pytestmark = pytest.mark.crdb_skip("notify")

Expand Down Expand Up @@ -250,3 +250,45 @@ async def listener():

assert n1
assert n2


@pytest.mark.slow
@pytest.mark.timing
@pytest.mark.parametrize("sleep_on", ["server", "client"])
async def test_notify_query_notify(aconn_cls, dsn, sleep_on):
e = AEvent()
by_gen: list[int] = []
by_cb: list[int] = []
workers = []

async def notifier():
async with await aconn_cls.connect(dsn, autocommit=True) as aconn:
await asleep(0.1)
for i in range(3):
await aconn.execute("select pg_notify('counter', %s)", (str(i),))
await asleep(0.2)

async def listener():
async with await aconn_cls.connect(dsn, autocommit=True) as aconn:
aconn.add_notify_handler(lambda n: by_cb.append(int(n.payload)))

await aconn.execute("listen counter")
e.set()
async for n in aconn.notifies(timeout=0.2):
by_gen.append(int(n.payload))

if sleep_on == "server":
await aconn.execute("select pg_sleep(0.2)")
else:
assert sleep_on == "client"
await asleep(0.2)

async for n in aconn.notifies(timeout=0.2):
by_gen.append(int(n.payload))

workers.append(spawn(listener))
await e.wait()
workers.append(spawn(notifier))
await gather(*workers)

assert list(range(3)) == by_cb == by_gen, f"{by_gen=}, {by_cb=}"

0 comments on commit bba3c01

Please sign in to comment.