langchain/libs/community/langchain_community/chat_message_histories/kafka.py
Chang Liu 62c8a67f56
community: add KafkaChatMessageHistory (#22216)
Add chat history store based on Kafka.

Files added: 
`libs/community/langchain_community/chat_message_histories/kafka.py`
`docs/docs/integrations/memory/kafka_chat_message_history.ipynb`

New issue to be created for future improvement:
1. Async method implementation.
2. Message retrieval based on timestamp.
3. Support for other configs when connecting to cloud hosted Kafka (e.g.
add `api_key` field)
4. Improve unit testing & integration testing.
2024-06-17 20:34:01 -07:00

363 lines
13 KiB
Python

""" Kafka-based chat message history by using confluent-kafka-python.
confluent-kafka-python is under Apache 2.0 license.
https://github.com/confluentinc/confluent-kafka-python
"""
from __future__ import annotations
import json
import logging
import time
from enum import Enum
from typing import TYPE_CHECKING, List, Optional, Sequence
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict
if TYPE_CHECKING:
from confluent_kafka import TopicPartition
from confluent_kafka.admin import AdminClient
logger = logging.getLogger(__name__)
BOOTSTRAP_SERVERS_CONFIG = "bootstrap.servers"
DEFAULT_TTL_MS = 604800000 # 7 days
DEFAULT_REPLICATION_FACTOR = 1
DEFAULT_PARTITION = 3
class ConsumeStartPosition(Enum):
"""Consume start position for Kafka consumer to get chat history messages.
LAST_CONSUMED: Continue from the last consumed offset.
EARLIEST: Start consuming from the beginning.
LATEST: Start consuming from the latest offset.
"""
LAST_CONSUMED = 1
EARLIEST = 2
LATEST = 3
def ensure_topic_exists(
admin_client: AdminClient,
topic_name: str,
replication_factor: int,
partition: int,
ttl_ms: int,
) -> int:
"""Create topic if it doesn't exist, and return the number of partitions.
If the topic already exists, we don't change the topic configuration.
"""
from confluent_kafka.admin import NewTopic
try:
topic_metadata = admin_client.list_topics().topics
if topic_name in topic_metadata:
num_partitions = len(topic_metadata[topic_name].partitions)
logger.info(
f"Topic {topic_name} already exists with {num_partitions} partitions"
)
return num_partitions
except Exception as e:
logger.error(f"Failed to list topics: {e}")
raise e
topics = [
NewTopic(
topic_name,
num_partitions=partition,
replication_factor=replication_factor,
config={"retention.ms": str(ttl_ms)},
)
]
try:
futures = admin_client.create_topics(topics)
for _, f in futures.items():
f.result() # result is None
logger.info(f"Topic {topic_name} created")
except Exception as e:
logger.error(f"Failed to create topic {topic_name}: {e}")
raise e
return partition
class KafkaChatMessageHistory(BaseChatMessageHistory):
"""Chat message history stored in Kafka.
Setup:
Install ``confluent-kafka-python``.
.. code-block:: bash
pip install confluent_kafka
Instantiate:
.. code-block:: python
from langchain_community.chat_message_histories import KafkaChatMessageHistory
history = KafkaChatMessageHistory(
session_id="your_session_id",
bootstrap_servers="host:port",
)
Add and retrieve messages:
.. code-block:: python
# Add messages
history.add_messages([message1, message2, message3, ...])
# Retrieve messages
message_batch_0 = history.messages
# retrieve messages after message_batch_0
message_batch_1 = history.messages
# Reset to beginning and retrieve messages
messages_from_beginning = history.messages_from_beginning()
Retrieving messages is stateful. Internally, it uses Kafka consumer to read.
The consumed offset is maintained persistently.
To retrieve messages, you can use the following methods:
- `messages`:
continue consuming chat messages from last one.
- `messages_from_beginning`:
reset the consumer to the beginning of the chat history and return messages.
Optional parameters:
1. `max_message_count`: maximum number of messages to return.
2. `max_time_sec`: maximum time in seconds to wait for messages.
- `messages_from_latest`:
reset to end of the chat history and try consuming messages.
Optional parameters same as above.
- `messages_from_last_consumed`:
continuing from the last consumed message, similar to `messages`.
Optional parameters same as above.
`max_message_count` and `max_time_sec` are used to avoid blocking indefinitely
when retrieving messages. As a result, the method to retrieve messages may not
return all messages. Change `max_message_count` and `max_time_sec` to retrieve
all history messages.
""" # noqa: E501
def __init__(
self,
session_id: str,
bootstrap_servers: str,
ttl_ms: int = DEFAULT_TTL_MS,
replication_factor: int = DEFAULT_REPLICATION_FACTOR,
partition: int = DEFAULT_PARTITION,
):
"""
Args:
session_id: The ID for single chat session. It is used as Kafka topic name.
bootstrap_servers:
Comma-separated host/port pairs to establish connection to Kafka cluster
https://kafka.apache.org/documentation.html#adminclientconfigs_bootstrap.servers
ttl_ms:
Time-to-live (milliseconds) for automatic expiration of entries.
Default 7 days. -1 for no expiration.
It translates to https://kafka.apache.org/documentation.html#topicconfigs_retention.ms
replication_factor: The replication factor for the topic. Default 1.
partition: The number of partitions for the topic. Default 3.
"""
try:
from confluent_kafka import Producer
from confluent_kafka.admin import AdminClient
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Could not import confluent_kafka package. "
"Please install it with `pip install confluent_kafka`."
)
self.session_id = session_id
self.bootstrap_servers = bootstrap_servers
self.admin_client = AdminClient({BOOTSTRAP_SERVERS_CONFIG: bootstrap_servers})
self.num_partitions = ensure_topic_exists(
self.admin_client, session_id, replication_factor, partition, ttl_ms
)
self.producer = Producer({BOOTSTRAP_SERVERS_CONFIG: bootstrap_servers})
def add_messages(
self,
messages: Sequence[BaseMessage],
flush_timeout_seconds: float = 5.0,
) -> None:
"""Add messages to the chat history by producing to the Kafka topic."""
try:
for message in messages:
self.producer.produce(
topic=self.session_id,
value=json.dumps(message_to_dict(message)),
)
message_remaining = self.producer.flush(flush_timeout_seconds)
if message_remaining > 0:
logger.warning(f"{message_remaining} messages are still in-flight.")
except Exception as e:
logger.error(f"Failed to add messages to Kafka: {e}")
raise e
def __read_messages(
self,
consume_start_pos: ConsumeStartPosition,
max_message_count: Optional[int],
max_time_sec: Optional[float],
) -> List[BaseMessage]:
"""Retrieve messages from Kafka topic for the session.
Please note this method is stateful. Internally, it uses Kafka consumer
to consume messages, and maintains the consumed offset.
Args:
consume_start_pos: Start position for Kafka consumer.
max_message_count: Maximum number of messages to consume.
max_time_sec: Time limit in seconds to consume messages.
Returns:
List of messages.
"""
from confluent_kafka import OFFSET_BEGINNING, OFFSET_END, Consumer
consumer_config = {
BOOTSTRAP_SERVERS_CONFIG: self.bootstrap_servers,
"group.id": self.session_id,
"auto.offset.reset": "latest"
if consume_start_pos == ConsumeStartPosition.LATEST
else "earliest",
}
def assign_beginning(
assigned_consumer: Consumer, assigned_partitions: list[TopicPartition]
) -> None:
for p in assigned_partitions:
p.offset = OFFSET_BEGINNING
assigned_consumer.assign(assigned_partitions)
def assign_latest(
assigned_consumer: Consumer, assigned_partitions: list[TopicPartition]
) -> None:
for p in assigned_partitions:
p.offset = OFFSET_END
assigned_consumer.assign(assigned_partitions)
messages: List[dict] = []
consumer = Consumer(consumer_config)
try:
if consume_start_pos == ConsumeStartPosition.EARLIEST:
consumer.subscribe([self.session_id], on_assign=assign_beginning)
elif consume_start_pos == ConsumeStartPosition.LATEST:
consumer.subscribe([self.session_id], on_assign=assign_latest)
else:
consumer.subscribe([self.session_id])
start_time_sec = time.time()
while True:
if (
max_time_sec is not None
and time.time() - start_time_sec > max_time_sec
):
break
if max_message_count is not None and len(messages) >= max_message_count:
break
message = consumer.poll(timeout=1.0)
if message is None: # poll timeout
continue
if message.error() is not None: # error
logger.error(f"Consumer error: {message.error()}")
continue
if message.value() is None: # empty value
logger.warning("Empty message value")
continue
messages.append(json.loads(message.value()))
except Exception as e:
logger.error(f"Failed to consume messages from Kafka: {e}")
raise e
finally:
consumer.close()
return messages_from_dict(messages)
def messages_from_beginning(
self, max_message_count: Optional[int] = 5, max_time_sec: Optional[float] = 5.0
) -> List[BaseMessage]:
"""Retrieve messages from Kafka topic from the beginning.
This method resets the consumer to the beginning and consumes messages.
Args:
max_message_count: Maximum number of messages to consume.
max_time_sec: Time limit in seconds to consume messages.
Returns:
List of messages.
"""
return self.__read_messages(
consume_start_pos=ConsumeStartPosition.EARLIEST,
max_message_count=max_message_count,
max_time_sec=max_time_sec,
)
def messages_from_latest(
self, max_message_count: Optional[int] = 5, max_time_sec: Optional[float] = 5.0
) -> List[BaseMessage]:
"""Reset to the end offset. Try to consume messages if available.
Args:
max_message_count: Maximum number of messages to consume.
max_time_sec: Time limit in seconds to consume messages.
Returns:
List of messages.
"""
return self.__read_messages(
consume_start_pos=ConsumeStartPosition.LATEST,
max_message_count=max_message_count,
max_time_sec=max_time_sec,
)
def messages_from_last_consumed(
self, max_message_count: Optional[int] = 5, max_time_sec: Optional[float] = 5.0
) -> List[BaseMessage]:
"""Retrieve messages from Kafka topic from the last consumed message.
Please note this method is stateful. Internally, it uses Kafka consumer
to consume messages, and maintains the commit offset.
Args:
max_message_count: Maximum number of messages to consume.
max_time_sec: Time limit in seconds to consume messages.
Returns:
List of messages.
"""
return self.__read_messages(
consume_start_pos=ConsumeStartPosition.LAST_CONSUMED,
max_message_count=max_message_count,
max_time_sec=max_time_sec,
)
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""
Retrieve the messages for the session, from Kafka topic continuously
from last consumed message. This method is stateful and maintains
consumed(committed) offset based on consumer group.
Alternatively, use messages_from_last_consumed() with specified parameters.
Use messages_from_beginning() to read from the earliest message.
Use messages_from_latest() to read from the latest message.
"""
return self.messages_from_last_consumed()
def clear(self) -> None:
"""Clear the chat history by deleting the Kafka topic."""
try:
futures = self.admin_client.delete_topics([self.session_id])
for _, f in futures.items():
f.result() # result is None
logger.info(f"Topic {self.session_id} deleted")
except Exception as e:
logger.error(f"Failed to delete topic {self.session_id}: {e}")
raise e
def close(self) -> None:
"""Release the resources.
Nothing to be released at this moment.
"""
pass