mirror of https://github.com/hwchase17/langchain
Add ElasticsearchChatMessageHistory (#10932)
**Description** This PR adds the `ElasticsearchChatMessageHistory` implementation that stores chat message history in the configured [Elasticsearch](https://www.elastic.co/elasticsearch/) deployment. ```python from langchain.memory.chat_message_histories import ElasticsearchChatMessageHistory history = ElasticsearchChatMessageHistory( es_url="https://my-elasticsearch-deployment-url:9200", index="chat-history-index", session_id="123" ) history.add_ai_message("This is me, the AI") history.add_user_message("This is me, the human") ``` **Dependencies** - [elasticsearch client](https://elasticsearch-py.readthedocs.io/) required Co-authored-by: Bagatur <baskaryan@gmail.com>pull/11745/head
parent
d3a5090e12
commit
008348ce71
@ -0,0 +1,186 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "683953b3",
|
||||
"metadata": {
|
||||
"id": "683953b3"
|
||||
},
|
||||
"source": [
|
||||
"# Elasticsearch Chat Message History\n",
|
||||
"\n",
|
||||
">[Elasticsearch](https://www.elastic.co/elasticsearch/) is a distributed, RESTful search and analytics engine, capable of performing both vector and lexical search. It is built on top of the Apache Lucene library.\n",
|
||||
"\n",
|
||||
"This notebook shows how to use chat message history functionality with Elasticsearch."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3c7720c3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Set up Elasticsearch\n",
|
||||
"\n",
|
||||
"There are two main ways to set up an Elasticsearch instance:\n",
|
||||
"\n",
|
||||
"1. **Elastic Cloud.** Elastic Cloud is a managed Elasticsearch service. Sign up for a [free trial](https://cloud.elastic.co/registration?storm=langchain-notebook).\n",
|
||||
"\n",
|
||||
"2. **Local Elasticsearch installation.** Get started with Elasticsearch by running it locally. The easiest way is to use the official Elasticsearch Docker image. See the [Elasticsearch Docker documentation](https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html) for more information."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cdf1d2b7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Install dependencies"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e5bbffe2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install elasticsearch langchain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8be8fcc3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Initialize Elasticsearch client and chat message history"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "8e2ee0fa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from langchain.memory import ElasticsearchChatMessageHistory\n",
|
||||
"\n",
|
||||
"es_url = os.environ.get(\"ES_URL\", \"http://localhost:9200\")\n",
|
||||
"\n",
|
||||
"# If using Elastic Cloud:\n",
|
||||
"# es_cloud_id = os.environ.get(\"ES_CLOUD_ID\")\n",
|
||||
"\n",
|
||||
"# Note: see Authentication section for various authentication methods\n",
|
||||
"\n",
|
||||
"history = ElasticsearchChatMessageHistory(\n",
|
||||
" es_url=es_url,\n",
|
||||
" index=\"test-history\",\n",
|
||||
" session_id=\"test-session\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a63942e2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Use the chat message history"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "c1c7be79",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"indexing message content='hi!' additional_kwargs={} example=False\n",
|
||||
"indexing message content='whats up?' additional_kwargs={} example=False\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"history.add_user_message(\"hi!\")\n",
|
||||
"history.add_ai_message(\"whats up?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c46c216c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Authentication\n",
|
||||
"\n",
|
||||
"## Username/password\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"es_username = os.environ.get(\"ES_USERNAME\", \"elastic\")\n",
|
||||
"es_password = os.environ.get(\"ES_PASSWORD\", \"changeme\")\n",
|
||||
"\n",
|
||||
"history = ElasticsearchChatMessageHistory(\n",
|
||||
" es_url=es_url,\n",
|
||||
" es_user=es_username,\n",
|
||||
" es_password=es_password,\n",
|
||||
" index=\"test-history\",\n",
|
||||
" session_id=\"test-session\"\n",
|
||||
")\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"### How to obtain a password for the default \"elastic\" user\n",
|
||||
"\n",
|
||||
"To obtain your Elastic Cloud password for the default \"elastic\" user:\n",
|
||||
"1. Log in to the Elastic Cloud console at https://cloud.elastic.co\n",
|
||||
"2. Go to \"Security\" > \"Users\"\n",
|
||||
"3. Locate the \"elastic\" user and click \"Edit\"\n",
|
||||
"4. Click \"Reset password\"\n",
|
||||
"5. Follow the prompts to reset the password\n",
|
||||
"\n",
|
||||
"## API key\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"es_api_key = os.environ.get(\"ES_API_KEY\")\n",
|
||||
"\n",
|
||||
"history = ElasticsearchChatMessageHistory(\n",
|
||||
" es_api_key=es_api_key,\n",
|
||||
" index=\"test-history\",\n",
|
||||
" session_id=\"test-session\"\n",
|
||||
")\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"### How to obtain an API key\n",
|
||||
"\n",
|
||||
"To obtain an API key:\n",
|
||||
"1. Log in to the Elastic Cloud console at https://cloud.elastic.co\n",
|
||||
"2. Open Kibana and go to Stack Management > API Keys\n",
|
||||
"3. Click \"Create API key\"\n",
|
||||
"4. Enter a name for the API key and click \"Create\""
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -0,0 +1,191 @@
|
||||
import json
|
||||
import logging
|
||||
from time import time
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from langchain.schema import BaseChatMessageHistory
|
||||
from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElasticsearchChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history that stores history in Elasticsearch.
|
||||
|
||||
Args:
|
||||
es_url: URL of the Elasticsearch instance to connect to.
|
||||
es_cloud_id: Cloud ID of the Elasticsearch instance to connect to.
|
||||
es_user: Username to use when connecting to Elasticsearch.
|
||||
es_password: Password to use when connecting to Elasticsearch.
|
||||
es_api_key: API key to use when connecting to Elasticsearch.
|
||||
es_connection: Optional pre-existing Elasticsearch connection.
|
||||
index: Name of the index to use.
|
||||
session_id: Arbitrary key that is used to store the messages
|
||||
of a single chat session.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index: str,
|
||||
session_id: str,
|
||||
*,
|
||||
es_connection: Optional["Elasticsearch"] = None,
|
||||
es_url: Optional[str] = None,
|
||||
es_cloud_id: Optional[str] = None,
|
||||
es_user: Optional[str] = None,
|
||||
es_api_key: Optional[str] = None,
|
||||
es_password: Optional[str] = None,
|
||||
):
|
||||
self.index: str = index
|
||||
self.session_id: str = session_id
|
||||
|
||||
# Initialize Elasticsearch client from passed client arg or connection info
|
||||
if es_connection is not None:
|
||||
self.client = es_connection.options(
|
||||
headers={"user-agent": self.get_user_agent()}
|
||||
)
|
||||
elif es_url is not None or es_cloud_id is not None:
|
||||
self.client = ElasticsearchChatMessageHistory.connect_to_elasticsearch(
|
||||
es_url=es_url,
|
||||
username=es_user,
|
||||
password=es_password,
|
||||
cloud_id=es_cloud_id,
|
||||
api_key=es_api_key,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"""Either provide a pre-existing Elasticsearch connection, \
|
||||
or valid credentials for creating a new connection."""
|
||||
)
|
||||
|
||||
if self.client.indices.exists(index=index):
|
||||
logger.debug(
|
||||
f"Chat history index {index} already exists, skipping creation."
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Creating index {index} for storing chat history.")
|
||||
|
||||
self.client.indices.create(
|
||||
index=index,
|
||||
mappings={
|
||||
"properties": {
|
||||
"session_id": {"type": "keyword"},
|
||||
"created_at": {"type": "date"},
|
||||
"history": {"type": "text"},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_agent() -> str:
|
||||
from langchain import __version__
|
||||
|
||||
return f"langchain-py-ms/{__version__}"
|
||||
|
||||
@staticmethod
|
||||
def connect_to_elasticsearch(
|
||||
*,
|
||||
es_url: Optional[str] = None,
|
||||
cloud_id: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
) -> "Elasticsearch":
|
||||
try:
|
||||
import elasticsearch
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import elasticsearch python package. "
|
||||
"Please install it with `pip install elasticsearch`."
|
||||
)
|
||||
|
||||
if es_url and cloud_id:
|
||||
raise ValueError(
|
||||
"Both es_url and cloud_id are defined. Please provide only one."
|
||||
)
|
||||
|
||||
connection_params: Dict[str, Any] = {}
|
||||
|
||||
if es_url:
|
||||
connection_params["hosts"] = [es_url]
|
||||
elif cloud_id:
|
||||
connection_params["cloud_id"] = cloud_id
|
||||
else:
|
||||
raise ValueError("Please provide either elasticsearch_url or cloud_id.")
|
||||
|
||||
if api_key:
|
||||
connection_params["api_key"] = api_key
|
||||
elif username and password:
|
||||
connection_params["basic_auth"] = (username, password)
|
||||
|
||||
es_client = elasticsearch.Elasticsearch(
|
||||
**connection_params,
|
||||
headers={"user-agent": ElasticsearchChatMessageHistory.get_user_agent()},
|
||||
)
|
||||
try:
|
||||
es_client.info()
|
||||
except Exception as err:
|
||||
logger.error(f"Error connecting to Elasticsearch: {err}")
|
||||
raise err
|
||||
|
||||
return es_client
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore[override]
|
||||
"""Retrieve the messages from Elasticsearch"""
|
||||
try:
|
||||
from elasticsearch import ApiError
|
||||
|
||||
result = self.client.search(
|
||||
index=self.index,
|
||||
query={"term": {"session_id": self.session_id}},
|
||||
sort="created_at:asc",
|
||||
)
|
||||
except ApiError as err:
|
||||
logger.error(f"Could not retrieve messages from Elasticsearch: {err}")
|
||||
raise err
|
||||
|
||||
if result and len(result["hits"]["hits"]) > 0:
|
||||
items = [
|
||||
json.loads(document["_source"]["history"])
|
||||
for document in result["hits"]["hits"]
|
||||
]
|
||||
else:
|
||||
items = []
|
||||
|
||||
return messages_from_dict(items)
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Add a message to the chat session in Elasticsearch"""
|
||||
try:
|
||||
from elasticsearch import ApiError
|
||||
|
||||
self.client.index(
|
||||
index=self.index,
|
||||
document={
|
||||
"session_id": self.session_id,
|
||||
"created_at": round(time() * 1000),
|
||||
"history": json.dumps(_message_to_dict(message)),
|
||||
},
|
||||
refresh=True,
|
||||
)
|
||||
except ApiError as err:
|
||||
logger.error(f"Could not add message to Elasticsearch: {err}")
|
||||
raise err
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory in Elasticsearch"""
|
||||
try:
|
||||
from elasticsearch import ApiError
|
||||
|
||||
self.client.delete_by_query(
|
||||
index=self.index,
|
||||
query={"term": {"session_id": self.session_id}},
|
||||
refresh=True,
|
||||
)
|
||||
except ApiError as err:
|
||||
logger.error(f"Could not clear session memory in Elasticsearch: {err}")
|
||||
raise err
|
@ -0,0 +1,34 @@
|
||||
version: "3"
|
||||
|
||||
services:
|
||||
elasticsearch:
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:8.9.0 # https://www.docker.elastic.co/r/elasticsearch/elasticsearch
|
||||
environment:
|
||||
- discovery.type=single-node
|
||||
- xpack.security.enabled=false # security has been disabled, so no login or password is required.
|
||||
- xpack.security.http.ssl.enabled=false
|
||||
ports:
|
||||
- "9200:9200"
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
"CMD-SHELL",
|
||||
"curl --silent --fail http://localhost:9200/_cluster/health || exit 1",
|
||||
]
|
||||
interval: 10s
|
||||
retries: 60
|
||||
|
||||
kibana:
|
||||
image: docker.elastic.co/kibana/kibana:8.9.0
|
||||
environment:
|
||||
- ELASTICSEARCH_URL=http://elasticsearch:9200
|
||||
ports:
|
||||
- "5601:5601"
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
"CMD-SHELL",
|
||||
"curl --silent --fail http://localhost:5601/login || exit 1",
|
||||
]
|
||||
interval: 10s
|
||||
retries: 60
|
@ -0,0 +1,91 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import Generator, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.memory.chat_message_histories import ElasticsearchChatMessageHistory
|
||||
from langchain.schema.messages import _message_to_dict
|
||||
|
||||
"""
|
||||
cd tests/integration_tests/memory/docker-compose
|
||||
docker-compose -f elasticsearch.yml up
|
||||
|
||||
By default runs against local docker instance of Elasticsearch.
|
||||
To run against Elastic Cloud, set the following environment variables:
|
||||
- ES_CLOUD_ID
|
||||
- ES_USERNAME
|
||||
- ES_PASSWORD
|
||||
"""
|
||||
|
||||
|
||||
class TestElasticsearch:
|
||||
@pytest.fixture(scope="class", autouse=True)
|
||||
def elasticsearch_connection(self) -> Union[dict, Generator[dict, None, None]]:
|
||||
# Run this integration test against Elasticsearch on localhost,
|
||||
# or an Elastic Cloud instance
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
es_url = os.environ.get("ES_URL", "http://localhost:9200")
|
||||
es_cloud_id = os.environ.get("ES_CLOUD_ID")
|
||||
es_username = os.environ.get("ES_USERNAME", "elastic")
|
||||
es_password = os.environ.get("ES_PASSWORD", "changeme")
|
||||
|
||||
if es_cloud_id:
|
||||
es = Elasticsearch(
|
||||
cloud_id=es_cloud_id,
|
||||
basic_auth=(es_username, es_password),
|
||||
)
|
||||
yield {
|
||||
"es_cloud_id": es_cloud_id,
|
||||
"es_user": es_username,
|
||||
"es_password": es_password,
|
||||
}
|
||||
|
||||
else:
|
||||
# Running this integration test with local docker instance
|
||||
es = Elasticsearch(hosts=es_url)
|
||||
yield {"es_url": es_url}
|
||||
|
||||
# Clear all indexes
|
||||
index_names = es.indices.get(index="_all").keys()
|
||||
for index_name in index_names:
|
||||
if index_name.startswith("test_"):
|
||||
es.indices.delete(index=index_name)
|
||||
es.indices.refresh(index="_all")
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def index_name(self) -> str:
|
||||
"""Return the index name."""
|
||||
return f"test_{uuid.uuid4().hex}"
|
||||
|
||||
def test_memory_with_message_store(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
"""Test the memory with a message store."""
|
||||
# setup Elasticsearch as a message store
|
||||
message_history = ElasticsearchChatMessageHistory(
|
||||
**elasticsearch_connection, index=index_name, session_id="test-session"
|
||||
)
|
||||
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="baz", chat_memory=message_history, return_messages=True
|
||||
)
|
||||
|
||||
# add some messages
|
||||
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||
memory.chat_memory.add_user_message("This is me, the human")
|
||||
|
||||
# get the message history from the memory store and turn it into a json
|
||||
messages = memory.chat_memory.messages
|
||||
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
|
||||
|
||||
assert "This is me, the AI" in messages_json
|
||||
assert "This is me, the human" in messages_json
|
||||
|
||||
# remove the record from Elasticsearch, so the next test run won't pick it up
|
||||
memory.chat_memory.clear()
|
||||
|
||||
assert memory.chat_memory.messages == []
|
Loading…
Reference in New Issue