KafkaProducer connection pool in Python: Part 2

In this article, we explore a generic and modular approach to the KafkaConnectionPool class

Introduction

Note: Part 1 of this series is available here

Note: The complete implementation is in my kafka-python-producer-pool repository.

In part 1 of this series, we addressed the concept of a Kafka producer pool to provide safe, multi-threaded connections to Kafka servers.

In this part, we will explore a much cleaner, loosely coupled, and generalized approach to the KafkaProducerPool class.

Here's the code from the last time in case you have forgotten:

KafkaProducerPool

import threading
import random
from dataclasses import dataclass
from typing import Optional

from kafka import KafkaProducer

# ideally, you want to grab this from the settings
KAFKA_BOOTSTRAP_SERVERS = ["128.122.1.1:8900"]

@dataclass
class KafkaProducerInstance:
    producer: KafkaProducer
    lock: threading.Lock


class KafkaProducerPool:
    # number of maximum instances
    INSTANCE_LIMIT: int = 10
    # holds all KafkaProducer instances
    _instances: dict[int, KafkaProducerInstance] = {}
    # lock on the `_instances` dict to make creation of
    # new instances thread-safe
    _creation_lock: threading.Lock = threading.Lock()

    def __new__(cls):
        if not cls._instances:
            cls._provision_instance()
        instance = cls._get_free_instance()
        if not instance:
            cls._provision_instance()
            random_index = random.randint(1, len(cls._instances))
            instance = cls._get_random_instance(random_index)
        return instance

    @classmethod
    def _provision_instance(cls):
        """
        Creates a new instance of the message broker and adds it to the pool.
        This method is thread-safe and is used to create new instances when
        all instances are busy and there is space to create instances.
        """
        with cls._creation_lock:
            if (instance_length := len(cls._instances)) >= cls.INSTANCE_LIMIT:
                # raising an exception is expensive in this context
                return

            producer: KafkaProducer = KafkaProducer(
                bootstrap_servers=KAFKA_BOOTSTRAP_SERVERS,
            )
            instance: KafkaProducerInstance = KafkaProducerInstance(
                producer=producer, lock=threading.Lock()
            )
            cls._instances[instance_length + 1] = instance

    @classmethod
    def _get_free_instance(cls) -> Optional[KafkaProducer]:
        """
        Retrieves a free instance of the message broker. If no free instance
        is found, `None` is returned.
        """
        if not cls._instances:
            return None

        for _, instance in cls._instances.items():
            if not instance.lock.locked():
                with instance.lock:
                    return instance.producer
        return None

    @classmethod
    def _get_random_instance(cls, index: int) -> KafkaProducer:
        """
        Retrieves a random instance of the message broker. This method is used
        when no free instance is found. The index is used to determine the
        instance to retrieve in case all instances are busy. If the index is
        out of range, the first instance is returned.
        """
        instance = cls._instances.get(index, cls._instances[1])
        with instance.lock:
            return instance.producer

Test

...
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import patch

from .kafka import KafkaProducerPool

...

MAX_INSTANCE_LIMIT = 2

# applying mocks

@patch("kafka.KafkaProducer", MockKafkaProducer)
@patch("kafka.KafkaProducerPool.INSTANCE_LIMIT", MAX_INSTANCE_LIMIT)
def test_kafka_producer_pool_will_return_same_instance():
    # spawn a thread pool with 7 threads. With this, 7
    # threads are available to access 2 instances
    with ThreadPoolExecutor(max_workers=7) as executor:
        producer1 = executor.submit(work, thread=1)
        producer2 = executor.submit(work, thread=2)
        producer3 = executor.submit(work, thread=3)
        producer4 = executor.submit(work, thread=4)
        producer5 = executor.submit(work, thread=5)
        producer6 = executor.submit(work, thread=6)
        producer7 = executor.submit(work, thread=7)

        # force all threads to completion and grab the results
        producer1 = producer1.result()
        producer2 = producer2.result()
        producer3 = producer3.result()
        producer4 = producer4.result()
        producer5 = producer5.result()
        producer6 = producer6.result()
        producer7 = producer7.result()

    # confirm that the first instance produced was reused
    assert producer1 in [
        producer2,
        producer3,
        producer4,
        producer5,
        producer6,
        producer7,
    ]

    # confirm that the maximum number of unique instances across
    # all threads is the same as the set limit
    assert len(KafkaProducerPool._instances) == MAX_INSTANCE_LIMIT

