Pilotcore Pilotcore

Creating a Custom XCom Backend in Airflow

In this article, we will take a look at how we can use XComs to create data pipelines in Airflow.

Pilotcore 3 min read
Creating a Custom XCom Backend in Airflow

In Airflow, XComs (short for "cross-communications") are a mechanism that lets tasks talk to exchange data between themselves.

An XCom is identified by a key (essentially its name), as well as the task_id and dag_id it came from. They can have any (serializable) value, however, they are designed to handle only very small data. Actually, the size limit will differ depending on your backend:

  • SQLite: 2 GB
  • Postgres: 1 GB
  • MySQL: 64 KB

That's why they, in the default form, can't be used to send and retrieve data frames or other bigger storage types. It's good to mention that Airflow is not designed for heavy data processing, for that use case, you could be better off with a specialized tool like Spark.

That being said, at Pilotcore we find that it's handy to be able to exchange data between tasks that are sometimes a little bigger than just 64 KB. For instance, the first task might create the DataFrame from records in the external database (that is not managed by us), send it to a second one and finally, the third one might send us a report. By being able to exchange small data frames between the tasks, their roles can be nicely isolated and if there is an error in the processing, we have visibility into the data for troubleshooting.

In Airflow, you have an option to create your own XCom implementation. Let's start by importing everything we will need:

import os
import uuid
import pandas as pd

from typing import Any
from airflow.models.xcom import BaseXCom
from airflow.providers.amazon.aws.hooks.s3 import S3Hook

Then let's create a new class, subclassing the original BaseXCom, we also add two variables to it, we will get to them later

class S3XComBackend(BaseXCom):
   PREFIX = "xcom_s3"
   BUCKET_NAME = os.environ.get("S3_XCOM_BUCKET_NAME")

Now we need to implement two static methods, serialize_value and deserialize_value. As the name suggests, one will be used to serialize variables into XCom-compatible format and another one to retrieve it.

In this implementation, we will limit ourselves only to pandas DataFrames while keeping backward compatibility for anything else. But you will see that this method can be easily extended to other ones as well.

In the serialization method, we will first check if the value is an instance of pandas DataFrame, if not, it can just return it.

Otherwise, it will create an S3 hook, serialize it to a pickle format, upload to S3 and in the end, only the S3 path is returned from the task.

@staticmethod
def serialize_value(value: Any):
    if isinstance(value, pd.DataFrame):
        hook = S3Hook()
        key = f"{str(uuid.uuid4())}.pickle"
        filename = f"{key}.pickle"
        value.to_pickle(filename, index=False)
        hook.load_file(
            filename=filename,
            key=key,
            bucket_name=S3XComBackend.BUCKET_NAME,
            replace=True
        )
        value = f"{S3XComBackend.PREFIX}://{S3XComBackend.BUCKET_NAME}/{key}"

    return BaseXCom.serialize_value(value)

You may have noticed that we are using prefix xcom_s3 instead of just s3, this is because in the deserialization method, we need to distinguish between s3 addresses created by the serialization method and s3 addresses directly returned by the task.

In the deserialization method, all we need to do is check if the value is a string starting with our custom prefix and if so, load it from s3:

@staticmethod
def deserialize_value(result) -> Any:
    result = BaseXCom.deserialize_value(result)

    if isinstance(result, str) and result.startswith(S3XComBackend.PREFIX):
        hook = S3Hook()
        key = result.replace(f"{S3XComBackend.PREFIX}://{S3XComBackend.BUCKET_NAME}/", "")
        filename = hook.download_file(
            key=key,
            bucket_name=S3XComBackend.BUCKET_NAME,
            local_path="/tmp"
        )
        result = pd.read_csv(filename)

    return result

To configure Airflow to use the new backend, you can set environment variable AIRFLOW__CORE__XCOM_BACKEND to xcom_s3_backend.S3XComBackend, where xcom_s3_backend is a name of the file located in your PYTHONPATH with the S3XComBackend.

Now we can put everything together, so you can just copy-paste the code into your editor.

Because target s3 bucket name is a deployment-dependent property, we make it configurable via an environment variable. So in the final code we also added a method called _assert_s3_backend that will check if the class property BUCKET_NAME is correctly initialized.

import os
import uuid
import pandas as pd

from typing import Any
from airflow.models.xcom import BaseXCom
from airflow.providers.amazon.aws.hooks.s3 import S3Hook


class S3XComBackend(BaseXCom):
   PREFIX = "xcom_s3"
   BUCKET_NAME = os.environ.get("S3_XCOM_BUCKET_NAME")

   @staticmethod
   def _assert_s3_backend():
       if S3XComBackend.BUCKET_NAME is None:
           raise ValueError("Unknown bucket for S3 backend.")

   @staticmethod
   def serialize_value(value: Any):
       if isinstance(value, pd.DataFrame):
           S3XComBackend._assert_s3_backend()
           hook = S3Hook()
           key = f"data_{str(uuid.uuid4())}.csv"
           filename = f"{key}.csv"
           value.to_csv(filename, index=False)
           hook.load_file(
               filename=filename,
               key=key,
               bucket_name=S3XComBackend.BUCKET_NAME,
               replace=True
           )
           value = f"{S3XComBackend.PREFIX}://{S3XComBackend.BUCKET_NAME}/{key}"

       return BaseXCom.serialize_value(value)

   @staticmethod
   def deserialize_value(result) -> Any:
       result = BaseXCom.deserialize_value(result)

       if isinstance(result, str) and result.startswith(S3XComBackend.PREFIX):
           S3XComBackend._assert_s3_backend()
           hook = S3Hook()
           key = result.replace(f"{S3XComBackend.PREFIX}://{S3XComBackend.BUCKET_NAME}/", "")
           filename = hook.download_file(
               key=key,
               bucket_name=S3XComBackend.BUCKET_NAME,
               local_path="/tmp"
           )
           result = pd.read_csv(filename)

       return result

And that's it! Now you can use this custom xcom implementation in your dags and modify the serialization function to your needs, for example you can add support for numpy arrays or any other format that you need.

If you'd like to see the full code in GitHub you can check out our repository, and as always, if you are looking for assistance in your company's machine learning and cloud initiatives, reach out to us today!

Peak of a mountain
Pilotcore

Your Pilot in the Cloud

Contact us today to discuss your cloud strategy! There is no obligation.

Let's Talk