From ff8e6981ff921c5ca04a0e866d764d9b988ed938 Mon Sep 17 00:00:00 2001 From: volodymyr-memsql <57520563+volodymyr-memsql@users.noreply.github.com> Date: Tue, 17 Oct 2023 04:59:45 +0300 Subject: [PATCH] SingleStoreDBChatMessageHistory: Add singlestoredb support for ChatMessageHistory (#11705) **Description** - Added the `SingleStoreDBChatMessageHistory` class that inherits `BaseChatMessageHistory` and allows to use of a SingleStoreDB database as a storage for chat message history. - Added integration test to check that everything works (requires `singlestoredb` to be installed) - Added notebook with usage example - Removed custom retriever for SingleStoreDB vector store (as it is useless) --------- Co-authored-by: Volodymyr Tkachuk --- .../singlestoredb_chat_message_history.ipynb | 65 +++++ libs/langchain/langchain/memory/__init__.py | 2 + .../memory/chat_message_histories/__init__.py | 4 + .../chat_message_histories/singlestoredb.py | 275 ++++++++++++++++++ .../langchain/vectorstores/singlestoredb.py | 56 +--- .../memory/test_singlestoredb.py | 35 +++ .../vectorstores/test_singlestoredb.py | 24 ++ 7 files changed, 420 insertions(+), 41 deletions(-) create mode 100644 docs/extras/integrations/memory/singlestoredb_chat_message_history.ipynb create mode 100644 libs/langchain/langchain/memory/chat_message_histories/singlestoredb.py create mode 100644 libs/langchain/tests/integration_tests/memory/test_singlestoredb.py diff --git a/docs/extras/integrations/memory/singlestoredb_chat_message_history.ipynb b/docs/extras/integrations/memory/singlestoredb_chat_message_history.ipynb new file mode 100644 index 0000000000..9b5912e79c --- /dev/null +++ b/docs/extras/integrations/memory/singlestoredb_chat_message_history.ipynb @@ -0,0 +1,65 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "91c6a7ef", + "metadata": {}, + "source": [ + "# SingleStoreDB Chat Message History\n", + "\n", + "This notebook goes over how to use SingleStoreDB to store chat message history." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d15e3302", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.memory import SingleStoreDBChatMessageHistory\n", + "\n", + "history = SingleStoreDBChatMessageHistory(\n", + " session_id=\"foo\",\n", + " host=\"root:pass@localhost:3306/db\"\n", + ")\n", + "\n", + "history.add_user_message(\"hi!\")\n", + "\n", + "history.add_ai_message(\"whats up?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64fc465e", + "metadata": {}, + "outputs": [], + "source": [ + "history.messages" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/langchain/langchain/memory/__init__.py b/libs/langchain/langchain/memory/__init__.py index 6db5840347..a7049e98b9 100644 --- a/libs/langchain/langchain/memory/__init__.py +++ b/libs/langchain/langchain/memory/__init__.py @@ -42,6 +42,7 @@ from langchain.memory.chat_message_histories import ( MongoDBChatMessageHistory, PostgresChatMessageHistory, RedisChatMessageHistory, + SingleStoreDBChatMessageHistory, SQLChatMessageHistory, StreamlitChatMessageHistory, UpstashRedisChatMessageHistory, @@ -90,6 +91,7 @@ __all__ = [ "ReadOnlySharedMemory", "RedisChatMessageHistory", "RedisEntityStore", + "SingleStoreDBChatMessageHistory", "SQLChatMessageHistory", "SQLiteEntityStore", "SimpleMemory", diff --git a/libs/langchain/langchain/memory/chat_message_histories/__init__.py b/libs/langchain/langchain/memory/chat_message_histories/__init__.py index 760e9d2eff..c0b7c544a9 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/__init__.py +++ b/libs/langchain/langchain/memory/chat_message_histories/__init__.py @@ -16,6 +16,9 @@ from langchain.memory.chat_message_histories.mongodb import MongoDBChatMessageHi from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory from langchain.memory.chat_message_histories.rocksetdb import RocksetChatMessageHistory +from langchain.memory.chat_message_histories.singlestoredb import ( + SingleStoreDBChatMessageHistory, +) from langchain.memory.chat_message_histories.sql import SQLChatMessageHistory from langchain.memory.chat_message_histories.streamlit import ( StreamlitChatMessageHistory, @@ -41,6 +44,7 @@ __all__ = [ "RocksetChatMessageHistory", "SQLChatMessageHistory", "StreamlitChatMessageHistory", + "SingleStoreDBChatMessageHistory", "XataChatMessageHistory", "ZepChatMessageHistory", "UpstashRedisChatMessageHistory", diff --git a/libs/langchain/langchain/memory/chat_message_histories/singlestoredb.py b/libs/langchain/langchain/memory/chat_message_histories/singlestoredb.py new file mode 100644 index 0000000000..74489abdd5 --- /dev/null +++ b/libs/langchain/langchain/memory/chat_message_histories/singlestoredb.py @@ -0,0 +1,275 @@ +import json +import logging +import re +from typing import ( + Any, + List, +) + +from langchain.schema import ( + BaseChatMessageHistory, +) +from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict + +logger = logging.getLogger(__name__) + + +class SingleStoreDBChatMessageHistory(BaseChatMessageHistory): + """Chat message history stored in a SingleStoreDB database.""" + + def __init__( + self, + session_id: str, + *, + table_name: str = "message_store", + id_field: str = "id", + session_id_field: str = "session_id", + message_field: str = "message", + pool_size: int = 5, + max_overflow: int = 10, + timeout: float = 30, + **kwargs: Any, + ): + """Initialize with necessary components. + + Args: + + + table_name (str, optional): Specifies the name of the table in use. + Defaults to "message_store". + id_field (str, optional): Specifies the name of the id field in the table. + Defaults to "id". + session_id_field (str, optional): Specifies the name of the session_id + field in the table. Defaults to "session_id". + message_field (str, optional): Specifies the name of the message field + in the table. Defaults to "message". + + Following arguments pertain to the connection pool: + + pool_size (int, optional): Determines the number of active connections in + the pool. Defaults to 5. + max_overflow (int, optional): Determines the maximum number of connections + allowed beyond the pool_size. Defaults to 10. + timeout (float, optional): Specifies the maximum wait time in seconds for + establishing a connection. Defaults to 30. + + Following arguments pertain to the database connection: + + host (str, optional): Specifies the hostname, IP address, or URL for the + database connection. The default scheme is "mysql". + user (str, optional): Database username. + password (str, optional): Database password. + port (int, optional): Database port. Defaults to 3306 for non-HTTP + connections, 80 for HTTP connections, and 443 for HTTPS connections. + database (str, optional): Database name. + + Additional optional arguments provide further customization over the + database connection: + + pure_python (bool, optional): Toggles the connector mode. If True, + operates in pure Python mode. + local_infile (bool, optional): Allows local file uploads. + charset (str, optional): Specifies the character set for string values. + ssl_key (str, optional): Specifies the path of the file containing the SSL + key. + ssl_cert (str, optional): Specifies the path of the file containing the SSL + certificate. + ssl_ca (str, optional): Specifies the path of the file containing the SSL + certificate authority. + ssl_cipher (str, optional): Sets the SSL cipher list. + ssl_disabled (bool, optional): Disables SSL usage. + ssl_verify_cert (bool, optional): Verifies the server's certificate. + Automatically enabled if ``ssl_ca`` is specified. + ssl_verify_identity (bool, optional): Verifies the server's identity. + conv (dict[int, Callable], optional): A dictionary of data conversion + functions. + credential_type (str, optional): Specifies the type of authentication to + use: auth.PASSWORD, auth.JWT, or auth.BROWSER_SSO. + autocommit (bool, optional): Enables autocommits. + results_type (str, optional): Determines the structure of the query results: + tuples, namedtuples, dicts. + results_format (str, optional): Deprecated. This option has been renamed to + results_type. + + Examples: + Basic Usage: + + .. code-block:: python + + from langchain.memory.chat_message_histories import ( + SingleStoreDBChatMessageHistory + ) + + message_history = SingleStoreDBChatMessageHistory( + session_id="my-session", + host="https://user:password@127.0.0.1:3306/database" + ) + + Advanced Usage: + + .. code-block:: python + + from langchain.memory.chat_message_histories import ( + SingleStoreDBChatMessageHistory + ) + + message_history = SingleStoreDBChatMessageHistory( + session_id="my-session", + host="127.0.0.1", + port=3306, + user="user", + password="password", + database="db", + table_name="my_custom_table", + pool_size=10, + timeout=60, + ) + + Using environment variables: + + .. code-block:: python + + from langchain.memory.chat_message_histories import ( + SingleStoreDBChatMessageHistory + ) + + os.environ['SINGLESTOREDB_URL'] = 'me:p455w0rd@s2-host.com/my_db' + message_history = SingleStoreDBChatMessageHistory("my-session") + """ + + self.table_name = self._sanitize_input(table_name) + self.session_id = self._sanitize_input(session_id) + self.id_field = self._sanitize_input(id_field) + self.session_id_field = self._sanitize_input(session_id_field) + self.message_field = self._sanitize_input(message_field) + + # Pass the rest of the kwargs to the connection. + self.connection_kwargs = kwargs + + # Add connection attributes to the connection kwargs. + if "conn_attrs" not in self.connection_kwargs: + self.connection_kwargs["conn_attrs"] = dict() + + self.connection_kwargs["conn_attrs"]["_connector_name"] = "langchain python sdk" + self.connection_kwargs["conn_attrs"]["_connector_version"] = "1.0.1" + + # Create a connection pool. + try: + from sqlalchemy.pool import QueuePool + except ImportError: + raise ImportError( + "Could not import sqlalchemy.pool python package. " + "Please install it with `pip install singlestoredb`." + ) + + self.connection_pool = QueuePool( + self._get_connection, + max_overflow=max_overflow, + pool_size=pool_size, + timeout=timeout, + ) + self.table_created = False + + def _sanitize_input(self, input_str: str) -> str: + # Remove characters that are not alphanumeric or underscores + return re.sub(r"[^a-zA-Z0-9_]", "", input_str) + + def _get_connection(self) -> Any: + try: + import singlestoredb as s2 + except ImportError: + raise ImportError( + "Could not import singlestoredb python package. " + "Please install it with `pip install singlestoredb`." + ) + return s2.connect(**self.connection_kwargs) + + def _create_table_if_not_exists(self) -> None: + """Create table if it doesn't exist.""" + if self.table_created: + return + conn = self.connection_pool.connect() + try: + cur = conn.cursor() + try: + cur.execute( + """CREATE TABLE IF NOT EXISTS {} + ({} BIGINT PRIMARY KEY AUTO_INCREMENT, + {} TEXT NOT NULL, + {} JSON NOT NULL);""".format( + self.table_name, + self.id_field, + self.session_id_field, + self.message_field, + ), + ) + self.table_created = True + finally: + cur.close() + finally: + conn.close() + + @property + def messages(self) -> List[BaseMessage]: # type: ignore + """Retrieve the messages from SingleStoreDB""" + self._create_table_if_not_exists() + conn = self.connection_pool.connect() + items = [] + try: + cur = conn.cursor() + try: + cur.execute( + """SELECT {} FROM {} WHERE {} = %s""".format( + self.message_field, + self.table_name, + self.session_id_field, + ), + (self.session_id), + ) + for row in cur.fetchall(): + items.append(row[0]) + finally: + cur.close() + finally: + conn.close() + messages = messages_from_dict(items) + return messages + + def add_message(self, message: BaseMessage) -> None: + """Append the message to the record in SingleStoreDB""" + self._create_table_if_not_exists() + conn = self.connection_pool.connect() + try: + cur = conn.cursor() + try: + cur.execute( + """INSERT INTO {} ({}, {}) VALUES (%s, %s)""".format( + self.table_name, + self.session_id_field, + self.message_field, + ), + (self.session_id, json.dumps(_message_to_dict(message))), + ) + finally: + cur.close() + finally: + conn.close() + + def clear(self) -> None: + """Clear session memory from SingleStoreDB""" + self._create_table_if_not_exists() + conn = self.connection_pool.connect() + try: + cur = conn.cursor() + try: + cur.execute( + """DELETE FROM {} WHERE {} = %s""".format( + self.table_name, + self.session_id_field, + ), + (self.session_id), + ) + finally: + cur.close() + finally: + conn.close() diff --git a/libs/langchain/langchain/vectorstores/singlestoredb.py b/libs/langchain/langchain/vectorstores/singlestoredb.py index 4e41269fef..070eeb7b0c 100644 --- a/libs/langchain/langchain/vectorstores/singlestoredb.py +++ b/libs/langchain/langchain/vectorstores/singlestoredb.py @@ -1,11 +1,10 @@ from __future__ import annotations import json +import re from typing import ( Any, Callable, - ClassVar, - Collection, Iterable, List, Optional, @@ -15,10 +14,6 @@ from typing import ( from sqlalchemy.pool import QueuePool -from langchain.callbacks.manager import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, -) from langchain.docstore.document import Document from langchain.schema.embeddings import Embeddings from langchain.schema.vectorstore import VectorStore, VectorStoreRetriever @@ -186,22 +181,22 @@ class SingleStoreDB(VectorStore): self.embedding = embedding self.distance_strategy = distance_strategy - self.table_name = table_name - self.content_field = content_field - self.metadata_field = metadata_field - self.vector_field = vector_field + self.table_name = self._sanitize_input(table_name) + self.content_field = self._sanitize_input(content_field) + self.metadata_field = self._sanitize_input(metadata_field) + self.vector_field = self._sanitize_input(vector_field) - """Pass the rest of the kwargs to the connection.""" + # Pass the rest of the kwargs to the connection. self.connection_kwargs = kwargs - """Add program name and version to connection attributes.""" + # Add program name and version to connection attributes. if "conn_attrs" not in self.connection_kwargs: self.connection_kwargs["conn_attrs"] = dict() self.connection_kwargs["conn_attrs"]["_connector_name"] = "langchain python sdk" - self.connection_kwargs["conn_attrs"]["_connector_version"] = "1.0.0" + self.connection_kwargs["conn_attrs"]["_connector_version"] = "1.0.1" - """Create connection pool.""" + # Create connection pool. self.connection_pool = QueuePool( self._get_connection, max_overflow=max_overflow, @@ -214,6 +209,10 @@ class SingleStoreDB(VectorStore): def embeddings(self) -> Embeddings: return self.embedding + def _sanitize_input(self, input_str: str) -> str: + # Remove characters that are not alphanumeric or underscores + return re.sub(r"[^a-zA-Z0-9_]", "", input_str) + def _select_relevance_score_fn(self) -> Callable[[float], float]: return self._max_inner_product_relevance_score_fn @@ -444,31 +443,6 @@ class SingleStoreDB(VectorStore): instance.add_texts(texts, metadatas, embedding.embed_documents(texts), **kwargs) return instance - def as_retriever(self, **kwargs: Any) -> SingleStoreDBRetriever: - tags = kwargs.pop("tags", None) or [] - tags.extend(self._get_retriever_tags()) - return SingleStoreDBRetriever(vectorstore=self, **kwargs, tags=tags) - - -class SingleStoreDBRetriever(VectorStoreRetriever): - """Retriever for SingleStoreDB vector stores.""" - vectorstore: SingleStoreDB - k: int = 4 - allowed_search_types: ClassVar[Collection[str]] = ("similarity",) - - def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun - ) -> List[Document]: - if self.search_type == "similarity": - docs = self.vectorstore.similarity_search(query, k=self.k) - else: - raise ValueError(f"search_type of {self.search_type} not allowed.") - return docs - - async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun - ) -> List[Document]: - raise NotImplementedError( - "SingleStoreDBVectorStoreRetriever does not support async" - ) +# SingleStoreDBRetriever is not needed, but we keep it for backwards compatibility +SingleStoreDBRetriever = VectorStoreRetriever diff --git a/libs/langchain/tests/integration_tests/memory/test_singlestoredb.py b/libs/langchain/tests/integration_tests/memory/test_singlestoredb.py new file mode 100644 index 0000000000..92a611a529 --- /dev/null +++ b/libs/langchain/tests/integration_tests/memory/test_singlestoredb.py @@ -0,0 +1,35 @@ +import json + +from langchain.memory import ConversationBufferMemory, SingleStoreDBChatMessageHistory +from langchain.schema.messages import _message_to_dict + +# Replace these with your mongodb connection string +TEST_SINGLESTOREDB_URL = "root:pass@localhost:3306/db" + + +def test_memory_with_message_store() -> None: + """Test the memory with a message store.""" + # setup SingleStoreDB as a message store + message_history = SingleStoreDBChatMessageHistory( + session_id="test-session", + host=TEST_SINGLESTOREDB_URL, + ) + memory = ConversationBufferMemory( + memory_key="baz", chat_memory=message_history, return_messages=True + ) + + # add some messages + memory.chat_memory.add_ai_message("This is me, the AI") + memory.chat_memory.add_user_message("This is me, the human") + + # get the message history from the memory store and turn it into a json + messages = memory.chat_memory.messages + messages_json = json.dumps([_message_to_dict(msg) for msg in messages]) + + assert "This is me, the AI" in messages_json + assert "This is me, the human" in messages_json + + # remove the record from SingleStoreDB, so the next test run won't pick it up + memory.chat_memory.clear() + + assert memory.chat_memory.messages == [] diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_singlestoredb.py b/libs/langchain/tests/integration_tests/vectorstores/test_singlestoredb.py index b53ed8ba27..c68ed47180 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/test_singlestoredb.py +++ b/libs/langchain/tests/integration_tests/vectorstores/test_singlestoredb.py @@ -349,3 +349,27 @@ def test_singlestoredb_filter_metadata_7(texts: List[str]) -> None: ) ] drop(table_name) + + +@pytest.mark.skipif(not singlestoredb_installed, reason="singlestoredb not installed") +def test_singlestoredb_as_retriever(texts: List[str]) -> None: + table_name = "test_singlestoredb_8" + drop(table_name) + docsearch = SingleStoreDB.from_texts( + texts, + FakeEmbeddings(), + distance_strategy=DistanceStrategy.EUCLIDEAN_DISTANCE, + table_name=table_name, + host=TEST_SINGLESTOREDB_URL, + ) + retriever = docsearch.as_retriever(search_kwargs={"k": 2}) + output = retriever.get_relevant_documents("foo") + assert output == [ + Document( + page_content="foo", + ), + Document( + page_content="bar", + ), + ] + drop(table_name)