Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Apr 1, 2024
2 parents edd6014 + 94cd878 commit e54eace
Show file tree
Hide file tree
Showing 17 changed files with 202 additions and 112 deletions.
2 changes: 1 addition & 1 deletion arq/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def watch_reload(path: str, worker_settings: 'WorkerSettingsType') -> None
except ImportError as e: # pragma: no cover
raise ImportError('watchfiles not installed, use `pip install watchfiles`') from e

loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
stop_event = asyncio.Event()

def worker_on_stop(s: Signals) -> None:
Expand Down
15 changes: 9 additions & 6 deletions arq/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ class RedisSettings:
@classmethod
def from_dsn(cls, dsn: str) -> 'RedisSettings':
conf = urlparse(dsn)
assert conf.scheme in {'redis', 'rediss', 'unix'}, 'invalid DSN scheme'
if conf.scheme not in {'redis', 'rediss', 'unix'}:
raise RuntimeError('invalid DSN scheme')
query_db = parse_qs(conf.query).get('db')
if query_db:
# e.g. redis://localhost:6379?db=1
Expand Down Expand Up @@ -143,7 +144,8 @@ async def enqueue_job(
_queue_name = self.default_queue_name
job_id = _job_id or uuid4().hex
job_key = job_key_prefix + job_id
assert not (_defer_until and _defer_by), "use either 'defer_until' or 'defer_by' or neither, not both"
if _defer_until and _defer_by:
raise RuntimeError("use either 'defer_until' or 'defer_by' or neither, not both")

defer_by_ms = to_ms(_defer_by)
expires_ms = to_ms(_expires)
Expand Down Expand Up @@ -195,9 +197,11 @@ async def all_job_results(self) -> List[JobResult]:
async def _get_job_def(self, job_id: bytes, score: int) -> JobDef:
key = job_key_prefix + job_id.decode()
v = await self.get(key)
assert v is not None, f'job "{key}" not found'
if v is None:
raise RuntimeError(f'job "{key}" not found')
jd = deserialize_job(v, deserializer=self.job_deserializer)
jd.score = score
jd.job_id = job_id.decode()
return jd

async def queued_jobs(self, *, queue_name: Optional[str] = None) -> List[JobDef]:
Expand Down Expand Up @@ -226,9 +230,8 @@ async def create_pool(
"""
settings: RedisSettings = RedisSettings() if settings_ is None else settings_

assert not (
type(settings.host) is str and settings.sentinel
), "str provided for 'host' but 'sentinel' is true; list of sentinels expected"
if isinstance(settings.host, str) and settings.sentinel:
raise RuntimeError("str provided for 'host' but 'sentinel' is true; list of sentinels expected")

if settings.sentinel:

Expand Down
15 changes: 9 additions & 6 deletions arq/cron.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,18 @@ def _get_next_dt(dt_: datetime, options: Options) -> Optional[datetime]: # noqa
next_v = getattr(dt_, field)
if isinstance(v, int):
mismatch = next_v != v
else:
assert isinstance(v, (set, list, tuple)), v
elif isinstance(v, (set, list, tuple)):
mismatch = next_v not in v
else:
raise RuntimeError(v)
# print(field, v, next_v, mismatch)
if mismatch:
micro = max(dt_.microsecond - options.microsecond, 0)
if field == 'month':
if dt_.month == 12:
return datetime(dt_.year + 1, 1, 1)
return datetime(dt_.year + 1, 1, 1, tzinfo=dt_.tzinfo)
else:
return datetime(dt_.year, dt_.month + 1, 1)
return datetime(dt_.year, dt_.month + 1, 1, tzinfo=dt_.tzinfo)
elif field in ('day', 'weekday'):
return (
dt_
Expand All @@ -82,7 +83,8 @@ def _get_next_dt(dt_: datetime, options: Options) -> Optional[datetime]: # noqa
elif field == 'second':
return dt_ + timedelta(seconds=1) - timedelta(microseconds=micro)
else:
assert field == 'microsecond', field
if field != 'microsecond':
raise RuntimeError(field)
return dt_ + timedelta(microseconds=options.microsecond - dt_.microsecond)
return None

Expand Down Expand Up @@ -173,7 +175,8 @@ def cron(
else:
coroutine_ = coroutine

assert asyncio.iscoroutinefunction(coroutine_), f'{coroutine_} is not a coroutine function'
if not asyncio.iscoroutinefunction(coroutine_):
raise RuntimeError(f'{coroutine_} is not a coroutine function')
timeout = to_seconds(timeout)
keep_result = to_seconds(keep_result)

Expand Down
6 changes: 5 additions & 1 deletion arq/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class JobDef:
job_try: int
enqueue_time: datetime
score: Optional[int]
job_id: Optional[str]

def __post_init__(self) -> None:
if isinstance(self.score, float):
Expand All @@ -60,7 +61,6 @@ class JobResult(JobDef):
start_time: datetime
finish_time: datetime
queue_name: str
job_id: Optional[str] = None


class Job:
Expand Down Expand Up @@ -238,6 +238,7 @@ def serialize_result(
finished_ms: int,
ref: str,
queue_name: str,
job_id: str,
*,
serializer: Optional[Serializer] = None,
) -> Optional[bytes]:
Expand All @@ -252,6 +253,7 @@ def serialize_result(
'st': start_ms,
'ft': finished_ms,
'q': queue_name,
'id': job_id,
}
if serializer is None:
serializer = pickle.dumps
Expand Down Expand Up @@ -281,6 +283,7 @@ def deserialize_job(r: bytes, *, deserializer: Optional[Deserializer] = None) ->
job_try=d['t'],
enqueue_time=ms_to_datetime(d['et']),
score=None,
job_id=None,
)
except Exception as e:
raise DeserializationError('unable to deserialize job') from e
Expand Down Expand Up @@ -315,6 +318,7 @@ def deserialize_result(r: bytes, *, deserializer: Optional[Deserializer] = None)
start_time=ms_to_datetime(d['st']),
finish_time=ms_to_datetime(d['ft']),
queue_name=d.get('q', '<unknown>'),
job_id=d.get('id', '<unknown>'),
)
except Exception as e:
raise DeserializationError('unable to deserialize job result') from e
37 changes: 30 additions & 7 deletions arq/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def func(
else:
coroutine_ = coroutine

assert asyncio.iscoroutinefunction(coroutine_), f'{coroutine_} is not a coroutine function'
if not asyncio.iscoroutinefunction(coroutine_):
raise RuntimeError(f'{coroutine_} is not a coroutine function')
timeout = to_seconds(timeout)
keep_result = to_seconds(keep_result)

Expand Down Expand Up @@ -226,17 +227,23 @@ def __init__(
self.queue_name = queue_name
self.cron_jobs: List[CronJob] = []
if cron_jobs is not None:
assert all(isinstance(cj, CronJob) for cj in cron_jobs), 'cron_jobs, must be instances of CronJob'
if not all(isinstance(cj, CronJob) for cj in cron_jobs):
raise RuntimeError('cron_jobs, must be instances of CronJob')
self.cron_jobs = list(cron_jobs)
self.functions.update({cj.name: cj for cj in self.cron_jobs})
assert len(self.functions) > 0, 'at least one function or cron_job must be registered'
if len(self.functions) == 0:
raise RuntimeError('at least one function or cron_job must be registered')
self.burst = burst
self.on_startup = on_startup
self.on_shutdown = on_shutdown
self.on_job_start = on_job_start
self.on_job_end = on_job_end
self.after_job_end = after_job_end
self.sem = asyncio.BoundedSemaphore(max_jobs)

self.max_jobs = max_jobs
self.sem = asyncio.BoundedSemaphore(max_jobs + 1)
self.job_counter: int = 0

self.job_timeout_s = to_seconds(job_timeout)
self.keep_result_s = to_seconds(keep_result)
self.keep_result_forever = keep_result_forever
Expand Down Expand Up @@ -374,13 +381,13 @@ async def _poll_iteration(self) -> None:
return
count = min(burst_jobs_remaining, count)
if self.allow_pick_jobs:
async with self.sem: # don't bother with zrangebyscore until we have "space" to run the jobs
if self.job_counter < self.max_jobs:
now = timestamp_ms()
job_ids = await self.pool.zrangebyscore(
self.queue_name, min=float('-inf'), start=self._queue_read_offset, num=count, max=now
)

await self.start_jobs(job_ids)
await self.start_jobs(job_ids)

if self.allow_abort_jobs:
await self._cancel_aborted_jobs()
Expand Down Expand Up @@ -419,12 +426,23 @@ async def _cancel_aborted_jobs(self) -> None:
self.aborting_tasks.update(aborted)
await self.pool.zrem(abort_jobs_ss, *aborted)

def _release_sem_dec_counter_on_complete(self) -> None:
self.job_counter = self.job_counter - 1
self.sem.release()

async def start_jobs(self, job_ids: List[bytes]) -> None:
"""
For each job id, get the job definition, check it's not running and start it in a task
"""
for job_id_b in job_ids:
await self.sem.acquire()

if self.job_counter >= self.max_jobs:
self.sem.release()
return None

self.job_counter = self.job_counter + 1

job_id = job_id_b.decode()
in_progress_key = in_progress_key_prefix + job_id
async with self.pool.pipeline(transaction=True) as pipe:
Expand All @@ -433,6 +451,7 @@ async def start_jobs(self, job_ids: List[bytes]) -> None:
score = await pipe.zscore(self.queue_name, job_id)
if ongoing_exists or not score:
# job already started elsewhere, or already finished and removed from queue
self.job_counter = self.job_counter - 1
self.sem.release()
logger.debug('job %s already running elsewhere', job_id)
continue
Expand All @@ -445,11 +464,12 @@ async def start_jobs(self, job_ids: List[bytes]) -> None:
await pipe.execute()
except (ResponseError, WatchError):
# job already started elsewhere since we got 'existing'
self.job_counter = self.job_counter - 1
self.sem.release()
logger.debug('multi-exec error, job %s already started elsewhere', job_id)
else:
t = self.loop.create_task(self.run_job(job_id, int(score)))
t.add_done_callback(lambda _: self.sem.release())
t.add_done_callback(lambda _: self._release_sem_dec_counter_on_complete())
self.tasks[job_id] = t

async def run_job(self, job_id: str, score: int) -> None: # noqa: C901
Expand Down Expand Up @@ -484,6 +504,7 @@ async def job_failed(exc: BaseException) -> None:
ref=f'{job_id}:{function_name}',
serializer=self.job_serializer,
queue_name=self.queue_name,
job_id=job_id,
)
await asyncio.shield(self.finish_failed_job(job_id, result_data_))

Expand Down Expand Up @@ -539,6 +560,7 @@ async def job_failed(exc: BaseException) -> None:
timestamp_ms(),
ref,
self.queue_name,
job_id=job_id,
serializer=self.job_serializer,
)
return await asyncio.shield(self.finish_failed_job(job_id, result_data))
Expand Down Expand Up @@ -632,6 +654,7 @@ async def job_failed(exc: BaseException) -> None:
finished_ms,
ref,
self.queue_name,
job_id=job_id,
serializer=self.job_serializer,
)

Expand Down
10 changes: 10 additions & 0 deletions docs/examples/job_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from arq import create_pool
from arq.connections import RedisSettings
from arq.jobs import Job


async def the_task(ctx):
print('running the task with id', ctx['job_id'])
Expand Down Expand Up @@ -37,6 +39,14 @@ async def main():
> None
"""

# you can retrieve jobs by using arq.jobs.Job
await redis.enqueue_job('the_task', _job_id='my_job')
job5 = Job(job_id='my_job', redis=redis)
print(job5)
"""
<arq job my_job>
"""

class WorkerSettings:
functions = [the_task]

Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ Sometimes you want a job to only be run once at a time (eg. a backup) or once fo
invoices for a particular company).

*arq* supports this via custom job ids, see :func:`arq.connections.ArqRedis.enqueue_job`. It guarantees
that a job with a particular ID cannot be enqueued again until its execution has finished.
that a job with a particular ID cannot be enqueued again until its execution has finished and its result has cleared. To control when a finished job's result clears, you can use the `keep_result` setting on your worker, see :func:`arq.worker.func`.

.. literalinclude:: examples/job_ids.py

Expand Down
4 changes: 2 additions & 2 deletions requirements/docs.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# This file is autogenerated by pip-compile with python 3.9
# This file is autogenerated by pip-compile with python 3.11
# To update, run:
#
# pip-compile --output-file=requirements/docs.txt requirements/docs.in
Expand Down Expand Up @@ -35,7 +35,7 @@ requests==2.28.1
snowballstemmer==2.2.0
# via sphinx
sphinx==5.1.1
# via -r docs.in
# via -r requirements/docs.in
sphinxcontrib-applehelp==1.0.2
# via sphinx
sphinxcontrib-devhelp==1.0.2
Expand Down
10 changes: 2 additions & 8 deletions requirements/linting.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# This file is autogenerated by pip-compile with python 3.9
# This file is autogenerated by pip-compile with python 3.11
# To update, run:
#
# pip-compile --output-file=requirements/linting.txt requirements/linting.in
Expand Down Expand Up @@ -34,15 +34,9 @@ pycodestyle==2.9.1
# via flake8
pyflakes==2.5.0
# via flake8
tomli==2.0.1
# via
# black
# mypy
types-pytz==2022.2.1.0
# via -r requirements/linting.in
types-redis==4.2.8
# via -r requirements/linting.in
typing-extensions==4.3.0
# via
# black
# mypy
# via mypy
14 changes: 3 additions & 11 deletions requirements/pyproject.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# This file is autogenerated by pip-compile with python 3.9
# This file is autogenerated by pip-compile with python 3.11
# To update, run:
#
# pip-compile --extra=watch --output-file=requirements/pyproject.txt pyproject.toml
Expand All @@ -10,23 +10,15 @@ async-timeout==4.0.2
# via redis
click==8.1.3
# via arq (pyproject.toml)
deprecated==1.2.13
# via redis
hiredis==2.0.0
hiredis==2.1.0
# via redis
idna==3.3
# via anyio
packaging==21.3
# via redis
pyparsing==3.0.9
# via packaging
redis[hiredis]==4.3.4
redis[hiredis]==4.4.0
# via arq (pyproject.toml)
sniffio==1.2.0
# via anyio
typing-extensions==4.3.0
# via arq (pyproject.toml)
watchfiles==0.16.1
# via arq (pyproject.toml)
wrapt==1.14.1
# via deprecated
2 changes: 1 addition & 1 deletion requirements/testing.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ dirty-equals>=0.4,<1
msgpack>=1,<2
pydantic>=1.9.2,<2
pytest>=7,<8
pytest-asyncio>=0.19,<0.20
pytest-asyncio>=0.20.3
pytest-mock>=3,<4
pytest-sugar>=0.9,<1
pytest-timeout>=2,<3
Expand Down
Loading

0 comments on commit e54eace

Please sign in to comment.