Generics and great dependency injection

We've gotten a great implementation but I can do you one better: I don't use this implementation in production. Other test cases in future will depend on KafkaProducerPool and I don't feel so good mocking KafkaProducer when I can have producers that can be inspected more closely. We have no sure-fire way of looking into the instances produced with a simple print statement. Why can't we use an integer as the produced instances in tests and use KafkaProducer in production code?

To do that, we will make the KafkaProducerPool into a more generic form that works with any instance type.

KafkaProducerInstance refactor

Our KafkaProducerInstance has to be made generic and will morph into a SingletonInstance:

import threading
from dataclasses import dataclass
from typing import Generic, TypeVar

U = TypeVar("U")

@dataclass
class SingletonInstance(Generic[U]):
    producer: U
    lock: threading.Lock

First, we define a generic type U and pass that to the SingletonInstance. This means we can declare an instance of SingletonInstance like so:

int_singleton_instance: SingletonInstance[int] = SingletonInstance(
    producer=3, lock=threading.Lock
)

Or a float:

float_singleton_instance: SingletonInstance[float] = SingletonInstance(
    producer=3.4, lock=threading.Lock
)

And so on 😂 Awww come on! tell me I'm a genius.

KafkaProducerPool refactor

Now we define another generic type T which is passed to the ProducerPool (renamed from KafkaProducerPool) and then repassed to SingletonInstance:

...

T = TypeVar("T")

class ProducerPool(Generic[T]):
    # number of maximum instances
    INSTANCE_LIMIT: int = 10

    # holds the class of the message broker instance
    producer_class: T = None

    # holds all T instances
    _instances: dict[int, T] = {}

    # holds the singleton instance
    _singleton_instance: SingletonInstance[T] = SingletonInstance

    # lock on the `_instances` dict to make creation of
    # new instances thread-safe
    _creation_lock: threading.Lock = threading.Lock()

    def __new__(cls) -> T:
        if not cls._instances:
            cls._provision_instance()
        instance = cls._get_free_instance()
        if not instance:
            cls._provision_instance()
            random_index = random.randint(1, len(cls._instances))
            instance = cls._get_random_instance(random_index)
        return instance

The __new__ method remains the same without any changes. I added the return type signature for dramatic effect.

_provision_instance refactor

We'll do a few structural changes here:

...

class ProducerPool(Generic[T]):
    ...

    @classmethod
    def _provision_instance(cls):
        """
        Creates a new instance of the message broker and adds it to the pool.
        This method is thread-safe and is used to create new instances when
        all instances are busy and there is space to create instances.
        """
        with cls._creation_lock:
            if (instance_length := len(cls._instances)) >= cls.INSTANCE_LIMIT:
                # raising an exception is expensive in this context
                return

            producer = cls.create_instance()
            instance = cls._singleton_instance(
                producer=producer, lock=threading.Lock()
            )
            cls._instances[instance_length + 1] = instance

Here, we use cls._singleton_instance which is generically-typed instead of the concrete SingletonInstance.

Also, instead of initializing KafkaProducer here, we delegate that to a public method, create_instance(). This method would be overridden when we want to make use of it. The implementation of that method is straightforward:

...

class ProducerPool(Generic[T]):
    ...

    @classmethod
    def create_instance(cls) -> T:
        raise NotImplementedError

This ensures the method must be implemented at all costs.

_get_free_instance implementation

All we need to do here is update the type information:

from typing import Generic, Optional, TypeVar
...

class ProducerPool(Generic[T]):
    ...

    @classmethod
    def _get_free_instance(cls) -> Optional[T]:
        """
        Retrieves a free instance of the message broker. If no free instance
        is found, `None` is returned.
        """
        if not cls._instances:
            return None

        for _, instance in cls._instances.items():
            if not instance.lock.locked():
                with instance.lock:
                    return instance.producer
        return None

