Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BashOperator: Execute templated bash script as file #42783

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: execute templated shell script as file
  • Loading branch information
joffreybienvenu-infrabel committed Oct 19, 2024
commit 3233f018588ea8423905e5a5afac5af3fc40e934
19 changes: 16 additions & 3 deletions providers/src/airflow/providers/standard/hooks/subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,27 @@
from collections import namedtuple
from subprocess import PIPE, STDOUT, Popen
from tempfile import TemporaryDirectory, gettempdir
from typing import Iterator

from airflow.hooks.base import BaseHook

SubprocessResult = namedtuple("SubprocessResult", ["exit_code", "output"])


@contextlib.contextmanager
def working_directory(cwd: str | None = None) -> Iterator[str]:
"""
Context manager for handling (temporary) working directory.

Use the given cwd as working directory, if provided.
Otherwise, create a temporary directory.
"""
with contextlib.ExitStack() as stack:
if cwd is None:
cwd = stack.enter_context(TemporaryDirectory(prefix="airflowtmp"))
yield cwd


class SubprocessHook(BaseHook):
"""Hook for running processes with the ``subprocess`` module."""

Expand Down Expand Up @@ -61,9 +76,7 @@ def run_command(
or stdout
"""
self.log.info("Tmp dir root location: %s", gettempdir())
with contextlib.ExitStack() as stack:
if cwd is None:
cwd = stack.enter_context(TemporaryDirectory(prefix="airflowtmp"))
with working_directory(cwd=cwd) as cwd:

def pre_exec():
# Restore default signal disposition and invoke setsid
Expand Down
74 changes: 61 additions & 13 deletions providers/src/airflow/providers/standard/operators/bash.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@

import os
import shutil
import tempfile
import warnings
from functools import cached_property
from typing import TYPE_CHECKING, Any, Callable, Container, Sequence, cast

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.models.baseoperator import BaseOperator
from airflow.providers.standard.hooks.subprocess import SubprocessHook
from airflow.providers.standard.hooks.subprocess import SubprocessHook, SubprocessResult, working_directory
from airflow.utils.operator_helpers import context_to_airflow_vars
from airflow.utils.types import ArgNotSet

Expand Down Expand Up @@ -63,6 +64,9 @@ class BashOperator(BaseOperator):
If None (default), the command is run in a temporary directory.
To use current DAG folder as the working directory,
you might set template ``{{ dag_run.dag.folder }}``.
When bash_command is a '.sh' or '.bash' file, Airflow must have write
access to the working directory. The script will be rendered (Jinja
template) into a new temporary file in this directory.
:param output_processor: Function to further process the output of the bash script
(default is lambda output: output).

Expand Down Expand Up @@ -97,10 +101,14 @@ class BashOperator(BaseOperator):

.. note::

Add a space after the script name when directly calling a ``.sh`` script with the
``bash_command`` argument -- for example ``bash_command="my_script.sh "``. This
is because Airflow tries to apply load this file and process it as a Jinja template to
it ends with ``.sh``, which will likely not be what most users want.
To simply execute a ``.sh`` or ``.bash`` script (without any Jinja template), ddd a space after the
script name ``bash_command`` argument -- for example ``bash_command="my_script.sh "``. This
is because Airflow tries to load this file and process it as a Jinja template when
it ends with ``.sh`` or ``.bash``.

If you have Jinja template in your script, do not put any blank space. And add the script's directory
in the DAG's ``template_searchpath``. If you specify a ``cwd``, Airflow must have write access to
this directory. The script will be rendered (Jinja template) into a new temporary file in this directory.

.. warning::

Expand Down Expand Up @@ -180,6 +188,11 @@ def __init__(
# determine whether the bash_command value needs to re-rendered.
self._init_bash_command_not_set = isinstance(self.bash_command, ArgNotSet)

# Keep a copy of the original bash_command, without the Jinja template rendered.
# This is later used to determine if the bash_command is a script or an inline string command.
# We do this later, because the bash_command is not available in __init__ when using @task.bash.
self._unrendered_bash_command: str | ArgNotSet = bash_command

@cached_property
def subprocess_hook(self):
"""Returns hook for running the bash command."""
Expand All @@ -200,7 +213,7 @@ def refresh_bash_command(ti: TaskInstance) -> None:

RenderedTaskInstanceFields._update_runtime_evaluated_template_fields(ti)

def get_env(self, context):
def get_env(self, context) -> dict:
"""Build the set of environment variables to be exposed for the bash command."""
system_env = os.environ.copy()
env = self.env
Expand All @@ -220,7 +233,7 @@ def get_env(self, context):
return env

def execute(self, context: Context):
bash_path = shutil.which("bash") or "bash"
bash_path: str = shutil.which("bash") or "bash"
if self.cwd is not None:
if not os.path.exists(self.cwd):
raise AirflowException(f"Can not find the cwd: {self.cwd}")
Expand All @@ -234,15 +247,17 @@ def execute(self, context: Context):
# Both will ensure the correct Bash command is executed and that the Rendered Template view in the UI
# displays the executed command (otherwise it will display as an ArgNotSet type).
if self._init_bash_command_not_set:
is_inline_command = self._is_inline_command(bash_command=cast(str, self.bash_command))
ti = cast("TaskInstance", context["ti"])
self.refresh_bash_command(ti)
else:
is_inline_command = self._is_inline_command(bash_command=cast(str, self._unrendered_bash_command))

if is_inline_command:
result = self._run_inline_command(bash_path=bash_path, env=env)
else:
result = self._run_rendered_script_file(bash_path=bash_path, env=env)

result = self.subprocess_hook.run_command(
command=[bash_path, "-c", self.bash_command],
env=env,
output_encoding=self.output_encoding,
cwd=self.cwd,
)
if result.exit_code in self.skip_on_exit_code:
raise AirflowSkipException(f"Bash command returned exit code {result.exit_code}. Skipping.")
elif result.exit_code != 0:
Expand All @@ -252,5 +267,38 @@ def execute(self, context: Context):

return self.output_processor(result.output)

def _run_inline_command(self, bash_path: str, env: dict) -> SubprocessResult:
"""Pass the bash command as string directly in the subprocess."""
return self.subprocess_hook.run_command(
command=[bash_path, "-c", self.bash_command],
env=env,
output_encoding=self.output_encoding,
cwd=self.cwd,
)

def _run_rendered_script_file(self, bash_path: str, env: dict) -> SubprocessResult:
"""
Save the bash command into a file and execute this file.

This allows for longer commands, and prevents "Argument list too long error".
"""
with working_directory(cwd=self.cwd) as cwd:
with tempfile.NamedTemporaryFile(mode="w", dir=cwd, suffix=".sh") as file:
file.write(cast(str, self.bash_command))
file.flush()

bash_script = os.path.basename(file.name)
return self.subprocess_hook.run_command(
command=[bash_path, bash_script],
env=env,
output_encoding=self.output_encoding,
cwd=cwd,
)

@classmethod
def _is_inline_command(cls, bash_command: str) -> bool:
"""Return True if the bash command is an inline string. False if it's a bash script file."""
return not bash_command.endswith(tuple(cls.template_ext))

def on_kill(self) -> None:
self.subprocess_hook.send_sigterm()
28 changes: 28 additions & 0 deletions providers/tests/standard/operators/test_bash.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from datetime import datetime, timedelta
from pathlib import Path
from time import sleep
from typing import TYPE_CHECKING
from unittest import mock

import pytest
Expand All @@ -38,6 +39,9 @@
if AIRFLOW_V_3_0_PLUS:
from airflow.utils.types import DagRunTriggeredByType

if TYPE_CHECKING:
from airflow.models import TaskInstance

DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=timezone.utc)
END_DATE = datetime(2016, 1, 2, tzinfo=timezone.utc)
INTERVAL = timedelta(hours=12)
Expand All @@ -60,6 +64,7 @@ def test_bash_operator_init(self):
assert op.skip_on_exit_code == [99]
assert op.cwd is None
assert op._init_bash_command_not_set is False
assert op._unrendered_bash_command == "echo"

