import json import logging from time import time from typing import TYPE_CHECKING, List, Optional from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import ( BaseMessage, message_to_dict, messages_from_dict, ) from langchain_elasticsearch._utilities import with_user_agent_header from langchain_elasticsearch.client import create_elasticsearch_client if TYPE_CHECKING: from elasticsearch import Elasticsearch logger = logging.getLogger(__name__) class ElasticsearchChatMessageHistory(BaseChatMessageHistory): """Chat message history that stores history in Elasticsearch. Args: es_url: URL of the Elasticsearch instance to connect to. es_cloud_id: Cloud ID of the Elasticsearch instance to connect to. es_user: Username to use when connecting to Elasticsearch. es_password: Password to use when connecting to Elasticsearch. es_api_key: API key to use when connecting to Elasticsearch. es_connection: Optional pre-existing Elasticsearch connection. esnsure_ascii: Used to escape ASCII symbols in json.dumps. Defaults to True. index: Name of the index to use. session_id: Arbitrary key that is used to store the messages of a single chat session. """ def __init__( self, index: str, session_id: str, *, es_connection: Optional["Elasticsearch"] = None, es_url: Optional[str] = None, es_cloud_id: Optional[str] = None, es_user: Optional[str] = None, es_api_key: Optional[str] = None, es_password: Optional[str] = None, esnsure_ascii: Optional[bool] = True, ): self.index: str = index self.session_id: str = session_id self.ensure_ascii = esnsure_ascii # Initialize Elasticsearch client from passed client arg or connection info if es_connection is not None: self.client = es_connection elif es_url is not None or es_cloud_id is not None: try: self.client = create_elasticsearch_client( url=es_url, username=es_user, password=es_password, cloud_id=es_cloud_id, api_key=es_api_key, ) except Exception as err: logger.error(f"Error connecting to Elasticsearch: {err}") raise err else: raise ValueError( """Either provide a pre-existing Elasticsearch connection, \ or valid credentials for creating a new connection.""" ) self.client = with_user_agent_header(self.client, "langchain-py-ms") if self.client.indices.exists(index=index): logger.debug( f"Chat history index {index} already exists, skipping creation." ) else: logger.debug(f"Creating index {index} for storing chat history.") self.client.indices.create( index=index, mappings={ "properties": { "session_id": {"type": "keyword"}, "created_at": {"type": "date"}, "history": {"type": "text"}, } }, ) @property def messages(self) -> List[BaseMessage]: # type: ignore[override] """Retrieve the messages from Elasticsearch""" try: from elasticsearch import ApiError result = self.client.search( index=self.index, query={"term": {"session_id": self.session_id}}, sort="created_at:asc", ) except ApiError as err: logger.error(f"Could not retrieve messages from Elasticsearch: {err}") raise err if result and len(result["hits"]["hits"]) > 0: items = [ json.loads(document["_source"]["history"]) for document in result["hits"]["hits"] ] else: items = [] return messages_from_dict(items) def add_message(self, message: BaseMessage) -> None: """Add a message to the chat session in Elasticsearch""" try: from elasticsearch import ApiError self.client.index( index=self.index, document={ "session_id": self.session_id, "created_at": round(time() * 1000), "history": json.dumps( message_to_dict(message), ensure_ascii=bool(self.ensure_ascii), ), }, refresh=True, ) except ApiError as err: logger.error(f"Could not add message to Elasticsearch: {err}") raise err def clear(self) -> None: """Clear session memory in Elasticsearch""" try: from elasticsearch import ApiError self.client.delete_by_query( index=self.index, query={"term": {"session_id": self.session_id}}, refresh=True, ) except ApiError as err: logger.error(f"Could not clear session memory in Elasticsearch: {err}") raise err