diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index ed67e7a4ac8d1..e085efc3142b8 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -24,7 +24,6 @@ import sys import time from collections import Counter, defaultdict, deque -from dataclasses import dataclass from datetime import timedelta from functools import lru_cache, partial from pathlib import Path @@ -83,7 +82,6 @@ from datetime import datetime from types import FrameType - from sqlalchemy.engine import Result from sqlalchemy.orm import Query, Session from airflow.dag_processing.manager import DagFileProcessorAgent @@ -99,7 +97,6 @@ DM = DagModel -@dataclass class ConcurrencyMap: """ Dataclass to represent concurrency maps. @@ -109,17 +106,24 @@ class ConcurrencyMap: to # of task instances in the given state list in each DAG run. """ - dag_active_tasks_map: dict[str, int] - task_concurrency_map: dict[tuple[str, str], int] - task_dagrun_concurrency_map: dict[tuple[str, str, str], int] - - @classmethod - def from_concurrency_map(cls, mapping: dict[tuple[str, str, str], int]) -> ConcurrencyMap: - instance = cls(Counter(), Counter(), Counter(mapping)) - for (d, _, t), c in mapping.items(): - instance.dag_active_tasks_map[d] += c - instance.task_concurrency_map[(d, t)] += c - return instance + def __init__(self): + self.dag_run_active_tasks_map: Counter[tuple[str, str]] = Counter() + self.task_concurrency_map: Counter[tuple[str, str]] = Counter() + self.task_dagrun_concurrency_map: Counter[tuple[str, str, str]] = Counter() + + def load(self, session: Session) -> None: + self.dag_run_active_tasks_map.clear() + self.task_concurrency_map.clear() + self.task_dagrun_concurrency_map.clear() + query = session.execute( + select(TI.dag_id, TI.task_id, TI.run_id, func.count("*")) + .where(TI.state.in_(EXECUTION_STATES)) + .group_by(TI.task_id, TI.run_id, TI.dag_id) + ) + for dag_id, task_id, run_id, c in query: + self.dag_run_active_tasks_map[dag_id, run_id] += c + self.task_concurrency_map[(dag_id, task_id)] += c + self.task_dagrun_concurrency_map[(dag_id, run_id, task_id)] += c def _is_parent_process() -> bool: @@ -258,22 +262,6 @@ def _debug_dump(self, signum: int, frame: FrameType | None) -> None: executor.debug_dump() self.log.info("-" * 80) - def __get_concurrency_maps(self, states: Iterable[TaskInstanceState], session: Session) -> ConcurrencyMap: - """ - Get the concurrency maps. - - :param states: List of states to query for - :return: Concurrency map - """ - ti_concurrency_query: Result = session.execute( - select(TI.task_id, TI.run_id, TI.dag_id, func.count("*")) - .where(TI.state.in_(states)) - .group_by(TI.task_id, TI.run_id, TI.dag_id) - ) - return ConcurrencyMap.from_concurrency_map( - {(dag_id, run_id, task_id): count for task_id, run_id, dag_id, count in ti_concurrency_query} - ) - def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -> list[TI]: """ Find TIs that are ready for execution based on conditions. @@ -326,7 +314,8 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) - starved_pools = {pool_name for pool_name, stats in pools.items() if stats["open"] <= 0} # dag_id to # of running tasks and (dag_id, task_id) to # of running tasks. - concurrency_map = self.__get_concurrency_maps(states=EXECUTION_STATES, session=session) + concurrency_map = ConcurrencyMap() + concurrency_map.load(session=session) # Number of tasks that cannot be scheduled because of no open slot in pool num_starving_tasks_total = 0 @@ -465,22 +454,22 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) - # Check to make sure that the task max_active_tasks of the DAG hasn't been # reached. dag_id = task_instance.dag_id - - current_active_tasks_per_dag = concurrency_map.dag_active_tasks_map[dag_id] - max_active_tasks_per_dag_limit = task_instance.dag_model.max_active_tasks + dag_run_key = (dag_id, task_instance.run_id) + current_active_tasks_per_dag_run = concurrency_map.dag_run_active_tasks_map[dag_run_key] + dag_max_active_tasks = task_instance.dag_model.max_active_tasks self.log.info( "DAG %s has %s/%s running and queued tasks", dag_id, - current_active_tasks_per_dag, - max_active_tasks_per_dag_limit, + current_active_tasks_per_dag_run, + dag_max_active_tasks, ) - if current_active_tasks_per_dag >= max_active_tasks_per_dag_limit: + if current_active_tasks_per_dag_run >= dag_max_active_tasks: self.log.info( "Not executing %s since the number of tasks running or queued " "from DAG %s is >= to the DAG's max_active_tasks limit of %s", task_instance, dag_id, - max_active_tasks_per_dag_limit, + dag_max_active_tasks, ) starved_dags.add(dag_id) continue @@ -571,7 +560,7 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) - executable_tis.append(task_instance) open_slots -= task_instance.pool_slots - concurrency_map.dag_active_tasks_map[dag_id] += 1 + concurrency_map.dag_run_active_tasks_map[dag_run_key] += 1 concurrency_map.task_concurrency_map[(task_instance.dag_id, task_instance.task_id)] += 1 concurrency_map.task_dagrun_concurrency_map[ (task_instance.dag_id, task_instance.run_id, task_instance.task_id) diff --git a/newsfragments/42953.significant b/newsfragments/42953.significant new file mode 100644 index 0000000000000..20f25b434572a --- /dev/null +++ b/newsfragments/42953.significant @@ -0,0 +1,3 @@ +DAG.max_active_runs now evaluated per-run + +Previously, this was evaluated across all runs of the dag. This behavior change was passed by lazy consensus. Vote thread: https://lists.apache.org/thread/9o84d3yn934m32gtlpokpwtbbmtxj47l. diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 52ad2f6a8c6c0..17fd851714bb5 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -22,7 +22,7 @@ import logging import os import sys -from collections import deque +from collections import Counter, deque from datetime import timedelta from importlib import reload from typing import Generator @@ -1169,86 +1169,78 @@ def test_tis_for_queued_dagruns_are_not_run(self, dag_maker): assert ti1.state == State.SCHEDULED assert ti2.state == State.QUEUED - def test_find_executable_task_instances_concurrency(self, dag_maker): - dag_id = "SchedulerJobTest.test_find_executable_task_instances_concurrency" - session = settings.Session() + @pytest.mark.parametrize("active_state", [TaskInstanceState.RUNNING, TaskInstanceState.QUEUED]) + def test_find_executable_task_instances_concurrency(self, dag_maker, active_state, session): + """We verify here that, with varying amounts of queued / running / scheduled tasks, + the correct number of TIs are queued""" + dag_id = "check_MAT_dag" with dag_maker(dag_id=dag_id, max_active_tasks=2, session=session): - EmptyOperator(task_id="dummy") + EmptyOperator(task_id="task_1") + EmptyOperator(task_id="task_2") + EmptyOperator(task_id="task_3") scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) - dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) - dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED) - dr3 = dag_maker.create_dagrun_after(dr2, run_type=DagRunType.SCHEDULED) + dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, run_id="run_1", session=session) + dr2 = dag_maker.create_dagrun_after( + dr1, run_type=DagRunType.SCHEDULED, run_id="run_2", session=session + ) + dr3 = dag_maker.create_dagrun_after( + dr2, run_type=DagRunType.SCHEDULED, run_id="run_3", session=session + ) - ti1 = dr1.task_instances[0] - ti2 = dr2.task_instances[0] - ti3 = dr3.task_instances[0] - ti1.state = State.RUNNING - ti2.state = State.SCHEDULED - ti3.state = State.SCHEDULED - session.merge(ti1) - session.merge(ti2) - session.merge(ti3) + # set 2 tis in dr1 to running + # no more can be queued + t1, t2, t3 = dr1.get_task_instances(session=session) + t1.state = active_state + t2.state = active_state + t3.state = State.SCHEDULED + session.merge(t1) + session.merge(t2) + session.merge(t3) + # set 1 ti from dr1 to running + # one can be queued + t1, t2, t3 = dr2.get_task_instances(session=session) + t1.state = active_state + t2.state = State.SCHEDULED + t3.state = State.SCHEDULED + session.merge(t1) + session.merge(t2) + session.merge(t3) + # set 0 tis from dr1 to running + # two can be queued + t1, t2, t3 = dr3.get_task_instances(session=session) + t1.state = State.SCHEDULED + t2.state = State.SCHEDULED + t3.state = State.SCHEDULED + session.merge(t1) + session.merge(t2) + session.merge(t3) session.flush() - res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) - - assert 1 == len(res) - res_keys = (x.key for x in res) - assert ti2.key in res_keys - - ti2.state = State.RUNNING - session.merge(ti2) - session.flush() - - res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) - - assert 0 == len(res) - session.rollback() - - def test_find_executable_task_instances_concurrency_queued(self, dag_maker): - dag_id = "SchedulerJobTest.test_find_executable_task_instances_concurrency_queued" - with dag_maker(dag_id=dag_id, max_active_tasks=3): - task1 = EmptyOperator(task_id="dummy1") - task2 = EmptyOperator(task_id="dummy2") - task3 = EmptyOperator(task_id="dummy3") - - scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) - session = settings.Session() - - dag_run = dag_maker.create_dagrun() - - ti1 = dag_run.get_task_instance(task1.task_id) - ti2 = dag_run.get_task_instance(task2.task_id) - ti3 = dag_run.get_task_instance(task3.task_id) - ti1.state = State.RUNNING - ti2.state = State.QUEUED - ti3.state = State.SCHEDULED - - session.merge(ti1) - session.merge(ti2) - session.merge(ti3) + queued_tis = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + queued_runs = Counter([x.run_id for x in queued_tis]) + assert queued_runs["run_1"] == 0 + assert queued_runs["run_2"] == 1 + assert queued_runs["run_3"] == 2 - session.flush() + session.commit() + session.query(TaskInstance).all() - res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + # now we still have max tis running so no more will be queued + queued_tis = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + assert queued_tis == [] - assert 1 == len(res) - assert res[0].key == ti3.key session.rollback() # TODO: This is a hack, I think I need to just remove the setting and have it on always def test_find_executable_task_instances_max_active_tis_per_dag(self, dag_maker): dag_id = "SchedulerJobTest.test_find_executable_task_instances_max_active_tis_per_dag" - task_id_1 = "dummy" - task_id_2 = "dummy2" with dag_maker(dag_id=dag_id, max_active_tasks=16): - task1 = EmptyOperator(task_id=task_id_1, max_active_tis_per_dag=2) - task2 = EmptyOperator(task_id=task_id_2) + task1 = EmptyOperator(task_id="dummy", max_active_tis_per_dag=2) + task2 = EmptyOperator(task_id="dummy2") executor = MockExecutor(do_update=True) @@ -1653,65 +1645,88 @@ def test_enqueue_task_instances_sets_ti_state_to_None_if_dagrun_in_finish_state( ("secondary_exec", "secondary_exec"), ], ) - def test_critical_section_enqueue_task_instances(self, task1_exec, task2_exec, dag_maker, mock_executors): + def test_critical_section_enqueue_task_instances( + self, task1_exec, task2_exec, dag_maker, mock_executors, session + ): dag_id = "SchedulerJobTest.test_execute_task_instances" - task_id_1 = "dummy_task" - task_id_2 = "dummy_task_nonexistent_queue" - session = settings.Session() # important that len(tasks) is less than max_active_tasks # because before scheduler._execute_task_instances would only # check the num tasks once so if max_active_tasks was 3, # we could execute arbitrarily many tasks in the second run with dag_maker(dag_id=dag_id, max_active_tasks=3, session=session) as dag: - task1 = EmptyOperator(task_id=task_id_1, executor=task1_exec) - task2 = EmptyOperator(task_id=task_id_2, executor=task2_exec) + task1 = EmptyOperator(task_id="t1", executor=task1_exec) + task2 = EmptyOperator(task_id="t2", executor=task2_exec) + task3 = EmptyOperator(task_id="t3", executor=task2_exec) + task4 = EmptyOperator(task_id="t4", executor=task2_exec) scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) - # create first dag run with 2 running tasks + # create first dag run with 3 running tasks - dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) + dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, session=session) - ti1 = dr1.get_task_instance(task1.task_id, session) - ti2 = dr1.get_task_instance(task2.task_id, session) - ti1.state = State.RUNNING - ti2.state = State.RUNNING + dr1_ti1 = dr1.get_task_instance(task1.task_id, session) + dr1_ti2 = dr1.get_task_instance(task2.task_id, session) + dr1_ti3 = dr1.get_task_instance(task3.task_id, session) + dr1_ti4 = dr1.get_task_instance(task4.task_id, session) + dr1_ti1.state = State.RUNNING + dr1_ti2.state = State.RUNNING + dr1_ti3.state = State.RUNNING + dr1_ti4.state = State.SCHEDULED session.flush() - assert State.RUNNING == dr1.state - assert 2 == DAG.get_num_task_instances( - dag_id, task_ids=dag.task_ids, states=[State.RUNNING], session=session + assert dr1.state == State.RUNNING + num_tis = DAG.get_num_task_instances( + dag_id=dag_id, + task_ids=dag.task_ids, + states=[State.RUNNING], + session=session, ) + assert num_tis == 3 # create second dag run - dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED) - ti3 = dr2.get_task_instance(task1.task_id, session) - ti4 = dr2.get_task_instance(task2.task_id, session) + dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED, session=session) + dr2_ti1 = dr2.get_task_instance(task1.task_id, session) + dr2_ti2 = dr2.get_task_instance(task2.task_id, session) + dr2_ti3 = dr2.get_task_instance(task3.task_id, session) + dr2_ti4 = dr2.get_task_instance(task4.task_id, session) # manually set to scheduled so we can pick them up - ti3.state = State.SCHEDULED - ti4.state = State.SCHEDULED + dr2_ti1.state = State.SCHEDULED + dr2_ti2.state = State.SCHEDULED + dr2_ti3.state = State.SCHEDULED + dr2_ti4.state = State.SCHEDULED session.flush() - assert State.RUNNING == dr2.state + assert dr2.state == State.RUNNING - res = self.job_runner._critical_section_enqueue_task_instances(session) + num_queued = self.job_runner._critical_section_enqueue_task_instances(session=session) + assert num_queued == 3 # check that max_active_tasks is respected - ti1.refresh_from_db() - ti2.refresh_from_db() - ti3.refresh_from_db() - ti4.refresh_from_db() - assert 3 == DAG.get_num_task_instances( - dag_id, task_ids=dag.task_ids, states=[State.RUNNING, State.QUEUED], session=session - ) - assert State.RUNNING == ti1.state - assert State.RUNNING == ti2.state - assert {State.QUEUED, State.SCHEDULED} == {ti3.state, ti4.state} - assert 1 == res - res = self.job_runner._critical_section_enqueue_task_instances(session) - assert 0 == res + num_tis = DAG.get_num_task_instances( + dag_id=dag_id, + task_ids=dag.task_ids, + states=[State.RUNNING, State.QUEUED], + session=session, + ) + assert num_tis == 6 + + # this doesn't really tell us anything since we set these values manually, but hey + dr1_counter = Counter(x.state for x in dr1.get_task_instances(session=session)) + assert dr1_counter[State.RUNNING] == 3 + assert dr1_counter[State.SCHEDULED] == 1 + + # this is the more meaningful bit + # three of dr2's tasks should be queued since that's max active tasks + # and max active tasks is evaluated per-dag-run + dr2_counter = Counter(x.state for x in dr2.get_task_instances(session=session)) + assert dr2_counter[State.QUEUED] == 3 + assert dr2_counter[State.SCHEDULED] == 1 + + num_queued = self.job_runner._critical_section_enqueue_task_instances(session=session) + assert num_queued == 0 def test_execute_task_instances_limit_second_executor(self, dag_maker, mock_executors): dag_id = "SchedulerJobTest.test_execute_task_instances_limit"