Become a Celery expert TODAY! Sign up for my newsletter.
Tell me where to send your free Celery Bootcamp lessons.

Testing Celery Chains

How to write unit tests for Celery task chains

Published on May 15, 2018
Estimated reading time: 5 minutes
The full source code is available on https://github.com/ZoomerAnalytics/python-celery-testing-chains

Celery chains allow you to modularise your application and reuse common Celery tasks. A classic use case is a market data system we at Zoomer Analytics built for a hedge fund client.

The aim was to consume market data from different data vendors such as Bloomberg or Reuters. The APIs were of all different kinds and shapes but ultimately the data ended in the same database table.

By chaining the Celery tasks we only had to build a small number of specialised feed and data transformation tasks for each vendor but could reuse common tasks such as deserialising and writing the data.

An example

Let's say we want to download cyrptocurrency timeseries data from a number of different APIs and calculate the moving averages for each of these timeseries.

For instance, one of these timeseries would be the Bitcoin Price Index, available via the Coindesk API.

After careful consideration of reusability and separation of concerns, we decide to implement a generic Celery task to calculate the moving average, expecting as parameters a list of dicts - [{"date": "2018-05-01": "value": 1000.0}, {"date": "2018-05-02": "value": 1003.5}, ...] - and the number of days we want to calculate the moving average over:

import pandas as pd
import numpy as np
from celery import chain
from worker import app

@app.task(bind=True, name='calculate_moving_average')
def calculate_moving_average(self, timeseries, window):
    df = pd.DataFrame(timeseries)
    df['ma'] = df['value'].rolling(window=window, center=False).mean()
    return list(df.replace(np.nan, '', regex=True).T.to_dict().values())

We can now implement any data feed Celery task for any given API; as long as that task returns a list of dicts of the expected format [{"date": "2018-05-01": "value": 1000.0}, {"date": "2018-05-02": "value": 1003.5}, ...], we can simply chain these tasks to be able to calculate the moving average for any API result. For instance, the Celery task for the Bitcoin Price Index Coindesk feed looks like this:

import requests
from worker import app

@app.task(bind=True, name='fetch_bitcoin_price_index')
def fetch_bitcoin_price_index(self, start_date, end_date):
    url = f'https://api.coindesk.com/v1/bpi/historical/close.json?start={start_date}&end={end_date}'
    response = requests.get(url)
    if not response.ok:
        raise ValueError(f'Unexpected status code: {response.status_code}')
    return [{'date': key, 'value': value} for key, value in response.json()['bpi'].items()]

How do we go about testing Celery chains? As usually, there is more than answer. Let's have a look at two different strategies and discuss which one makes sense in which context.

Mocking the Celery chain

Previously, we discussed the importance of unit-testing Celery tasks. Assuming we have our Celery tasks test-covered, the only thing we are really interested in when it comes to testing chained tasks is that the chain itself does the right thing.

In other words, we need to test, wherever we invoke a Celery chain, that the individual tasks are called in the correct order with the correct arguments. Whatever happens inside the task is not our concern as that is already covered by the unit test.

from flask import Flask, Response, request, jsonify
from tasks import fetch_bitcoin_price_index, calculate_moving_average
from celery import chain
        
app = Flask(__name__)

@app.route('/', methods=['POST'])
def index():
    chain(
        fetch_bitcoin_price_index.s(
            start_date=request.json['start_date'],
            end_date=request.json['end_date']),
        calculate_moving_average.s(window=request.json['window'])
    ).delay()
    return '', 201

The chain is invoked in this Flask view. Hence, this is what we need to write the test for (another common approach would be to implemented a dedicated method that invokes the Celery chain and write test against that method).


import app
from unittest import TestCase, mock
        
class Tests(TestCase):
    
    def setUp(self):
        app.app.config['TESTING'] = True
        self.client = app.app.test_client()

    @mock.patch('app.chain')
    @mock.patch('app.fetch_bitcoin_price_index')
    @mock.patch('app.calculate_moving_average')
    def test_mocked_chain(self, mock_calculate_moving_average, mock_fetch_bitcoin_price_index, mock_chain):
        response = self.client.post('/', json={'start_date': '2018-05-01', 'end_date': '2018-05-08', 'window': 3})
        self.assertEqual(response.status_code, 201)
        mock_chain.assert_called_once_with(
            mock_fetch_bitcoin_price_index.s(start_date='2018-05-01', end_date='2018-05-08'),
            mock_calculate_moving_average.s(window=3))

By mocking the actual Celery chain and the Celery tasks inside the chain, we are able to assert the order in which the tasks are called and their respective arguments.

Testing the chain synchronously

There is an alternative approach. Instead of mocking the chain and the tasks, we can just test the whole lot in one go by calling the task chain synchronously (in the same way we did it for unit-testing individual Celery tasks).

import unittest
import responses
import tasks

from unittest import TestCase
from celery import chain


class Tests(TestCase):

    @responses.activate
    def test_chain(self):
        responses.add(responses.GET, 'https://api.coindesk.com/v1/bpi/historical/close.json?start=2018-05-01&end=2018-05-08', body='{"bpi":{"2018-05-01":9067.715,"2018-05-02":9219.8638,"2018-05-03":9734.675,"2018-05-04":9692.7175,"2018-05-05":9826.5975,"2018-05-06":9619.1438,"2018-05-07":9362.5338,"2018-05-08":9180.1588},"disclaimer":"This data was produced from the CoinDesk Bitcoin Price Index. BPI value data returned as USD.","time":{"updated":"May 9, 2018 00:03:00 UTC","updatedISO":"2018-05-09T00:03:00+00:00"}}', status=200)
        task = chain(
            tasks.fetch_bitcoin_price_index.s(start_date='2018-05-01', end_date='2018-05-08'),
            tasks.calculate_moving_average.s(window=2)).apply()
        self.assertEqual(task.status, 'SUCCESS')
        self.assertEqual(task.result, [
            {'date': '2018-05-01', 'ma': '', 'value': 9067.715},
            {'date': '2018-05-02', 'ma': 9143.7894, 'value': 9219.8638},
            {'date': '2018-05-03', 'ma': 9477.2694, 'value': 9734.675},
            {'date': '2018-05-04', 'ma': 9713.69625, 'value': 9692.7175},
            {'date': '2018-05-05', 'ma': 9759.657500000001, 'value': 9826.5975},
            {'date': '2018-05-06', 'ma': 9722.87065, 'value': 9619.1438},
            {'date': '2018-05-07', 'ma': 9490.8388, 'value': 9362.5338},
            {'date': '2018-05-08', 'ma': 9271.346300000001, 'value': 9180.1588}
        ])

Here we use responsesto mock out the requests call but other than that we let our test actually execute the chain and the two chained tasks themselves. If you do (and you should) unit-test your Celery tasks, you end up with redundant tests which in turn is a drag on your development flow. However, this test setup might make sense if you have Celery tasks that are always called as part of a chain.

How to apply this

Asynchronously linking Celery tasks via task chains is a powerful building block for constructing complex workflows (think Lego). Testing Celery chains is as important as unit-testing individual Celery tasks. Mocking the Celery chain and the chained tasks is an easy and effective way to stay on top of your Celery workflow, however complex.

Posted on May 15, 2018