2023-12-11 21:53:30 +00:00
|
|
|
import json
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
from langchain_core.chat_history import BaseChatMessageHistory
|
|
|
|
from langchain_core.messages import (
|
|
|
|
BaseMessage,
|
|
|
|
message_to_dict,
|
|
|
|
messages_from_dict,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class XataChatMessageHistory(BaseChatMessageHistory):
|
|
|
|
"""Chat message history stored in a Xata database."""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
session_id: str,
|
|
|
|
db_url: str,
|
|
|
|
api_key: str,
|
|
|
|
branch_name: str = "main",
|
|
|
|
table_name: str = "messages",
|
|
|
|
create_table: bool = True,
|
|
|
|
) -> None:
|
|
|
|
"""Initialize with Xata client."""
|
|
|
|
try:
|
2024-04-30 17:13:48 +00:00
|
|
|
from xata.client import XataClient
|
2023-12-11 21:53:30 +00:00
|
|
|
except ImportError:
|
2024-04-29 14:32:50 +00:00
|
|
|
raise ImportError(
|
2023-12-11 21:53:30 +00:00
|
|
|
"Could not import xata python package. "
|
|
|
|
"Please install it with `pip install xata`."
|
|
|
|
)
|
|
|
|
self._client = XataClient(
|
|
|
|
api_key=api_key, db_url=db_url, branch_name=branch_name
|
|
|
|
)
|
|
|
|
self._table_name = table_name
|
|
|
|
self._session_id = session_id
|
|
|
|
|
|
|
|
if create_table:
|
|
|
|
self._create_table_if_not_exists()
|
|
|
|
|
|
|
|
def _create_table_if_not_exists(self) -> None:
|
|
|
|
r = self._client.table().get_schema(self._table_name)
|
|
|
|
if r.status_code <= 299:
|
|
|
|
return
|
|
|
|
if r.status_code != 404:
|
|
|
|
raise Exception(
|
|
|
|
f"Error checking if table exists in Xata: {r.status_code} {r}"
|
|
|
|
)
|
|
|
|
r = self._client.table().create(self._table_name)
|
|
|
|
if r.status_code > 299:
|
|
|
|
raise Exception(f"Error creating table in Xata: {r.status_code} {r}")
|
|
|
|
r = self._client.table().set_schema(
|
|
|
|
self._table_name,
|
|
|
|
payload={
|
|
|
|
"columns": [
|
|
|
|
{"name": "sessionId", "type": "string"},
|
|
|
|
{"name": "type", "type": "string"},
|
|
|
|
{"name": "role", "type": "string"},
|
|
|
|
{"name": "content", "type": "text"},
|
|
|
|
{"name": "name", "type": "string"},
|
|
|
|
{"name": "additionalKwargs", "type": "json"},
|
|
|
|
]
|
|
|
|
},
|
|
|
|
)
|
|
|
|
if r.status_code > 299:
|
|
|
|
raise Exception(f"Error setting table schema in Xata: {r.status_code} {r}")
|
|
|
|
|
|
|
|
def add_message(self, message: BaseMessage) -> None:
|
|
|
|
"""Append the message to the Xata table"""
|
|
|
|
msg = message_to_dict(message)
|
|
|
|
r = self._client.records().insert(
|
|
|
|
self._table_name,
|
|
|
|
{
|
|
|
|
"sessionId": self._session_id,
|
|
|
|
"type": msg["type"],
|
|
|
|
"content": message.content,
|
|
|
|
"additionalKwargs": json.dumps(message.additional_kwargs),
|
|
|
|
"role": msg["data"].get("role"),
|
|
|
|
"name": msg["data"].get("name"),
|
|
|
|
},
|
|
|
|
)
|
|
|
|
if r.status_code > 299:
|
|
|
|
raise Exception(f"Error adding message to Xata: {r.status_code} {r}")
|
|
|
|
|
|
|
|
@property
|
|
|
|
def messages(self) -> List[BaseMessage]: # type: ignore
|
|
|
|
r = self._client.data().query(
|
|
|
|
self._table_name,
|
|
|
|
payload={
|
|
|
|
"filter": {
|
|
|
|
"sessionId": self._session_id,
|
|
|
|
},
|
|
|
|
"sort": {"xata.createdAt": "asc"},
|
|
|
|
},
|
|
|
|
)
|
|
|
|
if r.status_code != 200:
|
|
|
|
raise Exception(f"Error running query: {r.status_code} {r}")
|
|
|
|
msgs = messages_from_dict(
|
|
|
|
[
|
|
|
|
{
|
|
|
|
"type": m["type"],
|
|
|
|
"data": {
|
|
|
|
"content": m["content"],
|
|
|
|
"role": m.get("role"),
|
|
|
|
"name": m.get("name"),
|
|
|
|
"additional_kwargs": json.loads(m["additionalKwargs"]),
|
|
|
|
},
|
|
|
|
}
|
|
|
|
for m in r["records"]
|
|
|
|
]
|
|
|
|
)
|
|
|
|
return msgs
|
|
|
|
|
|
|
|
def clear(self) -> None:
|
|
|
|
"""Delete session from Xata table."""
|
|
|
|
while True:
|
|
|
|
r = self._client.data().query(
|
|
|
|
self._table_name,
|
|
|
|
payload={
|
|
|
|
"columns": ["id"],
|
|
|
|
"filter": {
|
|
|
|
"sessionId": self._session_id,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
)
|
|
|
|
if r.status_code != 200:
|
|
|
|
raise Exception(f"Error running query: {r.status_code} {r}")
|
|
|
|
ids = [rec["id"] for rec in r["records"]]
|
|
|
|
if len(ids) == 0:
|
|
|
|
break
|
|
|
|
operations = [
|
|
|
|
{"delete": {"table": self._table_name, "id": id}} for id in ids
|
|
|
|
]
|
|
|
|
self._client.records().transaction(payload={"operations": operations})
|