KafkaProducer connection pool in Python: Part 1

This article focuses on implementing a thread-safe connection pool for KafkaProducer to avoid application startup fails due to Kafka servers downtime

Introduction

Note: Part 2 of this series is available here

Note: The complete implementation is in my kafka-python-producer-pool repository.

Kafka is an event-streaming platform used for building event-driven systems. Kafka hinges on three things: Topics, Producers, and Consumers. Producers publish messages under a Topic, and Consumers subscribe to the Topic and consume messages from the Topic

It doesn't use the PubSub architecture as publishers of events to topics are not aware of subscribers to the said topics.

Kafka in Django (insert any other framework)

Using Kafka with Django is quite straightforward and all you have to do is import KafkaProducer from the kafka-python library in the settings.pyfile:

from kafka import KafkaProducer

BROKER = KafkaProducer(
    bootstrap_servers=KAFKA_BOOTSTRAP_SERVERS,
)

What can go wrong?

Problem 1

Now KafkaProducer provided by the kafka-python library is thread-safe and can be used safely across threads. The docs suggest initializing Kafka and passing this instance everywhere it is needed instead of initializing in every place you use it. Since we’re talking about sync Django and not async and this means:

  • For each request, a thread is spawned

  • If you want to initialize an object to be used across threads, you must run that code as part of the start-up code, which puts that automatically in the settings file as we have done above.

Problem 2

But we don’t want application initialization to fail simply because our Kafka servers are down for some remote reason. To solve this, we decided to initialize Kafka at every place it is needed.

Problem 3

But the application uses Kafka in several places and this might cause performance problems when multiple messages are sent in different parts of the code.

Solution: Singleton. Use a singleton to gift out the KafkaProducer instance to threads.

Problem 4

The singleton must be thread-safe. To fix this, we simply have to write a thread-safe singleton.

Problem 5

Threads and ultimately requests will have to wait for locks to be released and this will impact latency

Singleton with a controlled number of instances: factory singleton or more correctly, a connection pool. This means we define the maximum number of instances that can be available at any one point in time. Then KafkaProducerPool as we will implement, dishes out instances to threads that need it or make them wait till a lock on one of the instances is released for use.

Implementation of a KafkaProducerPool

Before going further, there are reasons I am implementing this pool this way and I will address the questions when we are done with the implementation:

  • Why can't we use an array and simply push and pop. With this we can scale up and scale down connections as needed?

  • Why not initialize all the connections at the first time?

Implementation

Note: I will do lots of type annotations as is my convention when writing Python code. Beer with me.

First we implement a container to hold each KafkaProducer instance and its respective lock object:

import threading
from dataclasses import dataclass

from kafka import KafkaProducer


@dataclass
class KafkaProducerInstance:
    producer: KafkaProducer
    lock: threading.Lock

When we initialize our pool, we can set the number of these instances we want.

Next we define the KafkaConnectionPool class which will contain the pool implementation:

class KafkaProducerPool:
    # number of maximum instances
    INSTANCE_LIMIT: int = 10
    # holds all KafkaProducer instances
    _instances: dict[int, KafkaProducerInstance] = {}

The __new__ class method

Contrary to popular belief, the first method called when you instantiate a Python class isn't the __init__ method, but the __new__ method which is in fact a static method. This is in fact the method invoked when you initialize the class. This method by default creates the instance and then passes the instance to the __init__ method as self. You can then add state to the instance. That's why your __init__ method must be initialized with self.

Since we are implementing a connection pool and we need to control how an instance is initialized or whether it is initialized at all, we must override the __new__ method instead of the __init__ method. This method contains the top-level logic for the whole class so we can go ahead to put our expected code there and call expected methods there. Let's first do that in pseudocode:

class KafkaProducerPool:
    ...
    def __new__(cls):
        if no _instances exist:
            create one and add it to the pool
        instance = get a free instance from the pool
        if a free instance cannot be gotten:
            create one and add it to the pool
            instance = get a random instance from the pool
        return instance

