Creating a DAG Factory in Airflow

If you are using Airflow to orchestrate DBT, and if like me you prefer to isolate pipelines into logical units, then you will have noticed the inherent repetition that comes with this approach.

For example, perhaps you have a set of models which read from operations_db source and write to a set of operations marts, and another set of models that read from marketing_db source and write to a set of marketing marts. The logical organization of these pipelines would be two DAGS. One for operations say dbt__operations_build.py and another for marketing dbt__marketing_build.py.

In both cases the core components of those pipelines are the same – run the bronze layer, run the silver layer, run the gold layer (more on medallion architecture here).

An example of those stages might look like this:

tasks = []
stages = ["staging.operations", "intermediate.operations", 
"marts.operations"]
for stage in stages:
    task_id = f"dbt_build_{stage.split('.')[0]}"
    tasks.append(
        DbtOperator(
            task_id=task_id,
            dbt_command=f"dbt build --select {stage}",
            env=DBT_ENV,
            dag=dag,
        )
    )

chain(*tasks)

Which results in:

For the marketing jobs, the code will be identical. Now imagine 50 separate pipelines all utilizing those three core steps and it becomes apparent that there must be a better way.

Yet Another Config File

What if instead of 50 dags for 50 pipelines, we had a single dag that generates other dags for us? This is the concept of a DAG factory.

The core components are a dag_config.yaml and our dbt__dag_factory.py dag.

Below is an example dag_config.yaml – per dag we define the dag_id, the schedule, and per task what DBT needs to do.

---
dags:
  - dag_id: dbt__operations_daily_build
    schedule: 0 14,18 * * *
    tasks:
      - task_id: build_staging
        dbt_command: dbt build --select staging.operations
      - task_id: build_intermediate
        dbt_command: dbt build --select intermediate.operations
      - task_id: build_marts
        dbt_command: dbt build --select marts.operations

  - dag_id: dbt__marketing_daily_build
    schedule: 0 14,18 * * *
    tasks:
      - task_id: build_staging
        dbt_command: dbt build --select staging.marketing
      - task_id: build_intermediate
        dbt_command: dbt build --select intermediate.marketing
      - task_id: build_marts
        dbt_command: dbt build --select marts.marketing

Next is the DAG itself:

import os
import yaml
from airflow import DAG
from datetime import datetime
from operators.DbtOperator import DbtOperator
# DBT_ENV is jinja templated to read secrets from environment vars.
from utils.constants import DBT_ENV, ASSETS_DIR

DAG_NAME = "dbt__dag_factory"

# For me this will resolve to /usr/local/airflow/assets/dbt__dag_factory
# and is where I will store the dag_config
CONFIG_DIR = os.path.join(ASSETS_DIR, DAG_NAME)


def create_dbt_dag(dag_config):
    """
    Creates dags based on configuration file

    :param dag_config: dict
    :return: airflow.models.dag.DAG
    """

    dag_id = dag_config.get("dag_id")
    schedule_interval = dag_config.get("schedule")
    tasks = dag_config.get("tasks")

    if not dag_id or not schedule_interval or not tasks:
        raise ValueError("Missing required fields in dag_config.")

    dag = DAG(
        dag_id=dag_id,
        start_date=datetime(2021, 11, 17),
        schedule=schedule_interval,
        catchup=False,
        max_active_runs=1,
        max_active_tasks=1,
    )

    previous_task = None

    for task_config in tasks:
        task_id = task_config.get("task_id")
        dbt_command = task_config.get("dbt_command")

        if not task_id or not dbt_command:
            raise ValueError("Missing required fields in task_config.")

        dbt_task = DbtOperator(
            task_id=task_id,
            dbt_command=dbt_command,
            env=DBT_ENV,
            dag=dag,
        )

        if previous_task is not None:
            previous_task.set_downstream(dbt_task)

        previous_task = dbt_task

    return dag


def load_config(config_file):
    with open(os.path.join(CONFIG_DIR, config_file), "r") as config_file:
        config_data = yaml.safe_load(config_file)
    return config_data


config_data = load_config("dag_config.yaml")

for dag_config in config_data["dags"]:
    dag = create_dbt_dag(dag_config)
    globals()[dag_config["dag_id"]] = dag

The config file is read in, per config block we call the create_dbt_dag() function. The create_dbt_dag() function creates a DAG object and sets simple sequential task dependencies. Finally we assign the DAG object to the globals() dict.

This last line is key. Without it only the last DAG in the config file is registered by Airflow. The reason for this is down to how airflow parses dag objects.

A single file can contain multiple dags (dag_1, dag_2) for example. So long as they are defined at the “top level” code then airflow will parse these objects and register the dag.

In our case, we overwrite the dag object with each loop, so airflow will only register the last dag in the config block as this is all it really sees at the “top level”. The solution is then to register each dag object created to the globals() dict, the result of which gives us the following in the UI.

Final thoughts

Generally speaking I avoid designing things that are run via config files. I find too often they try to do too much, and the YAML becomes too complex and just generally becomes unpleasant. However for DBT workloads, it’s a nice solution as the scope is small (i.e just DBT), and the operations performed are highly generic. Rather than replicate dags, it’s far easier to replicate config blocks and put the time into something more interesting – just like this guy did https://github.com/ajbosco/dag-factory.

Leave a comment