@pytest.mark.db_test
@pytest.mark.parametrize(
Expand Down Expand Up @@ -284,3 +289,26 @@ def test_templated_fields(self, create_task_instance_of_operator):
assert task.bash_command == 'echo "test_templated_fields_dag"'
assert task.env == {"FOO": "2024-02-01"}
assert task.cwd == Path(__file__).absolute().parent.as_posix()

def test_templated_bash_script(self, dag_maker, tmp_path, session):
"""
Creates a .sh script with Jinja template.
Pass it to the BashOperator and ensure it gets correctly rendered and executed.
"""
bash_script: str = "sample.sh"
path: Path = tmp_path / bash_script
path.write_text('echo "{{ ti.task_id }}"')

with dag_maker(
dag_id="test_templated_bash_script", session=session, template_searchpath=os.fspath(path.parent)
):
BashOperator(task_id="test_templated_fields_task", bash_command=bash_script)
ti: TaskInstance = dag_maker.create_dagrun().task_instances[0]
session.add(ti)
session.commit()
context = ti.get_template_context(session=session)
ti.render_templates(context=context)

task: BashOperator = ti.task
result = task.execute(context=context)
assert result == "test_templated_fields_task"
35 changes: 35 additions & 0 deletions tests/decorators/test_bash.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import stat
import warnings
from contextlib import nullcontext as no_raise
from pathlib import Path
from typing import TYPE_CHECKING
from unittest import mock

import pytest
Expand All @@ -32,6 +34,10 @@

from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_rendered_ti_fields

if TYPE_CHECKING:
from airflow.models import TaskInstance
from airflow.operators.bash import BashOperator

DEFAULT_DATE = timezone.datetime(2023, 1, 1)


Expand Down Expand Up @@ -500,3 +506,32 @@ def bash():
with pytest.raises(AirflowException):
ti.run()
assert ti.task.bash_command == f"{DEFAULT_DATE.date()}; exit 1;"

def test_templated_bash_script(self, dag_maker, tmp_path, session):
"""
Creates a .sh script with Jinja template.
Pass it to the BashOperator and ensure it gets correctly rendered and executed.
"""
bash_script: str = "sample.sh"
path: Path = tmp_path / bash_script
path.write_text('echo "{{ ti.task_id }}"')

with dag_maker(
dag_id="test_templated_bash_script", session=session, template_searchpath=os.fspath(path.parent)
):

@task.bash
def test_templated_fields_task():
return bash_script

test_templated_fields_task()

ti: TaskInstance = dag_maker.create_dagrun().task_instances[0]
session.add(ti)
session.commit()
context = ti.get_template_context(session=session)
ti.render_templates(context=context)

op: BashOperator = ti.task
result = op.execute(context=context)
assert result == "test_templated_fields_task"