SingleStoreDBChatMessageHistory: Add singlestoredb support for ChatMessageHistory (#11705)

**Description**

- Added the `SingleStoreDBChatMessageHistory` class that inherits
`BaseChatMessageHistory` and allows to use of a SingleStoreDB database
as a storage for chat message history.
- Added integration test to check that everything works (requires
`singlestoredb` to be installed)
- Added notebook with usage example
- Removed custom retriever for SingleStoreDB vector store (as it is
useless)

---------

Co-authored-by: Volodymyr Tkachuk <vtkachuk-ua@singlestore.com>
This commit is contained in:
volodymyr-memsql 2023-10-17 04:59:45 +03:00 committed by GitHub
parent 634ccb8ccd
commit ff8e6981ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 420 additions and 41 deletions

View File

@ -0,0 +1,65 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "91c6a7ef",
"metadata": {},
"source": [
"# SingleStoreDB Chat Message History\n",
"\n",
"This notebook goes over how to use SingleStoreDB to store chat message history."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d15e3302",
"metadata": {},
"outputs": [],
"source": [
"from langchain.memory import SingleStoreDBChatMessageHistory\n",
"\n",
"history = SingleStoreDBChatMessageHistory(\n",
" session_id=\"foo\",\n",
" host=\"root:pass@localhost:3306/db\"\n",
")\n",
"\n",
"history.add_user_message(\"hi!\")\n",
"\n",
"history.add_ai_message(\"whats up?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "64fc465e",
"metadata": {},
"outputs": [],
"source": [
"history.messages"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -42,6 +42,7 @@ from langchain.memory.chat_message_histories import (
MongoDBChatMessageHistory, MongoDBChatMessageHistory,
PostgresChatMessageHistory, PostgresChatMessageHistory,
RedisChatMessageHistory, RedisChatMessageHistory,
SingleStoreDBChatMessageHistory,
SQLChatMessageHistory, SQLChatMessageHistory,
StreamlitChatMessageHistory, StreamlitChatMessageHistory,
UpstashRedisChatMessageHistory, UpstashRedisChatMessageHistory,
@ -90,6 +91,7 @@ __all__ = [
"ReadOnlySharedMemory", "ReadOnlySharedMemory",
"RedisChatMessageHistory", "RedisChatMessageHistory",
"RedisEntityStore", "RedisEntityStore",
"SingleStoreDBChatMessageHistory",
"SQLChatMessageHistory", "SQLChatMessageHistory",
"SQLiteEntityStore", "SQLiteEntityStore",
"SimpleMemory", "SimpleMemory",

View File

@ -16,6 +16,9 @@ from langchain.memory.chat_message_histories.mongodb import MongoDBChatMessageHi
from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory
from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory
from langchain.memory.chat_message_histories.rocksetdb import RocksetChatMessageHistory from langchain.memory.chat_message_histories.rocksetdb import RocksetChatMessageHistory
from langchain.memory.chat_message_histories.singlestoredb import (
SingleStoreDBChatMessageHistory,
)
from langchain.memory.chat_message_histories.sql import SQLChatMessageHistory from langchain.memory.chat_message_histories.sql import SQLChatMessageHistory
from langchain.memory.chat_message_histories.streamlit import ( from langchain.memory.chat_message_histories.streamlit import (
StreamlitChatMessageHistory, StreamlitChatMessageHistory,
@ -41,6 +44,7 @@ __all__ = [
"RocksetChatMessageHistory", "RocksetChatMessageHistory",
"SQLChatMessageHistory", "SQLChatMessageHistory",
"StreamlitChatMessageHistory", "StreamlitChatMessageHistory",
"SingleStoreDBChatMessageHistory",
"XataChatMessageHistory", "XataChatMessageHistory",
"ZepChatMessageHistory", "ZepChatMessageHistory",
"UpstashRedisChatMessageHistory", "UpstashRedisChatMessageHistory",

View File

@ -0,0 +1,275 @@
import json
import logging
import re
from typing import (
Any,
List,
)
from langchain.schema import (
BaseChatMessageHistory,
)
from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
logger = logging.getLogger(__name__)
class SingleStoreDBChatMessageHistory(BaseChatMessageHistory):
"""Chat message history stored in a SingleStoreDB database."""
def __init__(
self,
session_id: str,
*,
table_name: str = "message_store",
id_field: str = "id",
session_id_field: str = "session_id",
message_field: str = "message",
pool_size: int = 5,
max_overflow: int = 10,
timeout: float = 30,
**kwargs: Any,
):
"""Initialize with necessary components.
Args:
table_name (str, optional): Specifies the name of the table in use.
Defaults to "message_store".
id_field (str, optional): Specifies the name of the id field in the table.
Defaults to "id".
session_id_field (str, optional): Specifies the name of the session_id
field in the table. Defaults to "session_id".
message_field (str, optional): Specifies the name of the message field
in the table. Defaults to "message".
Following arguments pertain to the connection pool:
pool_size (int, optional): Determines the number of active connections in
the pool. Defaults to 5.
max_overflow (int, optional): Determines the maximum number of connections
allowed beyond the pool_size. Defaults to 10.
timeout (float, optional): Specifies the maximum wait time in seconds for
establishing a connection. Defaults to 30.
Following arguments pertain to the database connection:
host (str, optional): Specifies the hostname, IP address, or URL for the
database connection. The default scheme is "mysql".
user (str, optional): Database username.
password (str, optional): Database password.
port (int, optional): Database port. Defaults to 3306 for non-HTTP
connections, 80 for HTTP connections, and 443 for HTTPS connections.
database (str, optional): Database name.
Additional optional arguments provide further customization over the
database connection:
pure_python (bool, optional): Toggles the connector mode. If True,
operates in pure Python mode.
local_infile (bool, optional): Allows local file uploads.
charset (str, optional): Specifies the character set for string values.
ssl_key (str, optional): Specifies the path of the file containing the SSL
key.
ssl_cert (str, optional): Specifies the path of the file containing the SSL
certificate.
ssl_ca (str, optional): Specifies the path of the file containing the SSL
certificate authority.
ssl_cipher (str, optional): Sets the SSL cipher list.
ssl_disabled (bool, optional): Disables SSL usage.
ssl_verify_cert (bool, optional): Verifies the server's certificate.
Automatically enabled if ``ssl_ca`` is specified.
ssl_verify_identity (bool, optional): Verifies the server's identity.
conv (dict[int, Callable], optional): A dictionary of data conversion
functions.
credential_type (str, optional): Specifies the type of authentication to
use: auth.PASSWORD, auth.JWT, or auth.BROWSER_SSO.
autocommit (bool, optional): Enables autocommits.
results_type (str, optional): Determines the structure of the query results:
tuples, namedtuples, dicts.
results_format (str, optional): Deprecated. This option has been renamed to
results_type.
Examples:
Basic Usage:
.. code-block:: python
from langchain.memory.chat_message_histories import (
SingleStoreDBChatMessageHistory
)
message_history = SingleStoreDBChatMessageHistory(
session_id="my-session",
host="https://user:password@127.0.0.1:3306/database"
)
Advanced Usage:
.. code-block:: python
from langchain.memory.chat_message_histories import (
SingleStoreDBChatMessageHistory
)
message_history = SingleStoreDBChatMessageHistory(
session_id="my-session",
host="127.0.0.1",
port=3306,
user="user",
password="password",
database="db",
table_name="my_custom_table",
pool_size=10,
timeout=60,
)
Using environment variables:
.. code-block:: python
from langchain.memory.chat_message_histories import (
SingleStoreDBChatMessageHistory
)
os.environ['SINGLESTOREDB_URL'] = 'me:p455w0rd@s2-host.com/my_db'
message_history = SingleStoreDBChatMessageHistory("my-session")
"""
self.table_name = self._sanitize_input(table_name)
self.session_id = self._sanitize_input(session_id)
self.id_field = self._sanitize_input(id_field)
self.session_id_field = self._sanitize_input(session_id_field)
self.message_field = self._sanitize_input(message_field)
# Pass the rest of the kwargs to the connection.
self.connection_kwargs = kwargs
# Add connection attributes to the connection kwargs.
if "conn_attrs" not in self.connection_kwargs:
self.connection_kwargs["conn_attrs"] = dict()
self.connection_kwargs["conn_attrs"]["_connector_name"] = "langchain python sdk"
self.connection_kwargs["conn_attrs"]["_connector_version"] = "1.0.1"
# Create a connection pool.
try:
from sqlalchemy.pool import QueuePool
except ImportError:
raise ImportError(
"Could not import sqlalchemy.pool python package. "
"Please install it with `pip install singlestoredb`."
)
self.connection_pool = QueuePool(
self._get_connection,
max_overflow=max_overflow,
pool_size=pool_size,
timeout=timeout,
)
self.table_created = False
def _sanitize_input(self, input_str: str) -> str:
# Remove characters that are not alphanumeric or underscores
return re.sub(r"[^a-zA-Z0-9_]", "", input_str)
def _get_connection(self) -> Any:
try:
import singlestoredb as s2
except ImportError:
raise ImportError(
"Could not import singlestoredb python package. "
"Please install it with `pip install singlestoredb`."
)
return s2.connect(**self.connection_kwargs)
def _create_table_if_not_exists(self) -> None:
"""Create table if it doesn't exist."""
if self.table_created:
return
conn = self.connection_pool.connect()
try:
cur = conn.cursor()
try:
cur.execute(
"""CREATE TABLE IF NOT EXISTS {}
({} BIGINT PRIMARY KEY AUTO_INCREMENT,
{} TEXT NOT NULL,
{} JSON NOT NULL);""".format(
self.table_name,
self.id_field,
self.session_id_field,
self.message_field,
),
)
self.table_created = True
finally:
cur.close()
finally:
conn.close()
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve the messages from SingleStoreDB"""
self._create_table_if_not_exists()
conn = self.connection_pool.connect()
items = []
try:
cur = conn.cursor()
try:
cur.execute(
"""SELECT {} FROM {} WHERE {} = %s""".format(
self.message_field,
self.table_name,
self.session_id_field,
),
(self.session_id),
)
for row in cur.fetchall():
items.append(row[0])
finally:
cur.close()
finally:
conn.close()
messages = messages_from_dict(items)
return messages
def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in SingleStoreDB"""
self._create_table_if_not_exists()
conn = self.connection_pool.connect()
try:
cur = conn.cursor()
try:
cur.execute(
"""INSERT INTO {} ({}, {}) VALUES (%s, %s)""".format(
self.table_name,
self.session_id_field,
self.message_field,
),
(self.session_id, json.dumps(_message_to_dict(message))),
)
finally:
cur.close()
finally:
conn.close()
def clear(self) -> None:
"""Clear session memory from SingleStoreDB"""
self._create_table_if_not_exists()
conn = self.connection_pool.connect()
try:
cur = conn.cursor()
try:
cur.execute(
"""DELETE FROM {} WHERE {} = %s""".format(
self.table_name,
self.session_id_field,
),
(self.session_id),
)
finally:
cur.close()
finally:
conn.close()

View File

@ -1,11 +1,10 @@
from __future__ import annotations from __future__ import annotations
import json import json
import re
from typing import ( from typing import (
Any, Any,
Callable, Callable,
ClassVar,
Collection,
Iterable, Iterable,
List, List,
Optional, Optional,
@ -15,10 +14,6 @@ from typing import (
from sqlalchemy.pool import QueuePool from sqlalchemy.pool import QueuePool
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.schema.embeddings import Embeddings from langchain.schema.embeddings import Embeddings
from langchain.schema.vectorstore import VectorStore, VectorStoreRetriever from langchain.schema.vectorstore import VectorStore, VectorStoreRetriever
@ -186,22 +181,22 @@ class SingleStoreDB(VectorStore):
self.embedding = embedding self.embedding = embedding
self.distance_strategy = distance_strategy self.distance_strategy = distance_strategy
self.table_name = table_name self.table_name = self._sanitize_input(table_name)
self.content_field = content_field self.content_field = self._sanitize_input(content_field)
self.metadata_field = metadata_field self.metadata_field = self._sanitize_input(metadata_field)
self.vector_field = vector_field self.vector_field = self._sanitize_input(vector_field)
"""Pass the rest of the kwargs to the connection.""" # Pass the rest of the kwargs to the connection.
self.connection_kwargs = kwargs self.connection_kwargs = kwargs
"""Add program name and version to connection attributes.""" # Add program name and version to connection attributes.
if "conn_attrs" not in self.connection_kwargs: if "conn_attrs" not in self.connection_kwargs:
self.connection_kwargs["conn_attrs"] = dict() self.connection_kwargs["conn_attrs"] = dict()
self.connection_kwargs["conn_attrs"]["_connector_name"] = "langchain python sdk" self.connection_kwargs["conn_attrs"]["_connector_name"] = "langchain python sdk"
self.connection_kwargs["conn_attrs"]["_connector_version"] = "1.0.0" self.connection_kwargs["conn_attrs"]["_connector_version"] = "1.0.1"
"""Create connection pool.""" # Create connection pool.
self.connection_pool = QueuePool( self.connection_pool = QueuePool(
self._get_connection, self._get_connection,
max_overflow=max_overflow, max_overflow=max_overflow,
@ -214,6 +209,10 @@ class SingleStoreDB(VectorStore):
def embeddings(self) -> Embeddings: def embeddings(self) -> Embeddings:
return self.embedding return self.embedding
def _sanitize_input(self, input_str: str) -> str:
# Remove characters that are not alphanumeric or underscores
return re.sub(r"[^a-zA-Z0-9_]", "", input_str)
def _select_relevance_score_fn(self) -> Callable[[float], float]: def _select_relevance_score_fn(self) -> Callable[[float], float]:
return self._max_inner_product_relevance_score_fn return self._max_inner_product_relevance_score_fn
@ -444,31 +443,6 @@ class SingleStoreDB(VectorStore):
instance.add_texts(texts, metadatas, embedding.embed_documents(texts), **kwargs) instance.add_texts(texts, metadatas, embedding.embed_documents(texts), **kwargs)
return instance return instance
def as_retriever(self, **kwargs: Any) -> SingleStoreDBRetriever:
tags = kwargs.pop("tags", None) or []
tags.extend(self._get_retriever_tags())
return SingleStoreDBRetriever(vectorstore=self, **kwargs, tags=tags)
# SingleStoreDBRetriever is not needed, but we keep it for backwards compatibility
class SingleStoreDBRetriever(VectorStoreRetriever): SingleStoreDBRetriever = VectorStoreRetriever
"""Retriever for SingleStoreDB vector stores."""
vectorstore: SingleStoreDB
k: int = 4
allowed_search_types: ClassVar[Collection[str]] = ("similarity",)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, k=self.k)
else:
raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError(
"SingleStoreDBVectorStoreRetriever does not support async"
)

View File

@ -0,0 +1,35 @@
import json
from langchain.memory import ConversationBufferMemory, SingleStoreDBChatMessageHistory
from langchain.schema.messages import _message_to_dict
# Replace these with your mongodb connection string
TEST_SINGLESTOREDB_URL = "root:pass@localhost:3306/db"
def test_memory_with_message_store() -> None:
"""Test the memory with a message store."""
# setup SingleStoreDB as a message store
message_history = SingleStoreDBChatMessageHistory(
session_id="test-session",
host=TEST_SINGLESTOREDB_URL,
)
memory = ConversationBufferMemory(
memory_key="baz", chat_memory=message_history, return_messages=True
)
# add some messages
memory.chat_memory.add_ai_message("This is me, the AI")
memory.chat_memory.add_user_message("This is me, the human")
# get the message history from the memory store and turn it into a json
messages = memory.chat_memory.messages
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json
# remove the record from SingleStoreDB, so the next test run won't pick it up
memory.chat_memory.clear()
assert memory.chat_memory.messages == []

View File

@ -349,3 +349,27 @@ def test_singlestoredb_filter_metadata_7(texts: List[str]) -> None:
) )
] ]
drop(table_name) drop(table_name)
@pytest.mark.skipif(not singlestoredb_installed, reason="singlestoredb not installed")
def test_singlestoredb_as_retriever(texts: List[str]) -> None:
table_name = "test_singlestoredb_8"
drop(table_name)
docsearch = SingleStoreDB.from_texts(
texts,
FakeEmbeddings(),
distance_strategy=DistanceStrategy.EUCLIDEAN_DISTANCE,
table_name=table_name,
host=TEST_SINGLESTOREDB_URL,
)
retriever = docsearch.as_retriever(search_kwargs={"k": 2})
output = retriever.get_relevant_documents("foo")
assert output == [
Document(
page_content="foo",
),
Document(
page_content="bar",
),
]
drop(table_name)