mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
SingleStoreDBChatMessageHistory: Add singlestoredb support for ChatMessageHistory (#11705)
**Description** - Added the `SingleStoreDBChatMessageHistory` class that inherits `BaseChatMessageHistory` and allows to use of a SingleStoreDB database as a storage for chat message history. - Added integration test to check that everything works (requires `singlestoredb` to be installed) - Added notebook with usage example - Removed custom retriever for SingleStoreDB vector store (as it is useless) --------- Co-authored-by: Volodymyr Tkachuk <vtkachuk-ua@singlestore.com>
This commit is contained in:
parent
634ccb8ccd
commit
ff8e6981ff
@ -0,0 +1,65 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "91c6a7ef",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# SingleStoreDB Chat Message History\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook goes over how to use SingleStoreDB to store chat message history."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "d15e3302",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.memory import SingleStoreDBChatMessageHistory\n",
|
||||||
|
"\n",
|
||||||
|
"history = SingleStoreDBChatMessageHistory(\n",
|
||||||
|
" session_id=\"foo\",\n",
|
||||||
|
" host=\"root:pass@localhost:3306/db\"\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"history.add_user_message(\"hi!\")\n",
|
||||||
|
"\n",
|
||||||
|
"history.add_ai_message(\"whats up?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "64fc465e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"history.messages"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.2"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -42,6 +42,7 @@ from langchain.memory.chat_message_histories import (
|
|||||||
MongoDBChatMessageHistory,
|
MongoDBChatMessageHistory,
|
||||||
PostgresChatMessageHistory,
|
PostgresChatMessageHistory,
|
||||||
RedisChatMessageHistory,
|
RedisChatMessageHistory,
|
||||||
|
SingleStoreDBChatMessageHistory,
|
||||||
SQLChatMessageHistory,
|
SQLChatMessageHistory,
|
||||||
StreamlitChatMessageHistory,
|
StreamlitChatMessageHistory,
|
||||||
UpstashRedisChatMessageHistory,
|
UpstashRedisChatMessageHistory,
|
||||||
@ -90,6 +91,7 @@ __all__ = [
|
|||||||
"ReadOnlySharedMemory",
|
"ReadOnlySharedMemory",
|
||||||
"RedisChatMessageHistory",
|
"RedisChatMessageHistory",
|
||||||
"RedisEntityStore",
|
"RedisEntityStore",
|
||||||
|
"SingleStoreDBChatMessageHistory",
|
||||||
"SQLChatMessageHistory",
|
"SQLChatMessageHistory",
|
||||||
"SQLiteEntityStore",
|
"SQLiteEntityStore",
|
||||||
"SimpleMemory",
|
"SimpleMemory",
|
||||||
|
@ -16,6 +16,9 @@ from langchain.memory.chat_message_histories.mongodb import MongoDBChatMessageHi
|
|||||||
from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory
|
from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory
|
||||||
from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory
|
from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory
|
||||||
from langchain.memory.chat_message_histories.rocksetdb import RocksetChatMessageHistory
|
from langchain.memory.chat_message_histories.rocksetdb import RocksetChatMessageHistory
|
||||||
|
from langchain.memory.chat_message_histories.singlestoredb import (
|
||||||
|
SingleStoreDBChatMessageHistory,
|
||||||
|
)
|
||||||
from langchain.memory.chat_message_histories.sql import SQLChatMessageHistory
|
from langchain.memory.chat_message_histories.sql import SQLChatMessageHistory
|
||||||
from langchain.memory.chat_message_histories.streamlit import (
|
from langchain.memory.chat_message_histories.streamlit import (
|
||||||
StreamlitChatMessageHistory,
|
StreamlitChatMessageHistory,
|
||||||
@ -41,6 +44,7 @@ __all__ = [
|
|||||||
"RocksetChatMessageHistory",
|
"RocksetChatMessageHistory",
|
||||||
"SQLChatMessageHistory",
|
"SQLChatMessageHistory",
|
||||||
"StreamlitChatMessageHistory",
|
"StreamlitChatMessageHistory",
|
||||||
|
"SingleStoreDBChatMessageHistory",
|
||||||
"XataChatMessageHistory",
|
"XataChatMessageHistory",
|
||||||
"ZepChatMessageHistory",
|
"ZepChatMessageHistory",
|
||||||
"UpstashRedisChatMessageHistory",
|
"UpstashRedisChatMessageHistory",
|
||||||
|
@ -0,0 +1,275 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
List,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain.schema import (
|
||||||
|
BaseChatMessageHistory,
|
||||||
|
)
|
||||||
|
from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SingleStoreDBChatMessageHistory(BaseChatMessageHistory):
|
||||||
|
"""Chat message history stored in a SingleStoreDB database."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
*,
|
||||||
|
table_name: str = "message_store",
|
||||||
|
id_field: str = "id",
|
||||||
|
session_id_field: str = "session_id",
|
||||||
|
message_field: str = "message",
|
||||||
|
pool_size: int = 5,
|
||||||
|
max_overflow: int = 10,
|
||||||
|
timeout: float = 30,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
"""Initialize with necessary components.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
|
||||||
|
table_name (str, optional): Specifies the name of the table in use.
|
||||||
|
Defaults to "message_store".
|
||||||
|
id_field (str, optional): Specifies the name of the id field in the table.
|
||||||
|
Defaults to "id".
|
||||||
|
session_id_field (str, optional): Specifies the name of the session_id
|
||||||
|
field in the table. Defaults to "session_id".
|
||||||
|
message_field (str, optional): Specifies the name of the message field
|
||||||
|
in the table. Defaults to "message".
|
||||||
|
|
||||||
|
Following arguments pertain to the connection pool:
|
||||||
|
|
||||||
|
pool_size (int, optional): Determines the number of active connections in
|
||||||
|
the pool. Defaults to 5.
|
||||||
|
max_overflow (int, optional): Determines the maximum number of connections
|
||||||
|
allowed beyond the pool_size. Defaults to 10.
|
||||||
|
timeout (float, optional): Specifies the maximum wait time in seconds for
|
||||||
|
establishing a connection. Defaults to 30.
|
||||||
|
|
||||||
|
Following arguments pertain to the database connection:
|
||||||
|
|
||||||
|
host (str, optional): Specifies the hostname, IP address, or URL for the
|
||||||
|
database connection. The default scheme is "mysql".
|
||||||
|
user (str, optional): Database username.
|
||||||
|
password (str, optional): Database password.
|
||||||
|
port (int, optional): Database port. Defaults to 3306 for non-HTTP
|
||||||
|
connections, 80 for HTTP connections, and 443 for HTTPS connections.
|
||||||
|
database (str, optional): Database name.
|
||||||
|
|
||||||
|
Additional optional arguments provide further customization over the
|
||||||
|
database connection:
|
||||||
|
|
||||||
|
pure_python (bool, optional): Toggles the connector mode. If True,
|
||||||
|
operates in pure Python mode.
|
||||||
|
local_infile (bool, optional): Allows local file uploads.
|
||||||
|
charset (str, optional): Specifies the character set for string values.
|
||||||
|
ssl_key (str, optional): Specifies the path of the file containing the SSL
|
||||||
|
key.
|
||||||
|
ssl_cert (str, optional): Specifies the path of the file containing the SSL
|
||||||
|
certificate.
|
||||||
|
ssl_ca (str, optional): Specifies the path of the file containing the SSL
|
||||||
|
certificate authority.
|
||||||
|
ssl_cipher (str, optional): Sets the SSL cipher list.
|
||||||
|
ssl_disabled (bool, optional): Disables SSL usage.
|
||||||
|
ssl_verify_cert (bool, optional): Verifies the server's certificate.
|
||||||
|
Automatically enabled if ``ssl_ca`` is specified.
|
||||||
|
ssl_verify_identity (bool, optional): Verifies the server's identity.
|
||||||
|
conv (dict[int, Callable], optional): A dictionary of data conversion
|
||||||
|
functions.
|
||||||
|
credential_type (str, optional): Specifies the type of authentication to
|
||||||
|
use: auth.PASSWORD, auth.JWT, or auth.BROWSER_SSO.
|
||||||
|
autocommit (bool, optional): Enables autocommits.
|
||||||
|
results_type (str, optional): Determines the structure of the query results:
|
||||||
|
tuples, namedtuples, dicts.
|
||||||
|
results_format (str, optional): Deprecated. This option has been renamed to
|
||||||
|
results_type.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Basic Usage:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.memory.chat_message_histories import (
|
||||||
|
SingleStoreDBChatMessageHistory
|
||||||
|
)
|
||||||
|
|
||||||
|
message_history = SingleStoreDBChatMessageHistory(
|
||||||
|
session_id="my-session",
|
||||||
|
host="https://user:password@127.0.0.1:3306/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
Advanced Usage:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.memory.chat_message_histories import (
|
||||||
|
SingleStoreDBChatMessageHistory
|
||||||
|
)
|
||||||
|
|
||||||
|
message_history = SingleStoreDBChatMessageHistory(
|
||||||
|
session_id="my-session",
|
||||||
|
host="127.0.0.1",
|
||||||
|
port=3306,
|
||||||
|
user="user",
|
||||||
|
password="password",
|
||||||
|
database="db",
|
||||||
|
table_name="my_custom_table",
|
||||||
|
pool_size=10,
|
||||||
|
timeout=60,
|
||||||
|
)
|
||||||
|
|
||||||
|
Using environment variables:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.memory.chat_message_histories import (
|
||||||
|
SingleStoreDBChatMessageHistory
|
||||||
|
)
|
||||||
|
|
||||||
|
os.environ['SINGLESTOREDB_URL'] = 'me:p455w0rd@s2-host.com/my_db'
|
||||||
|
message_history = SingleStoreDBChatMessageHistory("my-session")
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.table_name = self._sanitize_input(table_name)
|
||||||
|
self.session_id = self._sanitize_input(session_id)
|
||||||
|
self.id_field = self._sanitize_input(id_field)
|
||||||
|
self.session_id_field = self._sanitize_input(session_id_field)
|
||||||
|
self.message_field = self._sanitize_input(message_field)
|
||||||
|
|
||||||
|
# Pass the rest of the kwargs to the connection.
|
||||||
|
self.connection_kwargs = kwargs
|
||||||
|
|
||||||
|
# Add connection attributes to the connection kwargs.
|
||||||
|
if "conn_attrs" not in self.connection_kwargs:
|
||||||
|
self.connection_kwargs["conn_attrs"] = dict()
|
||||||
|
|
||||||
|
self.connection_kwargs["conn_attrs"]["_connector_name"] = "langchain python sdk"
|
||||||
|
self.connection_kwargs["conn_attrs"]["_connector_version"] = "1.0.1"
|
||||||
|
|
||||||
|
# Create a connection pool.
|
||||||
|
try:
|
||||||
|
from sqlalchemy.pool import QueuePool
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import sqlalchemy.pool python package. "
|
||||||
|
"Please install it with `pip install singlestoredb`."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.connection_pool = QueuePool(
|
||||||
|
self._get_connection,
|
||||||
|
max_overflow=max_overflow,
|
||||||
|
pool_size=pool_size,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
self.table_created = False
|
||||||
|
|
||||||
|
def _sanitize_input(self, input_str: str) -> str:
|
||||||
|
# Remove characters that are not alphanumeric or underscores
|
||||||
|
return re.sub(r"[^a-zA-Z0-9_]", "", input_str)
|
||||||
|
|
||||||
|
def _get_connection(self) -> Any:
|
||||||
|
try:
|
||||||
|
import singlestoredb as s2
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import singlestoredb python package. "
|
||||||
|
"Please install it with `pip install singlestoredb`."
|
||||||
|
)
|
||||||
|
return s2.connect(**self.connection_kwargs)
|
||||||
|
|
||||||
|
def _create_table_if_not_exists(self) -> None:
|
||||||
|
"""Create table if it doesn't exist."""
|
||||||
|
if self.table_created:
|
||||||
|
return
|
||||||
|
conn = self.connection_pool.connect()
|
||||||
|
try:
|
||||||
|
cur = conn.cursor()
|
||||||
|
try:
|
||||||
|
cur.execute(
|
||||||
|
"""CREATE TABLE IF NOT EXISTS {}
|
||||||
|
({} BIGINT PRIMARY KEY AUTO_INCREMENT,
|
||||||
|
{} TEXT NOT NULL,
|
||||||
|
{} JSON NOT NULL);""".format(
|
||||||
|
self.table_name,
|
||||||
|
self.id_field,
|
||||||
|
self.session_id_field,
|
||||||
|
self.message_field,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.table_created = True
|
||||||
|
finally:
|
||||||
|
cur.close()
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||||
|
"""Retrieve the messages from SingleStoreDB"""
|
||||||
|
self._create_table_if_not_exists()
|
||||||
|
conn = self.connection_pool.connect()
|
||||||
|
items = []
|
||||||
|
try:
|
||||||
|
cur = conn.cursor()
|
||||||
|
try:
|
||||||
|
cur.execute(
|
||||||
|
"""SELECT {} FROM {} WHERE {} = %s""".format(
|
||||||
|
self.message_field,
|
||||||
|
self.table_name,
|
||||||
|
self.session_id_field,
|
||||||
|
),
|
||||||
|
(self.session_id),
|
||||||
|
)
|
||||||
|
for row in cur.fetchall():
|
||||||
|
items.append(row[0])
|
||||||
|
finally:
|
||||||
|
cur.close()
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
messages = messages_from_dict(items)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
|
"""Append the message to the record in SingleStoreDB"""
|
||||||
|
self._create_table_if_not_exists()
|
||||||
|
conn = self.connection_pool.connect()
|
||||||
|
try:
|
||||||
|
cur = conn.cursor()
|
||||||
|
try:
|
||||||
|
cur.execute(
|
||||||
|
"""INSERT INTO {} ({}, {}) VALUES (%s, %s)""".format(
|
||||||
|
self.table_name,
|
||||||
|
self.session_id_field,
|
||||||
|
self.message_field,
|
||||||
|
),
|
||||||
|
(self.session_id, json.dumps(_message_to_dict(message))),
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
cur.close()
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear session memory from SingleStoreDB"""
|
||||||
|
self._create_table_if_not_exists()
|
||||||
|
conn = self.connection_pool.connect()
|
||||||
|
try:
|
||||||
|
cur = conn.cursor()
|
||||||
|
try:
|
||||||
|
cur.execute(
|
||||||
|
"""DELETE FROM {} WHERE {} = %s""".format(
|
||||||
|
self.table_name,
|
||||||
|
self.session_id_field,
|
||||||
|
),
|
||||||
|
(self.session_id),
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
cur.close()
|
||||||
|
finally:
|
||||||
|
conn.close()
|
@ -1,11 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
ClassVar,
|
|
||||||
Collection,
|
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
@ -15,10 +14,6 @@ from typing import (
|
|||||||
|
|
||||||
from sqlalchemy.pool import QueuePool
|
from sqlalchemy.pool import QueuePool
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
|
||||||
CallbackManagerForRetrieverRun,
|
|
||||||
)
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.schema.embeddings import Embeddings
|
from langchain.schema.embeddings import Embeddings
|
||||||
from langchain.schema.vectorstore import VectorStore, VectorStoreRetriever
|
from langchain.schema.vectorstore import VectorStore, VectorStoreRetriever
|
||||||
@ -186,22 +181,22 @@ class SingleStoreDB(VectorStore):
|
|||||||
|
|
||||||
self.embedding = embedding
|
self.embedding = embedding
|
||||||
self.distance_strategy = distance_strategy
|
self.distance_strategy = distance_strategy
|
||||||
self.table_name = table_name
|
self.table_name = self._sanitize_input(table_name)
|
||||||
self.content_field = content_field
|
self.content_field = self._sanitize_input(content_field)
|
||||||
self.metadata_field = metadata_field
|
self.metadata_field = self._sanitize_input(metadata_field)
|
||||||
self.vector_field = vector_field
|
self.vector_field = self._sanitize_input(vector_field)
|
||||||
|
|
||||||
"""Pass the rest of the kwargs to the connection."""
|
# Pass the rest of the kwargs to the connection.
|
||||||
self.connection_kwargs = kwargs
|
self.connection_kwargs = kwargs
|
||||||
|
|
||||||
"""Add program name and version to connection attributes."""
|
# Add program name and version to connection attributes.
|
||||||
if "conn_attrs" not in self.connection_kwargs:
|
if "conn_attrs" not in self.connection_kwargs:
|
||||||
self.connection_kwargs["conn_attrs"] = dict()
|
self.connection_kwargs["conn_attrs"] = dict()
|
||||||
|
|
||||||
self.connection_kwargs["conn_attrs"]["_connector_name"] = "langchain python sdk"
|
self.connection_kwargs["conn_attrs"]["_connector_name"] = "langchain python sdk"
|
||||||
self.connection_kwargs["conn_attrs"]["_connector_version"] = "1.0.0"
|
self.connection_kwargs["conn_attrs"]["_connector_version"] = "1.0.1"
|
||||||
|
|
||||||
"""Create connection pool."""
|
# Create connection pool.
|
||||||
self.connection_pool = QueuePool(
|
self.connection_pool = QueuePool(
|
||||||
self._get_connection,
|
self._get_connection,
|
||||||
max_overflow=max_overflow,
|
max_overflow=max_overflow,
|
||||||
@ -214,6 +209,10 @@ class SingleStoreDB(VectorStore):
|
|||||||
def embeddings(self) -> Embeddings:
|
def embeddings(self) -> Embeddings:
|
||||||
return self.embedding
|
return self.embedding
|
||||||
|
|
||||||
|
def _sanitize_input(self, input_str: str) -> str:
|
||||||
|
# Remove characters that are not alphanumeric or underscores
|
||||||
|
return re.sub(r"[^a-zA-Z0-9_]", "", input_str)
|
||||||
|
|
||||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||||
return self._max_inner_product_relevance_score_fn
|
return self._max_inner_product_relevance_score_fn
|
||||||
|
|
||||||
@ -444,31 +443,6 @@ class SingleStoreDB(VectorStore):
|
|||||||
instance.add_texts(texts, metadatas, embedding.embed_documents(texts), **kwargs)
|
instance.add_texts(texts, metadatas, embedding.embed_documents(texts), **kwargs)
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
def as_retriever(self, **kwargs: Any) -> SingleStoreDBRetriever:
|
|
||||||
tags = kwargs.pop("tags", None) or []
|
|
||||||
tags.extend(self._get_retriever_tags())
|
|
||||||
return SingleStoreDBRetriever(vectorstore=self, **kwargs, tags=tags)
|
|
||||||
|
|
||||||
|
# SingleStoreDBRetriever is not needed, but we keep it for backwards compatibility
|
||||||
class SingleStoreDBRetriever(VectorStoreRetriever):
|
SingleStoreDBRetriever = VectorStoreRetriever
|
||||||
"""Retriever for SingleStoreDB vector stores."""
|
|
||||||
|
|
||||||
vectorstore: SingleStoreDB
|
|
||||||
k: int = 4
|
|
||||||
allowed_search_types: ClassVar[Collection[str]] = ("similarity",)
|
|
||||||
|
|
||||||
def _get_relevant_documents(
|
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
|
||||||
) -> List[Document]:
|
|
||||||
if self.search_type == "similarity":
|
|
||||||
docs = self.vectorstore.similarity_search(query, k=self.k)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
|
||||||
) -> List[Document]:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"SingleStoreDBVectorStoreRetriever does not support async"
|
|
||||||
)
|
|
||||||
|
@ -0,0 +1,35 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from langchain.memory import ConversationBufferMemory, SingleStoreDBChatMessageHistory
|
||||||
|
from langchain.schema.messages import _message_to_dict
|
||||||
|
|
||||||
|
# Replace these with your mongodb connection string
|
||||||
|
TEST_SINGLESTOREDB_URL = "root:pass@localhost:3306/db"
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_with_message_store() -> None:
|
||||||
|
"""Test the memory with a message store."""
|
||||||
|
# setup SingleStoreDB as a message store
|
||||||
|
message_history = SingleStoreDBChatMessageHistory(
|
||||||
|
session_id="test-session",
|
||||||
|
host=TEST_SINGLESTOREDB_URL,
|
||||||
|
)
|
||||||
|
memory = ConversationBufferMemory(
|
||||||
|
memory_key="baz", chat_memory=message_history, return_messages=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# add some messages
|
||||||
|
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||||
|
memory.chat_memory.add_user_message("This is me, the human")
|
||||||
|
|
||||||
|
# get the message history from the memory store and turn it into a json
|
||||||
|
messages = memory.chat_memory.messages
|
||||||
|
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
|
||||||
|
|
||||||
|
assert "This is me, the AI" in messages_json
|
||||||
|
assert "This is me, the human" in messages_json
|
||||||
|
|
||||||
|
# remove the record from SingleStoreDB, so the next test run won't pick it up
|
||||||
|
memory.chat_memory.clear()
|
||||||
|
|
||||||
|
assert memory.chat_memory.messages == []
|
@ -349,3 +349,27 @@ def test_singlestoredb_filter_metadata_7(texts: List[str]) -> None:
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
drop(table_name)
|
drop(table_name)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not singlestoredb_installed, reason="singlestoredb not installed")
|
||||||
|
def test_singlestoredb_as_retriever(texts: List[str]) -> None:
|
||||||
|
table_name = "test_singlestoredb_8"
|
||||||
|
drop(table_name)
|
||||||
|
docsearch = SingleStoreDB.from_texts(
|
||||||
|
texts,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
distance_strategy=DistanceStrategy.EUCLIDEAN_DISTANCE,
|
||||||
|
table_name=table_name,
|
||||||
|
host=TEST_SINGLESTOREDB_URL,
|
||||||
|
)
|
||||||
|
retriever = docsearch.as_retriever(search_kwargs={"k": 2})
|
||||||
|
output = retriever.get_relevant_documents("foo")
|
||||||
|
assert output == [
|
||||||
|
Document(
|
||||||
|
page_content="foo",
|
||||||
|
),
|
||||||
|
Document(
|
||||||
|
page_content="bar",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
drop(table_name)
|
||||||
|
Loading…
Reference in New Issue
Block a user