And in Python would be:

import random
...

class KafkaProducerPool:
    ...
    def __new__(cls):
        if not cls._instances:
            cls._provision_instance()
        instance = cls._get_free_instance()
        if not instance:
            cls._provision_instance()
            random_index = random.randint(1, len(cls._instances))
            instance = cls._get_random_instance(random_index)
        return instance

The private methods _provision_instance, _get_free_instance, _get_random_instance will be implemented later on as class methods. You know why they are class method, yes? Because they are not called on an instance but on the class since the instance has not been initialized yet. They do not take in self.

Let's look at _provision_instance

_provision_instance has a simple job: Create a new instance of KafkaProducer and add it to the pool. Yet, we must be careful at this point; we want to avoid two things:

  • Creating more instances than necessary

  • Race conditions when creating instances.

For the first, we can simply check with:

if len(cls._instances) >= cls.INSTANCE_LIMIT:
    # create instance here

For the second, we need to place a lock over the creation of instances to allow only one thread create a resource at any point in time. First we add a _creation_lock :

class KafkaProducerPool:
    ...
    # lock on the `_instances` dict to make creation of
    # new instances thread-safe
    _creation_lock: threading.Lock = threading.Lock()

Then we do a check before creating the instance as in the previous code snippet, bringing the _provision_instance method to be:

...
# ideally, you want to grab this from the settings
KAFKA_BOOTSTRAP_SERVERS = ["128.122.1.1:8900"]

class KafkaProducerPool:
    ...
    # lock on the `_instances` dict to make creation of
    # new instances thread-safe
    _creation_lock: threading.Lock = threading.Lock()

    ...

    @classmethod
    def _provision_instance(cls):
        """
        Creates a new instance of the message broker and adds it to the pool.
        This method is thread-safe and is used to create new instances when
        all instances are busy and there is space to create instances.
        """
        with cls._creation_lock:
            if (instance_length := len(cls._instances)) >= cls.INSTANCE_LIMIT:
                # raising an exception is expensive in this context
                return

            producer: KafkaProducer = KafkaProducer(
                bootstrap_servers=KAFKA_BOOTSTRAP_SERVERS,
            )
            instance: KafkaProducerInstance = KafkaProducerInstance(
                producer=producer, lock=threading.Lock()
            )
            cls._instances[instance_length + 1] = instance

And there's it!

You may wonder, threading.Lock is expensive to use, why not check for instance length before you acquire the lock?

If we do that, we run the risk of two concurrent threads thinking it is okay to create an instance (when one space is left) and then we end up with INSTANCE_LIMIT += 1which isn't what we want. Instead, we place a lock directly before the check so all checks can be correct without race conditions.

Implementing _get_free_instance

In the case of _get_free_instance we simply loop through the instances and return one that isn't currently in use. If there is none available, we return None. The implementation is like so:

from typing import Optional
...

class KafkaProducerPool:
    ...

    @classmethod
    def _get_free_instance(cls) -> Optional[KafkaProducer]:
        """
        Retrieves a free instance of the message broker. If no free instance
        is found, `None` is returned.
        """
        if not cls._instances:
            return None

        for _, instance in cls._instances.items():
            if not instance.lock.locked():
                with instance.lock:
                    return instance.producer
        return None

In the case of this method, we aren't really concerned with a global lock. We also aren't strict on whether an instance is really free with the check if not instance.lock.locked(): else we'd place a central lock on that too. Instead we use a much more liberal approach since we can afford to: if another thread had acquired the lock as at the check, the current thread will simply wait for that specific instance to be released. This works because a thread can afford to wait here without impacting performance most of the time.

If you like, you can place a central lock to enforce the checks and grab a free instance when the thread finds one.

Lastly, we

Implement _get_random_instance

