Airflow 1.9 : Run a task when upstream is skipped by shortcircuit

不羁岁月 提交于 2019-12-24 05:38:25

问题


I have a task that I'll call final that has multiple upstream connections. When one of the upstreams gets skipped by ShortCircuitOperator this task gets skipped as well. I don't want final task to get skipped as it has to report on DAG success.

To avoid it getting skipped I used trigger_rule='all_done', but it still gets skipped.

If I use BranchPythonOperator instead of ShortCircuitOperator final task doesn't get skipped. It would seem like branching workflow could be a solution, even though not optimal, but now final will not respect failures of upstream tasks.

How do I get it to only run when upstreams are successful or skipped?

Sample ShortCircuit DAG:

from airflow import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import ShortCircuitOperator
from datetime import datetime
from random import randint

default_args = {
    'owner': 'airflow',
    'start_date': datetime(2018, 8, 1)}

dag = DAG(
    'shortcircuit_test',
    default_args=default_args,
    schedule_interval='* * * * *',
    catchup=False)

def shortcircuit_fn():
    return randint(0, 1) == 1

task_1 = DummyOperator(dag=dag, task_id='task_1')
task_2 = DummyOperator(dag=dag, task_id='task_2')

work = DummyOperator(dag=dag, task_id='work')
short = ShortCircuitOperator(dag=dag, task_id='short_circuit', python_callable=shortcircuit_fn)
final = DummyOperator(dag=dag, task_id="final", trigger_rule="all_done")

task_1 >> short >> work >> final
task_1 >> task_2 >> final

Sample Branch DAG:

from airflow import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import BranchPythonOperator
from datetime import datetime
from random import randint

default_args = {
    'owner': 'airflow',
    'start_date': datetime(2018, 8, 1)}

dag = DAG(
    'branch_test',
    default_args=default_args,
    schedule_interval='* * * * *',
    catchup=False)

# these two are only here to protect tasks from getting skipped as direct dependencies of branch operator
to_do_work = DummyOperator(dag=dag, task_id='to_do_work')
to_skip_work = DummyOperator(dag=dag, task_id='to_skip_work')

def branch_fn():
    return to_do_work.task_id if randint(0, 1) == 1 else to_skip_work.task_id

task_1 = DummyOperator(dag=dag, task_id='task_1')
task_2 = DummyOperator(dag=dag, task_id='task_2')

work = DummyOperator(dag=dag, task_id='work')
branch = BranchPythonOperator(dag=dag, task_id='branch', python_callable=branch_fn)
final = DummyOperator(dag=dag, task_id="final", trigger_rule="all_done")

task_1 >> branch >> to_do_work >> work >> final
branch >> to_skip_work >> final
task_1 >> task_2 >> final


回答1:


I've made it work by making final task to check for statuses of upstream instances. Not beautiful as only way to access their state I've found was by querying Airflow DB.

# # additional imports to ones in question code
# from airflow import AirflowException
# from airflow.models import TaskInstance
# from airflow.operators.python_operator import PythonOperator
# from airflow.settings import Session
# from airflow.utils.state import State
# from airflow.utils.trigger_rule import TriggerRule

def all_upstreams_either_succeeded_or_skipped(dag, task, task_instance, **context):
    """
    find directly upstream task instances and count how many are not in prefered statuses.
    return True if we got no instances with non-preferred statuses.
    """
    upstream_task_ids = [t.task_id for t in task.get_direct_relatives(upstream=True)]
    session = Session()
    query = (session
        .query(TaskInstance)
        .filter(
            TaskInstance.dag_id == dag.dag_id,
            TaskInstance.execution_date.in_([task_instance.execution_date]),
            TaskInstance.task_id.in_(upstream_task_ids)
        )
    )
    upstream_task_instances = query.all()
    unhappy_task_instances = [ti for ti in upstream_task_instances if ti.state not in [State.SUCCESS, State.SKIPPED]]
    print(unhappy_task_instances)
    return len(unhappy_task_instances) == 0

def final_fn(**context):
    """
    fail if upstream task instances have unwanted statuses
    """
    if not all_upstreams_either_succeeded_or_skipped(**context):
        raise AirflowException("Not all upstream tasks succeeded.")
    # Do things

# will run when upstream task instances are done, including failed
final = PythonOperator(
    dag=dag,
    task_id="final",
    trigger_rule=TriggerRule.ALL_DONE,
    python_callable=final_fn,
    provide_context=True)



回答2:


I've ended up with developing custom ShortCircuitOperator based on the original one:

class ShortCircuitOperator(PythonOperator, SkipMixin):
    """
    Allows a workflow to continue only if a condition is met. Otherwise, the
    workflow "short-circuits" and downstream tasks that only rely on this operator
    are skipped.

    The ShortCircuitOperator is derived from the PythonOperator. It evaluates a
    condition and short-circuits the workflow if the condition is False. Any
    downstream tasks that only rely on this operator are marked with a state of "skipped".
    If the condition is True, downstream tasks proceed as normal.

    The condition is determined by the result of `python_callable`.
    """

    def find_tasks_to_skip(self, task, found_tasks=None):
        if not found_tasks:
            found_tasks = []
        direct_relatives = task.get_direct_relatives(upstream=False)
        for t in direct_relatives:
            if len(t.upstream_task_ids) == 1:
                found_tasks.append(t)
                self.find_tasks_to_skip(t, found_tasks)
        return found_tasks

    def execute(self, context):
        condition = super(ShortCircuitOperator, self).execute(context)
        self.log.info("Condition result is %s", condition)

        if condition:
            self.log.info('Proceeding with downstream tasks...')
            return

        self.log.info(
            'Skipping downstream tasks that only rely on this path...')

        tasks_to_skip = self.find_tasks_to_skip(context['task'])
        self.log.debug("Tasks to skip: %s", tasks_to_skip)

        if tasks_to_skip:
            self.skip(context['dag_run'], context['ti'].execution_date,
                      tasks_to_skip)

        self.log.info("Done.")

This operator makes sure no downstream task that rely on multiple paths are getting skipped because of one skipped task.




回答3:


This may have been added after you asked your initial question, but Airflow now conveniently has a trigger_rule value of none_failed. If you set this on your final task, it should complete whether upstream tasks are skipped or succeeded, just not when they fail.

More info: https://airflow.apache.org/concepts.html#trigger-rules



来源:https://stackoverflow.com/questions/51725746/airflow-1-9-run-a-task-when-upstream-is-skipped-by-shortcircuit

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!