Extend SQLChatMessageHistory (#9849)

### Description

There is a really nice class for saving chat messages into a database -
SQLChatMessageHistory.
It leverages SqlAlchemy to be compatible with any supported database (in
contrast with PostgresChatMessageHistory, which is basically the same
but is limited to Postgres).

However, the class is not really customizable in terms of what you can
store. I can imagine a lot of use cases, when one will need to save a
message date, along with some additional metadata.

To solve this, I propose to extract the converting logic from
BaseMessage to SQLAlchemy model (and vice versa) into a separate class -
message converter. So instead of rewriting the whole
SQLChatMessageHistory class, a user will only need to write a custom
model and a simple mapping class, and pass its instance as a parameter.

I also noticed that there is no documentation on this class, so I added
that too, with an example of custom message converter.

### Issue

N/A

### Dependencies

N/A

### Tag maintainer

Not yet

### Twitter handle

N/A
pull/9791/head
Viktor Zhemchuzhnikov 1 year ago committed by GitHub
parent fed137a8a9
commit 507e46844e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,235 @@
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# SQL Chat Message History\n",
"\n",
"This notebook goes over a **SQLChatMessageHistory** class that allows to store chat history in any database supported by SQLAlchemy.\n",
"\n",
"Please note that to use it with databases other than SQLite, you will need to install the corresponding database driver."
],
"metadata": {
"collapsed": false
},
"id": "f22eab3f84cbeb37"
},
{
"cell_type": "markdown",
"source": [
"### Basic Usage\n",
"\n",
"To use the storage you need to provide only 2 things:\n",
"\n",
"1. Session Id - a unique identifier of the session, like user name, email, chat id etc.\n",
"2. Connection string - a string that specifies the database connection. It will be passed to SQLAlchemy create_engine function."
],
"metadata": {
"collapsed": false
},
"id": "f8f2830ee9ca1e01"
},
{
"cell_type": "code",
"execution_count": 1,
"outputs": [],
"source": [
"from langchain.memory.chat_message_histories import SQLChatMessageHistory\n",
"\n",
"chat_message_history = SQLChatMessageHistory(\n",
"\tsession_id='test_session',\n",
"\tconnection_string='sqlite:///sqlite.db'\n",
")\n",
"\n",
"chat_message_history.add_user_message('Hello')\n",
"chat_message_history.add_ai_message('Hi')"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-08-28T10:04:38.077748Z",
"start_time": "2023-08-28T10:04:36.105894Z"
}
},
"id": "4576e914a866fb40"
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [
{
"data": {
"text/plain": "[HumanMessage(content='Hello', additional_kwargs={}, example=False),\n AIMessage(content='Hi', additional_kwargs={}, example=False)]"
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chat_message_history.messages"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-08-28T10:04:38.929396Z",
"start_time": "2023-08-28T10:04:38.915727Z"
}
},
"id": "b476688cbb32ba90"
},
{
"cell_type": "markdown",
"source": [
"### Custom Storage Format\n",
"\n",
"By default, only the session id and message dictionary are stored in the table.\n",
"\n",
"However, sometimes you might want to store some additional information, like message date, author, language etc.\n",
"\n",
"To do that, you can create a custom message converter, by implementing **BaseMessageConverter** interface."
],
"metadata": {
"collapsed": false
},
"id": "2e5337719d5614fd"
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [],
"source": [
"from datetime import datetime\n",
"from langchain.schema import BaseMessage, HumanMessage, AIMessage, SystemMessage\n",
"from typing import Any\n",
"from sqlalchemy import Column, Integer, Text, DateTime\n",
"from sqlalchemy.orm import declarative_base\n",
"from langchain.memory.chat_message_histories.sql import BaseMessageConverter\n",
"\n",
"\n",
"Base = declarative_base()\n",
"\n",
"\n",
"class CustomMessage(Base):\n",
"\t__tablename__ = 'custom_message_store'\n",
"\n",
"\tid = Column(Integer, primary_key=True)\n",
"\tsession_id = Column(Text)\n",
"\ttype = Column(Text)\n",
"\tcontent = Column(Text)\n",
"\tcreated_at = Column(DateTime)\n",
"\tauthor_email = Column(Text)\n",
"\n",
"\n",
"class CustomMessageConverter(BaseMessageConverter):\n",
"\tdef __init__(self, author_email: str):\n",
"\t\tself.author_email = author_email\n",
"\t\n",
"\tdef from_sql_model(self, sql_message: Any) -> BaseMessage:\n",
"\t\tif sql_message.type == 'human':\n",
"\t\t\treturn HumanMessage(\n",
"\t\t\t\tcontent=sql_message.content,\n",
"\t\t\t)\n",
"\t\telif sql_message.type == 'ai':\n",
"\t\t\treturn AIMessage(\n",
"\t\t\t\tcontent=sql_message.content,\n",
"\t\t\t)\n",
"\t\telif sql_message.type == 'system':\n",
"\t\t\treturn SystemMessage(\n",
"\t\t\t\tcontent=sql_message.content,\n",
"\t\t\t)\n",
"\t\telse:\n",
"\t\t\traise ValueError(f'Unknown message type: {sql_message.type}')\n",
"\t\n",
"\tdef to_sql_model(self, message: BaseMessage, session_id: str) -> Any:\n",
"\t\tnow = datetime.now()\n",
"\t\treturn CustomMessage(\n",
"\t\t\tsession_id=session_id,\n",
"\t\t\ttype=message.type,\n",
"\t\t\tcontent=message.content,\n",
"\t\t\tcreated_at=now,\n",
"\t\t\tauthor_email=self.author_email\n",
"\t\t)\n",
"\t\n",
"\tdef get_sql_model_class(self) -> Any:\n",
"\t\treturn CustomMessage\n",
"\n",
"\n",
"chat_message_history = SQLChatMessageHistory(\n",
"\tsession_id='test_session',\n",
"\tconnection_string='sqlite:///sqlite.db',\n",
"\tcustom_message_converter=CustomMessageConverter(\n",
"\t\tauthor_email='test@example.com'\n",
" )\n",
")\n",
"\n",
"chat_message_history.add_user_message('Hello')\n",
"chat_message_history.add_ai_message('Hi')"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-08-28T10:04:41.510498Z",
"start_time": "2023-08-28T10:04:41.494912Z"
}
},
"id": "fdfde84c07d071bb"
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [
{
"data": {
"text/plain": "[HumanMessage(content='Hello', additional_kwargs={}, example=False),\n AIMessage(content='Hi', additional_kwargs={}, example=False)]"
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chat_message_history.messages"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-08-28T10:04:43.497990Z",
"start_time": "2023-08-28T10:04:43.492517Z"
}
},
"id": "4a6a54d8a9e2856f"
},
{
"cell_type": "markdown",
"source": [
"You also might want to change the name of session_id column. In this case you'll need to specify `session_id_field_name` parameter."
],
"metadata": {
"collapsed": false
},
"id": "622aded629a1adeb"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -1,6 +1,7 @@
import json
import logging
from typing import List
from abc import ABC, abstractmethod
from typing import Any, List, Optional
from sqlalchemy import Column, Integer, Text, create_engine
@ -18,6 +19,25 @@ from langchain.schema.messages import BaseMessage, _message_to_dict, messages_fr
logger = logging.getLogger(__name__)
class BaseMessageConverter(ABC):
"""The class responsible for converting BaseMessage to your SQLAlchemy model."""
@abstractmethod
def from_sql_model(self, sql_message: Any) -> BaseMessage:
"""Convert a SQLAlchemy model to a BaseMessage instance."""
raise NotImplementedError
@abstractmethod
def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
"""Convert a BaseMessage instance to a SQLAlchemy model."""
raise NotImplementedError
@abstractmethod
def get_sql_model_class(self) -> Any:
"""Get the SQLAlchemy model class."""
raise NotImplementedError
def create_message_model(table_name, DynamicBase): # type: ignore
"""
Create a message model for a given table name.
@ -41,6 +61,24 @@ def create_message_model(table_name, DynamicBase): # type: ignore
return Message
class DefaultMessageConverter(BaseMessageConverter):
"""The default message converter for SQLChatMessageHistory."""
def __init__(self, table_name: str):
self.model_class = create_message_model(table_name, declarative_base())
def from_sql_model(self, sql_message: Any) -> BaseMessage:
return messages_from_dict([json.loads(sql_message.message)])[0]
def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
return self.model_class(
session_id=session_id, message=json.dumps(_message_to_dict(message))
)
def get_sql_model_class(self) -> Any:
return self.model_class
class SQLChatMessageHistory(BaseChatMessageHistory):
"""Chat message history stored in an SQL database."""
@ -49,44 +87,49 @@ class SQLChatMessageHistory(BaseChatMessageHistory):
session_id: str,
connection_string: str,
table_name: str = "message_store",
session_id_field_name: str = "session_id",
custom_message_converter: Optional[BaseMessageConverter] = None,
):
self.table_name = table_name
self.connection_string = connection_string
self.engine = create_engine(connection_string, echo=False)
self.session_id_field_name = session_id_field_name
self.converter = custom_message_converter or DefaultMessageConverter(table_name)
self.sql_model_class = self.converter.get_sql_model_class()
if not hasattr(self.sql_model_class, session_id_field_name):
raise ValueError("SQL model class must have session_id column")
self._create_table_if_not_exists()
self.session_id = session_id
self.Session = sessionmaker(self.engine)
def _create_table_if_not_exists(self) -> None:
DynamicBase = declarative_base()
self.Message = create_message_model(self.table_name, DynamicBase)
# Create all does the check for us in case the table exists.
DynamicBase.metadata.create_all(self.engine)
self.sql_model_class.metadata.create_all(self.engine)
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve all messages from db"""
with self.Session() as session:
result = session.query(self.Message).where(
self.Message.session_id == self.session_id
result = session.query(self.sql_model_class).where(
getattr(self.sql_model_class, self.session_id_field_name)
== self.session_id
)
items = [json.loads(record.message) for record in result]
messages = messages_from_dict(items)
messages = []
for record in result:
messages.append(self.converter.from_sql_model(record))
return messages
def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in db"""
with self.Session() as session:
jsonstr = json.dumps(_message_to_dict(message))
session.add(self.Message(session_id=self.session_id, message=jsonstr))
session.add(self.converter.to_sql_model(message, self.session_id))
session.commit()
def clear(self) -> None:
"""Clear session memory from db"""
with self.Session() as session:
session.query(self.Message).filter(
self.Message.session_id == self.session_id
session.query(self.sql_model_class).filter(
getattr(self.sql_model_class, self.session_id_field_name)
== self.session_id
).delete()
session.commit()

@ -1,21 +1,26 @@
from pathlib import Path
from typing import Tuple
from typing import Any, Generator, Tuple
import pytest
from sqlalchemy import Column, Integer, Text
from sqlalchemy.orm import DeclarativeBase
from langchain.memory.chat_message_histories import SQLChatMessageHistory
from langchain.memory.chat_message_histories.sql import DefaultMessageConverter
from langchain.schema.messages import AIMessage, HumanMessage
# @pytest.fixture(params=[("SQLite"), ("postgresql")])
@pytest.fixture(params=[("SQLite")])
def sql_histories(request, tmp_path: Path): # type: ignore
if request.param == "SQLite":
file_path = tmp_path / "db.sqlite3"
con_str = f"sqlite:///{file_path}"
elif request.param == "postgresql":
con_str = "postgresql://postgres:postgres@localhost/postgres"
@pytest.fixture()
def con_str(tmp_path: Path) -> str:
file_path = tmp_path / "db.sqlite3"
con_str = f"sqlite:///{file_path}"
return con_str
@pytest.fixture()
def sql_histories(
con_str: str,
) -> Generator[Tuple[SQLChatMessageHistory, SQLChatMessageHistory], None, None]:
message_history = SQLChatMessageHistory(
session_id="123", connection_string=con_str, table_name="test_table"
)
@ -24,7 +29,7 @@ def sql_histories(request, tmp_path: Path): # type: ignore
session_id="456", connection_string=con_str, table_name="test_table"
)
yield (message_history, other_history)
yield message_history, other_history
message_history.clear()
other_history.clear()
@ -83,3 +88,24 @@ def test_clear_messages(
sql_history.clear()
assert len(sql_history.messages) == 0
assert len(other_history.messages) == 1
def test_model_no_session_id_field_error(con_str: str) -> None:
class Base(DeclarativeBase):
pass
class Model(Base):
__tablename__ = "test_table"
id = Column(Integer, primary_key=True)
test_field = Column(Text)
class CustomMessageConverter(DefaultMessageConverter):
def get_sql_model_class(self) -> Any:
return Model
with pytest.raises(ValueError):
SQLChatMessageHistory(
"test",
con_str,
custom_message_converter=CustomMessageConverter("test_table"),
)

Loading…
Cancel
Save