This method receives a random index (produced by random.choice or any random number generator, and acquires the lock of the instance at that index in the KafkaProducerPool._instancesdictionary and yields it to the current thread. It's a 3-line implementation:

class KafkaProducerPool:
    ...

    @classmethod
    def _get_random_instance(cls, index: int) -> KafkaProducer:
        """
        Retrieves a random instance of the message broker. This method is used
        when no free instance is found. The index is used to determine the
        instance to retrieve in case all instances are busy. If the index is
        out of range, the first instance is returned.
        """
        instance = cls._instances.get(index, cls._instances[1])
        with instance.lock:
            return instance.producer

We return the first instance if an out-of-bound index is provided.

With this, our implementation is complete and the KafkaProducerPool implementation should look like:

import threading
import random
from dataclasses import dataclass
from typing import Optional

from kafka import KafkaProducer

# ideally, you want to grab this from the settings
KAFKA_BOOTSTRAP_SERVERS = ["128.122.1.1:8900"]

@dataclass
class KafkaProducerInstance:
    producer: KafkaProducer
    lock: threading.Lock


class KafkaProducerPool:
    # number of maximum instances
    INSTANCE_LIMIT: int = 10
    # holds all KafkaProducer instances
    _instances: dict[int, KafkaProducerInstance] = {}
    # lock on the `_instances` dict to make creation of
    # new instances thread-safe
    _creation_lock: threading.Lock = threading.Lock()

    def __new__(cls):
        if not cls._instances:
            cls._provision_instance()
        instance = cls._get_free_instance()
        if not instance:
            cls._provision_instance()
            random_index = random.randint(1, len(cls._instances))
            instance = cls._get_random_instance(random_index)
        return instance

    @classmethod
    def _provision_instance(cls):
        """
        Creates a new instance of the message broker and adds it to the pool.
        This method is thread-safe and is used to create new instances when
        all instances are busy and there is space to create instances.
        """
        with cls._creation_lock:
            if (instance_length := len(cls._instances)) >= cls.INSTANCE_LIMIT:
                # raising an exception is expensive in this context
                return

            producer: KafkaProducer = KafkaProducer(
                bootstrap_servers=KAFKA_BOOTSTRAP_SERVERS,
            )
            instance: KafkaProducerInstance = KafkaProducerInstance(
                producer=producer, lock=threading.Lock()
            )
            cls._instances[instance_length + 1] = instance

    @classmethod
    def _get_free_instance(cls) -> Optional[KafkaProducer]:
        """
        Retrieves a free instance of the message broker. If no free instance
        is found, `None` is returned.
        """
        if not cls._instances:
            return None

        for _, instance in cls._instances.items():
            if not instance.lock.locked():
                with instance.lock:
                    return instance.producer
        return None

    @classmethod
    def _get_random_instance(cls, index: int) -> KafkaProducer:
        """
        Retrieves a random instance of the message broker. This method is used
        when no free instance is found. The index is used to determine the
        instance to retrieve in case all instances are busy. If the index is
        out of range, the first instance is returned.
        """
        instance = cls._instances.get(index, cls._instances[1])
        with instance.lock:
            return instance.producer

Testing

How are we certain that this works? We will find out with a test case.

What do we test?

We want to test that when n threads attempt to access KafkaProducer which has a maximum limit of x = n/2, x instances are created for n/2 threads and the remaining threads have to wait for locks to be released on other instances before they can proceed.

How do we test?

I prefer using Pytest and I'll write test cases with this assertion. First we have to mock a few things: INSTANCE_LIMIT and KafkaProducer. We can't afford to work with 10 maximum instances so we'll reduce that to 2. In addition, if KafkaProducer isn't mocked, the test case will always trigger a connection to Kafka servers that may have been deployed.

Mock time

Implement a mock class for KafkaProducer:

class MockKafkaProducer:
    def __init__(self, *args, **kwargs):
        pass

unittest.mock provides a patch function for applying mocks which we will use in our test cases. Now, let's get on to the test case

Writing tests

We start by creating a function that simulates expensive work:

import time

def work(sleep: int = 4, thread: int = 1):
    print(f"Starting Thread: {thread}")
    result = KafkaProducerPool()
    # mocking expensive operation
    time.sleep(sleep)
    return result

Next the test case:

...
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import patch

from .kafka import KafkaProducerPool

...

MAX_INSTANCE_LIMIT = 2

# applying mocks

@patch("kafka.KafkaProducer", MockKafkaProducer)
@patch("kafka.KafkaProducerPool.INSTANCE_LIMIT", MAX_INSTANCE_LIMIT)
def test_kafka_producer_pool_will_return_same_instance():
    # spawn a thread pool with 7 threads. With this, 7
    # threads are available to access 2 instances
    with ThreadPoolExecutor(max_workers=7) as executor:
        producer1 = executor.submit(work, thread=1)
        producer2 = executor.submit(work, thread=2)
        producer3 = executor.submit(work, thread=3)
        producer4 = executor.submit(work, thread=4)
        producer5 = executor.submit(work, thread=5)
        producer6 = executor.submit(work, thread=6)
        producer7 = executor.submit(work, thread=7)

        # force all threads to completion and grab the results
        producer1 = producer1.result()
        producer2 = producer2.result()
        producer3 = producer3.result()
        producer4 = producer4.result()
        producer5 = producer5.result()
        producer6 = producer6.result()
        producer7 = producer7.result()

    # confirm that the first instance produced was reused
    assert producer1 in [
        producer2,
        producer3,
        producer4,
        producer5,
        producer6,
        producer7,
    ]

    # confirm that the maximum number of unique instances across
    # all threads is the same as the set limit
    assert len(KafkaProducerPool._instances) == MAX_INSTANCE_LIMIT

Did the test fail?

I got this result on the first test run:

==================================================================== short test summary info ====================================================================
FAILED tests/test_kafka/test_kafka.py::test_kafka_producer_pool_will_return_same_instance - assert 1 == 2
================================================================= 1 failed, 1 warning in 4.21s ==================================================================

This one's tricky, but I'll give you a moment to think about it...

...

...

...

...

...

...

...

...

...

...

tests/test_shared/test_event_broker/test_base.py .                                                                                                        [100%]

======================================================================= warnings summary ========================================================================
../usr/local/lib/python3.9/site-packages/rest_framework_simplejwt/__init__.py:1
  /usr/local/lib/python3.9/site-packages/rest_framework_simplejwt/__init__.py:1: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
    from pkg_resources import DistributionNotFound, get_distribution

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================================= 1 passed, 1 warning in 8.10s ==================================================================

And it is passing on a subsequent run...

What happened?

Well, testing multithreaded code can be unreliable. In the first case, the threads didn't have to wait long for a lock to be released on the resource, hence a single instance served all the threads. In the second case (as should happen most times), the threads were starved and this triggered the creation of another KafkaProducer instance which happens to be the limit.

Let's satisfy Pytest

Although our implementation works fine, this test can fail at any time so we can mark it as expected to fail like so:

...
import pytest

@pytest.mark.xfail
@patch("shared.kafka.KafkaProducer", MockKafkaProducer)
@patch("shared.kafka.KafkaProducerPool.INSTANCE_LIMIT", MAX_INSTANCE_LIMIT)
def test_kafka_producer_pool_will_return_same_instance():
    ...

and on subsequent runs, we get:

tests/test_kafka/test_kafka.py X                                                                                                                          [100%]

======================================================================= warnings summary ========================================================================
../usr/local/lib/python3.9/site-packages/rest_framework_simplejwt/__init__.py:1
  /usr/local/lib/python3.9/site-packages/rest_framework_simplejwt/__init__.py:1: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
    from pkg_resources import DistributionNotFound, get_distribution

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================================= 1 xpassed, 1 warning in 4.11s =================================================================

Conclusion

This isn't all, there is a part 2 with a much better implementation. I will also answer some questions about implementation choice and the GIL. Check it out.

I hope this adds to your great library of knowledge and fixes concerns you have about performance with Kafka and Python. Note: The complete implementation is in a section of my kafka-python-producer-pool repository.

Don't have a great year yet, read the second part 🔥