Skip to content

Commit

Permalink
extract pool creation from RedisMixin
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Aug 22, 2017
1 parent 6abdf05 commit 4686da9
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 43 deletions.
4 changes: 4 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
History
-------

v0.11.0 (2017-08-XX)
....................
* extract ``create_pool_lenient`` from ``RedixMixin``

v0.10.4 (2017-08-22)
....................
* ``RedisSettings`` repr method
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ lint:
python setup.py check -rms
flake8 arq/ tests/
pytest arq -p no:sugar -q
mypy --ignore-missing-imports --follow-imports=skip arq/
mypy --ignore-missing-imports arq/

.PHONY: test
test:
Expand Down
4 changes: 2 additions & 2 deletions arq/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async def enqueue_job(self, func_name: str, *args, queue: str=None, **kwargs):
queue = queue or self.DEFAULT_QUEUE
if self._concurrency_enabled:
# use the pool directly rather than get_redis_conn to avoid one extra await
pool = self._redis_pool or await self.get_redis_pool()
pool = self.redis_pool or await self.get_redis_pool()
main_logger.debug('%s.%s → %s', self.name, func_name, queue)
async with pool.get() as redis:
await self.job_future(redis, queue, func_name, *args, **kwargs)
Expand All @@ -141,7 +141,7 @@ def _now(self):

async def run_cron(self):
n = self._now()
pool = self._redis_pool or await self.get_redis_pool()
pool = self.redis_pool or await self.get_redis_pool()
to_run = set()

for cron_job in self.con_jobs:
Expand Down
10 changes: 5 additions & 5 deletions arq/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,17 @@ class MockRedisMixin(RedisMixin):
Dependent of RedisMixin which uses MockRedis rather than real redis to enqueue jobs.
"""
async def create_redis_pool(self):
return self._redis_pool or MockRedisPool(self.loop)
return self.redis_pool or MockRedisPool(self.loop)

@property
def mock_data(self):
self._redis_pool = self._redis_pool or MockRedisPool(self.loop)
return self._redis_pool.data
self.redis_pool = self.redis_pool or MockRedisPool(self.loop)
return self.redis_pool.data

@mock_data.setter
def mock_data(self, data):
self._redis_pool = self._redis_pool or MockRedisPool(self.loop)
self._redis_pool.data = data
self.redis_pool = self.redis_pool or MockRedisPool(self.loop)
self.redis_pool.data = data


class MockRedisWorker(MockRedisMixin, BaseWorker):
Expand Down
70 changes: 39 additions & 31 deletions arq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import aioredis
from aioredis.pool import RedisPool

__all__ = ['RedisSettings', 'RedisMixin', 'next_cron']
__all__ = ['RedisSettings', 'create_pool_lenient', 'RedisMixin', 'next_cron']
logger = logging.getLogger('arq.utils')


Expand Down Expand Up @@ -50,6 +50,33 @@ def __repr__(self):
return '<RedisSettings {}>'.format(' '.join(f'{s}={getattr(self, s)}' for s in self.__slots__))


async def create_pool_lenient(settings: RedisSettings, loop: asyncio.AbstractEventLoop, *, _retry: int=0) -> RedisPool:
"""
Create a new redis pool, retrying up to conn_retries times if the connection fails.
:param settings: RedisSettings instance
:param loop: event loop
:param _retry: retry attempt, this is set when the method calls itself recursively
"""
addr = settings.host, settings.port
try:
pool = await aioredis.create_pool(
addr, loop=loop, db=settings.database, password=settings.password,
create_connection_timeout=settings.conn_timeout
)
except (ConnectionError, OSError, aioredis.RedisError, asyncio.TimeoutError) as e:
if _retry < settings.conn_retries:
logger.warning('redis connection error %s %s, %d retries remaining...',
e.__class__.__name__, e, settings.conn_retries - _retry)
await asyncio.sleep(settings.conn_retry_delay)
return await create_pool_lenient(settings, loop, _retry=_retry + 1)
else:
raise
else:
if _retry > 0:
logger.warning('redis connection successful')
return pool


class RedisMixin:
"""
Mixin used to fined a redis pool and access it.
Expand All @@ -67,40 +94,21 @@ def __init__(self, *,
# loop or redis_settings before calling super().__init__ and don't pass those parameters.
self.loop = loop or getattr(self, 'loop', None) or asyncio.get_event_loop()
self.redis_settings = redis_settings or getattr(self, 'redis_settings', None) or RedisSettings()
self._redis_pool = existing_pool
self.redis_pool = existing_pool
self._create_pool_lock = asyncio.Lock(loop=self.loop)

async def create_redis_pool(self, *, _retry=0) -> RedisPool:
"""
Create a new redis pool.
"""
addr = self.redis_settings.host, self.redis_settings.port
try:
pool = await aioredis.create_pool(
addr, loop=self.loop, db=self.redis_settings.database, password=self.redis_settings.password,
create_connection_timeout=self.redis_settings.conn_timeout
)
except (ConnectionError, OSError, aioredis.RedisError, asyncio.TimeoutError) as e:
if _retry < self.redis_settings.conn_retries:
logger.warning('redis connection error %s %s, %d retries remaining...',
e.__class__.__name__, e, self.redis_settings.conn_retries - _retry)
await asyncio.sleep(self.redis_settings.conn_retry_delay)
return await self.create_redis_pool(_retry=_retry + 1)
else:
raise
else:
if _retry > 0:
logger.warning('redis connection successful')
return pool
async def create_redis_pool(self):
# defined here for easy mocking
return await create_pool_lenient(self.redis_settings, self.loop)

async def get_redis_pool(self) -> RedisPool:
"""
Get the redis pool, if a pool is already initialised it's returned, else one is crated.
"""
async with self._create_pool_lock:
if self._redis_pool is None:
self._redis_pool = await self.create_redis_pool()
return self._redis_pool
if self.redis_pool is None:
self.redis_pool = await self.create_redis_pool()
return self.redis_pool

async def get_redis_conn(self):
"""
Expand All @@ -124,10 +132,10 @@ async def close(self):
"""
Close the pool and wait for all connections to close.
"""
if self._redis_pool:
self._redis_pool.close()
await self._redis_pool.wait_closed()
await self._redis_pool.clear()
if self.redis_pool:
self.redis_pool.close()
await self.redis_pool.wait_closed()
await self.redis_pool.clear()


def create_tz(utcoffset=0) -> timezone:
Expand Down
8 changes: 4 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

import arq.utils
from arq import RedisMixin, RedisSettings
from arq import RedisMixin, RedisSettings, create_pool_lenient
from arq.logs import ColourHandler
from arq.testing import MockRedis
from arq.utils import next_cron, timestamp
Expand Down Expand Up @@ -65,13 +65,13 @@ async def test_redis_timeout(loop, mocker):


async def test_redis_success_log(loop, caplog):
r = RedisMixin(loop=loop)
pool = await r.create_redis_pool()
settings = RedisSettings()
pool = await create_pool_lenient(settings, loop)
assert 'redis connection successful' not in caplog
pool.close()
await pool.wait_closed()

pool = await r.create_redis_pool(_retry=1)
pool = await create_pool_lenient(settings, loop, _retry=1)
assert 'redis connection successful' not in caplog
pool.close()
await pool.wait_closed()
Expand Down

0 comments on commit 4686da9

Please sign in to comment.