Xata as a chat message memory store (#9719)

This adds Xata as a memory store also to the python version of
LangChain, similar to the [one for
LangChain.js](https://github.com/hwchase17/langchainjs/pull/2217).

I have added a Jupyter Notebook with a simple and a more complex example
using an agent.

To run the integration test, you need to execute something like:

```
XATA_API_KEY='xau_...' XATA_DB_URL="https://demo-uni3q8.eu-west-1.xata.sh/db/langchain"  poetry run pytest tests/integration_tests/memory/test_xata.py
```

Where `langchain` is the database you create in Xata.
pull/9737/head
Tudor Golubenco 1 year ago committed by GitHub
parent dff00ea91e
commit dc30edf51c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,326 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Xata chat memory\n",
"\n",
"[Xata](https://xata.io) is a serverless data platform, based on PostgreSQL and Elasticsearch. It provides a Python SDK for interacting with your database, and a UI for managing your data. With the `XataChatMessageHistory` class, you can use Xata databases for longer-term persistence of chat sessions.\n",
"\n",
"This notebook covers:\n",
"\n",
"* A simple example showing what `XataChatMessageHistory` does.\n",
"* A more complex example using a REACT agent that answer questions based on a knowledge based or documentation (stored in Xata as a vector store) and also having a long-term searchable history of its past messages (stored in Xata as a memory store)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"### Create a database\n",
"\n",
"In the [Xata UI](https://app.xata.io) create a new database. You can name it whatever you want, in this notepad we'll use `langchain`. The Langchain integration can auto-create the table used for storying the memory, and this is what we'll use in this example. If you want to pre-create the table, ensure it has the right schema and set `create_table` to `False` when creating the class. Pre-creating the table saves one round-trip to the database during each session initialization."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's first install our dependencies:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install xata==1.0.0rc0 openai langchain"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we need to get the environment variables for Xata. You can create a new API key by visiting your [account settings](https://app.xata.io/settings). To find the database URL, go to the Settings page of the database that you have created. The database URL should look something like this: `https://demo-uni3q8.eu-west-1.xata.sh/db/langchain`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"\n",
"api_key = getpass.getpass(\"Xata API key: \")\n",
"db_url = input(\"Xata database URL (copy it from your DB settings):\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create a simple memory store\n",
"\n",
"To test the memory store functionality in isolation, let's use the following code snippet:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.memory import XataChatMessageHistory\n",
"\n",
"history = XataChatMessageHistory(\n",
" session_id=\"session-1\",\n",
" api_key=api_key,\n",
" db_url=db_url,\n",
" table_name=\"memory\"\n",
")\n",
"\n",
"history.add_user_message(\"hi!\")\n",
"\n",
"history.add_ai_message(\"whats up?\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The above code creates a session with the ID `session-1` and stores two messages in it. After running the above, if you visit the Xata UI, you should see a table named `memory` and the two messages added to it.\n",
"\n",
"You can retrieve the message history for a particular session with the following code:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"history.messages"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conversational Q&A chain on your data with memory\n",
"\n",
"Let's now see a more complex example in which we combine OpenAI, the Xata Vector Store integration, and the Xata memory store integration to create a Q&A chat bot on your data, with follow-up questions and history."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We're going to need to access the OpenAI API, so let's configure the API key:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To store the documents that the chatbot will search for answers, add a table named `docs` to your `langchain` database using the Xata UI, and add the following columns:\n",
"\n",
"* `content` of type \"Text\". This is used to store the `Document.pageContent` values.\n",
"* `embedding` of type \"Vector\". Use the dimension used by the model you plan to use. In this notebook we use OpenAI embeddings, which have 1536 dimensions."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's create the vector store and add some sample docs to it:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"from langchain.vectorstores.xata import XataVectorStore\n",
"\n",
"embeddings = OpenAIEmbeddings()\n",
"\n",
"texts = [\n",
" \"Xata is a Serverless Data platform based on PostgreSQL\",\n",
" \"Xata offers a built-in vector type that can be used to store and query vectors\",\n",
" \"Xata includes similarity search\"\n",
"]\n",
"\n",
"vector_store = XataVectorStore.from_texts(texts, embeddings, api_key=api_key, db_url=db_url, table_name=\"docs\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After running the above command, if you go to the Xata UI, you should see the documents loaded together with their embeddings in the `docs` table."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's now create a ConversationBufferMemory to store the chat messages from both the user and the AI."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.memory import ConversationBufferMemory\n",
"from uuid import uuid4\n",
"\n",
"chat_memory = XataChatMessageHistory(\n",
" session_id=str(uuid4()), # needs to be unique per user session\n",
" api_key=api_key,\n",
" db_url=db_url,\n",
" table_name=\"memory\"\n",
")\n",
"memory = ConversationBufferMemory(memory_key=\"chat_history\", chat_memory=chat_memory, return_messages=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now it's time to create an Agent to use both the vector store and the chat memory together."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents import initialize_agent, AgentType\n",
"from langchain.agents.agent_toolkits import create_retriever_tool\n",
"from langchain.chat_models import ChatOpenAI\n",
"\n",
"tool = create_retriever_tool(\n",
" vector_store.as_retriever(), \n",
" \"search_docs\",\n",
" \"Searches and returns documents from the Xata manual. Useful when you need to answer questions about Xata.\"\n",
")\n",
"tools = [tool]\n",
"\n",
"llm = ChatOpenAI(temperature=0)\n",
"\n",
"agent = initialize_agent(\n",
" tools,\n",
" llm,\n",
" agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,\n",
" verbose=True,\n",
" memory=memory)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To test, let's tell the agent our name:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"agent.run(input=\"My name is bob\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, let's now ask the agent some questions about Xata:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"agent.run(input=\"What is xata?\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Notice that it answers based on the data stored in the document store. And now, let's ask a follow up question:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"agent.run(input=\"Does it support similarity search?\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And now let's test its memory:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"agent.run(input=\"Did I tell you my name? What is it?\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

@ -43,6 +43,7 @@ from langchain.memory.chat_message_histories import (
RedisChatMessageHistory,
SQLChatMessageHistory,
StreamlitChatMessageHistory,
XataChatMessageHistory,
ZepChatMessageHistory,
)
from langchain.memory.combined import CombinedMemory
@ -90,6 +91,7 @@ __all__ = [
"SimpleMemory",
"StreamlitChatMessageHistory",
"VectorStoreRetrieverMemory",
"XataChatMessageHistory",
"ZepChatMessageHistory",
"ZepMemory",
]

@ -17,6 +17,7 @@ from langchain.memory.chat_message_histories.sql import SQLChatMessageHistory
from langchain.memory.chat_message_histories.streamlit import (
StreamlitChatMessageHistory,
)
from langchain.memory.chat_message_histories.xata import XataChatMessageHistory
from langchain.memory.chat_message_histories.zep import ZepChatMessageHistory
__all__ = [
@ -33,5 +34,6 @@ __all__ = [
"RocksetChatMessageHistory",
"SQLChatMessageHistory",
"StreamlitChatMessageHistory",
"XataChatMessageHistory",
"ZepChatMessageHistory",
]

@ -0,0 +1,132 @@
import json
from typing import List
from langchain.schema import (
BaseChatMessageHistory,
)
from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
class XataChatMessageHistory(BaseChatMessageHistory):
"""Chat message history stored in a Xata database."""
def __init__(
self,
session_id: str,
db_url: str,
api_key: str,
branch_name: str = "main",
table_name: str = "messages",
create_table: bool = True,
) -> None:
"""Initialize with Xata client."""
try:
from xata.client import XataClient # noqa: F401
except ImportError:
raise ValueError(
"Could not import xata python package. "
"Please install it with `pip install xata`."
)
self._client = XataClient(
api_key=api_key, db_url=db_url, branch_name=branch_name
)
self._table_name = table_name
self._session_id = session_id
if create_table:
self._create_table_if_not_exists()
def _create_table_if_not_exists(self) -> None:
r = self._client.table().get_schema(self._table_name)
if r.status_code <= 299:
return
if r.status_code != 404:
raise Exception(
f"Error checking if table exists in Xata: {r.status_code} {r}"
)
r = self._client.table().create(self._table_name)
if r.status_code > 299:
raise Exception(f"Error creating table in Xata: {r.status_code} {r}")
r = self._client.table().set_schema(
self._table_name,
payload={
"columns": [
{"name": "sessionId", "type": "string"},
{"name": "type", "type": "string"},
{"name": "role", "type": "string"},
{"name": "content", "type": "text"},
{"name": "name", "type": "string"},
{"name": "additionalKwargs", "type": "text"},
]
},
)
if r.status_code > 299:
raise Exception(f"Error setting table schema in Xata: {r.status_code} {r}")
def add_message(self, message: BaseMessage) -> None:
"""Append the message to the Xata table"""
msg = _message_to_dict(message)
r = self._client.records().insert(
self._table_name,
{
"sessionId": self._session_id,
"type": msg["type"],
"content": message.content,
"additionalKwargs": json.dumps(message.additional_kwargs),
"role": msg["data"].get("role"),
"name": msg["data"].get("name"),
},
)
if r.status_code > 299:
raise Exception(f"Error adding message to Xata: {r.status_code} {r}")
@property
def messages(self) -> List[BaseMessage]: # type: ignore
r = self._client.data().query(
self._table_name,
payload={
"filter": {
"sessionId": self._session_id,
},
"sort": {"xata.createdAt": "asc"},
},
)
if r.status_code != 200:
raise Exception(f"Error running query: {r.status_code} {r}")
msgs = messages_from_dict(
[
{
"type": m["type"],
"data": {
"content": m["content"],
"role": m.get("role"),
"name": m.get("name"),
"additionalKwargs": json.loads(m["additionalKwargs"]),
},
}
for m in r["records"]
]
)
return msgs
def clear(self) -> None:
"""Delete session from Xata table."""
while True:
r = self._client.data().query(
self._table_name,
payload={
"columns": ["id"],
"filter": {
"sessionId": self._session_id,
},
},
)
if r.status_code != 200:
raise Exception(f"Error running query: {r.status_code} {r}")
ids = [rec["id"] for rec in r["records"]]
if len(ids) == 0:
break
operations = [
{"delete": {"table": self._table_name, "id": id}} for id in ids
]
self._client.records().transaction(payload={"operations": operations})

@ -0,0 +1,41 @@
"""Test Xata chat memory store functionality.
Before running this test, please create a Xata database.
"""
import json
import os
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import XataChatMessageHistory
from langchain.schema.messages import _message_to_dict
class TestXata:
@classmethod
def setup_class(cls) -> None:
assert os.getenv("XATA_API_KEY"), "XATA_API_KEY environment variable is not set"
assert os.getenv("XATA_DB_URL"), "XATA_DB_URL environment variable is not set"
def test_xata_chat_memory(self) -> None:
message_history = XataChatMessageHistory(
api_key=os.getenv("XATA_API_KEY", ""),
db_url=os.getenv("XATA_DB_URL", ""),
session_id="integration-test-session",
)
memory = ConversationBufferMemory(
memory_key="baz", chat_memory=message_history, return_messages=True
)
# add some messages
memory.chat_memory.add_ai_message("This is me, the AI")
memory.chat_memory.add_user_message("This is me, the human")
# get the message history from the memory store and turn it into a json
messages = memory.chat_memory.messages
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json
# remove the record from Redis, so the next test run won't pick it up
memory.chat_memory.clear()
Loading…
Cancel
Save