mirror of https://github.com/hwchase17/langchain
Xata as a chat message memory store (#9719)
This adds Xata as a memory store also to the python version of LangChain, similar to the [one for LangChain.js](https://github.com/hwchase17/langchainjs/pull/2217). I have added a Jupyter Notebook with a simple and a more complex example using an agent. To run the integration test, you need to execute something like: ``` XATA_API_KEY='xau_...' XATA_DB_URL="https://demo-uni3q8.eu-west-1.xata.sh/db/langchain" poetry run pytest tests/integration_tests/memory/test_xata.py ``` Where `langchain` is the database you create in Xata.pull/9737/head
parent
dff00ea91e
commit
dc30edf51c
@ -0,0 +1,326 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Xata chat memory\n",
|
||||
"\n",
|
||||
"[Xata](https://xata.io) is a serverless data platform, based on PostgreSQL and Elasticsearch. It provides a Python SDK for interacting with your database, and a UI for managing your data. With the `XataChatMessageHistory` class, you can use Xata databases for longer-term persistence of chat sessions.\n",
|
||||
"\n",
|
||||
"This notebook covers:\n",
|
||||
"\n",
|
||||
"* A simple example showing what `XataChatMessageHistory` does.\n",
|
||||
"* A more complex example using a REACT agent that answer questions based on a knowledge based or documentation (stored in Xata as a vector store) and also having a long-term searchable history of its past messages (stored in Xata as a memory store)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Setup\n",
|
||||
"\n",
|
||||
"### Create a database\n",
|
||||
"\n",
|
||||
"In the [Xata UI](https://app.xata.io) create a new database. You can name it whatever you want, in this notepad we'll use `langchain`. The Langchain integration can auto-create the table used for storying the memory, and this is what we'll use in this example. If you want to pre-create the table, ensure it has the right schema and set `create_table` to `False` when creating the class. Pre-creating the table saves one round-trip to the database during each session initialization."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's first install our dependencies:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install xata==1.0.0rc0 openai langchain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next, we need to get the environment variables for Xata. You can create a new API key by visiting your [account settings](https://app.xata.io/settings). To find the database URL, go to the Settings page of the database that you have created. The database URL should look something like this: `https://demo-uni3q8.eu-west-1.xata.sh/db/langchain`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"\n",
|
||||
"api_key = getpass.getpass(\"Xata API key: \")\n",
|
||||
"db_url = input(\"Xata database URL (copy it from your DB settings):\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create a simple memory store\n",
|
||||
"\n",
|
||||
"To test the memory store functionality in isolation, let's use the following code snippet:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.memory import XataChatMessageHistory\n",
|
||||
"\n",
|
||||
"history = XataChatMessageHistory(\n",
|
||||
" session_id=\"session-1\",\n",
|
||||
" api_key=api_key,\n",
|
||||
" db_url=db_url,\n",
|
||||
" table_name=\"memory\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"history.add_user_message(\"hi!\")\n",
|
||||
"\n",
|
||||
"history.add_ai_message(\"whats up?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The above code creates a session with the ID `session-1` and stores two messages in it. After running the above, if you visit the Xata UI, you should see a table named `memory` and the two messages added to it.\n",
|
||||
"\n",
|
||||
"You can retrieve the message history for a particular session with the following code:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"history.messages"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Conversational Q&A chain on your data with memory\n",
|
||||
"\n",
|
||||
"Let's now see a more complex example in which we combine OpenAI, the Xata Vector Store integration, and the Xata memory store integration to create a Q&A chat bot on your data, with follow-up questions and history."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We're going to need to access the OpenAI API, so let's configure the API key:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To store the documents that the chatbot will search for answers, add a table named `docs` to your `langchain` database using the Xata UI, and add the following columns:\n",
|
||||
"\n",
|
||||
"* `content` of type \"Text\". This is used to store the `Document.pageContent` values.\n",
|
||||
"* `embedding` of type \"Vector\". Use the dimension used by the model you plan to use. In this notebook we use OpenAI embeddings, which have 1536 dimensions."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's create the vector store and add some sample docs to it:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
||||
"from langchain.vectorstores.xata import XataVectorStore\n",
|
||||
"\n",
|
||||
"embeddings = OpenAIEmbeddings()\n",
|
||||
"\n",
|
||||
"texts = [\n",
|
||||
" \"Xata is a Serverless Data platform based on PostgreSQL\",\n",
|
||||
" \"Xata offers a built-in vector type that can be used to store and query vectors\",\n",
|
||||
" \"Xata includes similarity search\"\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"vector_store = XataVectorStore.from_texts(texts, embeddings, api_key=api_key, db_url=db_url, table_name=\"docs\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"After running the above command, if you go to the Xata UI, you should see the documents loaded together with their embeddings in the `docs` table."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's now create a ConversationBufferMemory to store the chat messages from both the user and the AI."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.memory import ConversationBufferMemory\n",
|
||||
"from uuid import uuid4\n",
|
||||
"\n",
|
||||
"chat_memory = XataChatMessageHistory(\n",
|
||||
" session_id=str(uuid4()), # needs to be unique per user session\n",
|
||||
" api_key=api_key,\n",
|
||||
" db_url=db_url,\n",
|
||||
" table_name=\"memory\"\n",
|
||||
")\n",
|
||||
"memory = ConversationBufferMemory(memory_key=\"chat_history\", chat_memory=chat_memory, return_messages=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now it's time to create an Agent to use both the vector store and the chat memory together."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.agents import initialize_agent, AgentType\n",
|
||||
"from langchain.agents.agent_toolkits import create_retriever_tool\n",
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"\n",
|
||||
"tool = create_retriever_tool(\n",
|
||||
" vector_store.as_retriever(), \n",
|
||||
" \"search_docs\",\n",
|
||||
" \"Searches and returns documents from the Xata manual. Useful when you need to answer questions about Xata.\"\n",
|
||||
")\n",
|
||||
"tools = [tool]\n",
|
||||
"\n",
|
||||
"llm = ChatOpenAI(temperature=0)\n",
|
||||
"\n",
|
||||
"agent = initialize_agent(\n",
|
||||
" tools,\n",
|
||||
" llm,\n",
|
||||
" agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,\n",
|
||||
" verbose=True,\n",
|
||||
" memory=memory)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To test, let's tell the agent our name:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"agent.run(input=\"My name is bob\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now, let's now ask the agent some questions about Xata:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"agent.run(input=\"What is xata?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Notice that it answers based on the data stored in the document store. And now, let's ask a follow up question:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"agent.run(input=\"Does it support similarity search?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"And now let's test its memory:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"agent.run(input=\"Did I tell you my name? What is it?\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
@ -0,0 +1,132 @@
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
from langchain.schema import (
|
||||
BaseChatMessageHistory,
|
||||
)
|
||||
from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
|
||||
|
||||
|
||||
class XataChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history stored in a Xata database."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
db_url: str,
|
||||
api_key: str,
|
||||
branch_name: str = "main",
|
||||
table_name: str = "messages",
|
||||
create_table: bool = True,
|
||||
) -> None:
|
||||
"""Initialize with Xata client."""
|
||||
try:
|
||||
from xata.client import XataClient # noqa: F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import xata python package. "
|
||||
"Please install it with `pip install xata`."
|
||||
)
|
||||
self._client = XataClient(
|
||||
api_key=api_key, db_url=db_url, branch_name=branch_name
|
||||
)
|
||||
self._table_name = table_name
|
||||
self._session_id = session_id
|
||||
|
||||
if create_table:
|
||||
self._create_table_if_not_exists()
|
||||
|
||||
def _create_table_if_not_exists(self) -> None:
|
||||
r = self._client.table().get_schema(self._table_name)
|
||||
if r.status_code <= 299:
|
||||
return
|
||||
if r.status_code != 404:
|
||||
raise Exception(
|
||||
f"Error checking if table exists in Xata: {r.status_code} {r}"
|
||||
)
|
||||
r = self._client.table().create(self._table_name)
|
||||
if r.status_code > 299:
|
||||
raise Exception(f"Error creating table in Xata: {r.status_code} {r}")
|
||||
r = self._client.table().set_schema(
|
||||
self._table_name,
|
||||
payload={
|
||||
"columns": [
|
||||
{"name": "sessionId", "type": "string"},
|
||||
{"name": "type", "type": "string"},
|
||||
{"name": "role", "type": "string"},
|
||||
{"name": "content", "type": "text"},
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "additionalKwargs", "type": "text"},
|
||||
]
|
||||
},
|
||||
)
|
||||
if r.status_code > 299:
|
||||
raise Exception(f"Error setting table schema in Xata: {r.status_code} {r}")
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Append the message to the Xata table"""
|
||||
msg = _message_to_dict(message)
|
||||
r = self._client.records().insert(
|
||||
self._table_name,
|
||||
{
|
||||
"sessionId": self._session_id,
|
||||
"type": msg["type"],
|
||||
"content": message.content,
|
||||
"additionalKwargs": json.dumps(message.additional_kwargs),
|
||||
"role": msg["data"].get("role"),
|
||||
"name": msg["data"].get("name"),
|
||||
},
|
||||
)
|
||||
if r.status_code > 299:
|
||||
raise Exception(f"Error adding message to Xata: {r.status_code} {r}")
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
r = self._client.data().query(
|
||||
self._table_name,
|
||||
payload={
|
||||
"filter": {
|
||||
"sessionId": self._session_id,
|
||||
},
|
||||
"sort": {"xata.createdAt": "asc"},
|
||||
},
|
||||
)
|
||||
if r.status_code != 200:
|
||||
raise Exception(f"Error running query: {r.status_code} {r}")
|
||||
msgs = messages_from_dict(
|
||||
[
|
||||
{
|
||||
"type": m["type"],
|
||||
"data": {
|
||||
"content": m["content"],
|
||||
"role": m.get("role"),
|
||||
"name": m.get("name"),
|
||||
"additionalKwargs": json.loads(m["additionalKwargs"]),
|
||||
},
|
||||
}
|
||||
for m in r["records"]
|
||||
]
|
||||
)
|
||||
return msgs
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Delete session from Xata table."""
|
||||
while True:
|
||||
r = self._client.data().query(
|
||||
self._table_name,
|
||||
payload={
|
||||
"columns": ["id"],
|
||||
"filter": {
|
||||
"sessionId": self._session_id,
|
||||
},
|
||||
},
|
||||
)
|
||||
if r.status_code != 200:
|
||||
raise Exception(f"Error running query: {r.status_code} {r}")
|
||||
ids = [rec["id"] for rec in r["records"]]
|
||||
if len(ids) == 0:
|
||||
break
|
||||
operations = [
|
||||
{"delete": {"table": self._table_name, "id": id}} for id in ids
|
||||
]
|
||||
self._client.records().transaction(payload={"operations": operations})
|
@ -0,0 +1,41 @@
|
||||
"""Test Xata chat memory store functionality.
|
||||
|
||||
Before running this test, please create a Xata database.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.memory.chat_message_histories import XataChatMessageHistory
|
||||
from langchain.schema.messages import _message_to_dict
|
||||
|
||||
|
||||
class TestXata:
|
||||
@classmethod
|
||||
def setup_class(cls) -> None:
|
||||
assert os.getenv("XATA_API_KEY"), "XATA_API_KEY environment variable is not set"
|
||||
assert os.getenv("XATA_DB_URL"), "XATA_DB_URL environment variable is not set"
|
||||
|
||||
def test_xata_chat_memory(self) -> None:
|
||||
message_history = XataChatMessageHistory(
|
||||
api_key=os.getenv("XATA_API_KEY", ""),
|
||||
db_url=os.getenv("XATA_DB_URL", ""),
|
||||
session_id="integration-test-session",
|
||||
)
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="baz", chat_memory=message_history, return_messages=True
|
||||
)
|
||||
# add some messages
|
||||
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||
memory.chat_memory.add_user_message("This is me, the human")
|
||||
|
||||
# get the message history from the memory store and turn it into a json
|
||||
messages = memory.chat_memory.messages
|
||||
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
|
||||
|
||||
assert "This is me, the AI" in messages_json
|
||||
assert "This is me, the human" in messages_json
|
||||
|
||||
# remove the record from Redis, so the next test run won't pick it up
|
||||
memory.chat_memory.clear()
|
Loading…
Reference in New Issue