_get_random_instance implementation

Here, we do the same type signature updating:

...

class ProducerPool(Generic[T]):
    ...

    @classmethod
    def _get_random_instance(cls, index: int) -> T:
        """
        Retrieves a random instance of the message broker. This method is used
        when no free instance is found. The index is used to determine the
        instance to retrieve in case all instances are busy. If the index is
        out of range, the first instance is returned.
        """
        instance: cls._singleton_instance = cls._instances.get(index, cls._instances[1])
        with instance.lock:
            return instance.producer

KafkaProducerPool , the great part

The implementation of KafkaProducerPool is then simplified to a subclass of ProducerPool. We only override a few variables to get sufficient type information in addition to the create_instances method.

from kafka import KafkaProducer
from shared.event_broker.base import ProducerPool

# ideally, you want to grab this from the settings
KAFKA_BOOTSTRAP_SERVERS = ["128.122.1.1:8900"]

class KafkaProducerPool(ProducerPool[KafkaProducer]):
    producer_class = KafkaProducer

    @classmethod
    def create_instance(cls) -> KafkaProducer:
        return cls.producer_class(
            bootstrap_servers=KAFKA_BOOTSTRAP_SERVERS,
        )

Test for generic implementation

The test is made cleaner this time; we get rid of mocks and use concrete implementations:

import random
import time
from concurrent.futures import ThreadPoolExecutor

import pytest
from kafka.base import ProducerPool


class IntProducerPool(ProducerPool[int]):
    producer_class = int
    INSTANCE_LIMIT = 2

    @classmethod
    def create_instance(cls):
        return int(random.randint(1, 100))


def work(sleep: int = 4, thread: int = 1):
    time.sleep(sleep)
    result = IntProducerPool()

    # mocking expensive operation
    time.sleep(sleep)
    return result

MAX_INSTANCE_LIMIT = 2

@pytest.mark.xfail
def test_kafka_producer_pool_will_return_same_instance():
    with ThreadPoolExecutor(max_workers=7) as executor:
        producer1 = executor.submit(work, thread=1)
        producer2 = executor.submit(work, thread=2)
        producer3 = executor.submit(work, thread=3)
        producer4 = executor.submit(work, thread=4)
        producer5 = executor.submit(work, thread=5)
        producer6 = executor.submit(work, thread=6)
        producer7 = executor.submit(work, thread=7)

        producer1 = producer1.result()
        producer2 = producer2.result()
        producer3 = producer3.result()
        producer4 = producer4.result()
        producer5 = producer5.result()
        producer6 = producer6.result()
        producer7 = producer7.result()

    assert producer1 in [
        producer2,
        producer3,
        producer4,
        producer5,
        producer6,
        producer7,
    ]

    assert len(IntProducerPool._instances) == MAX_INSTANCE_LIMIT

Sorry folks, we can't fix the unreliability of the test. At least, I can't.

How to plug it in your application

Don't go ahead calling this class everywhere in your code. It should be specified in settings so it can be truly shared like so:

from shared.event_broker import EventProducerPool

# I re-exported it under a different name for easy
# changes if we decide to switch from Kafka
EventBroker = EventProducerPool

FAQs

Here are some questions you may have after reading this implementation

I used ChatGPT and it gave me a really elegant solution

I don’t trust it.

The fun answer: Plus they mutate the array of connections using push and pop too much so I don’t like it.

Not so fun answer: Doesn’t reuse connections. It’s just a class that limits number of usages of the producer

Do I really need a connection pool because of GIL?

Thinking back at it: If the implementation of Python may not allow concurrent execution of threads, is it safe to say that there is no need for a connection pool? I’ve not read the implementation (RFC, I mean) for WSGI so I can’t even say. If so, then Python applications are reeeeeeeeeeeally slow and I’ll write only Rust for backend moving forward or async Python using ASGI.

Conclusion

I hope this adds to your great library of knowledge and fixes concerns you have about performance with Kafka and Python.

Note: The complete implementation of this part is in a section of my kafka-python-producer-pool repository.