From cfc225ecb3f442b06f552d7eff2ec73b39f9beb5 Mon Sep 17 00:00:00 2001 From: gcheron Date: Wed, 24 Jan 2024 01:50:48 +0100 Subject: [PATCH] community: SQLStrStore/SQLDocStore provide an easy SQL alternative to `InMemoryStore` to persist data remotely in a SQL storage (#15909) **Description:** - Implement `SQLStrStore` and `SQLDocStore` classes that inherits from `BaseStore` to allow to persist data remotely on a SQL server. - SQL is widely used and sometimes we do not want to install a caching solution like Redis. - Multiple issues/comments complain that there is no easy remote and persistent solution that are not in memory (users want to replace InMemoryStore), e.g., https://github.com/langchain-ai/langchain/issues/14267, https://github.com/langchain-ai/langchain/issues/15633, https://github.com/langchain-ai/langchain/issues/14643, https://stackoverflow.com/questions/77385587/persist-parentdocumentretriever-of-langchain - This is particularly painful when wanting to use `ParentDocumentRetriever ` - This implementation is particularly useful when: * it's expensive to construct an InMemoryDocstore/dict * you want to retrieve documents from remote sources * you just want to reuse existing objects - This implementation integrates well with PGVector, indeed, when using PGVector, you already have a SQL instance running. `SQLDocStore` is a convenient way of using this instance to store documents associated to vectors. An integration example with ParentDocumentRetriever and PGVector is provided in docs/docs/integrations/stores/sql.ipynb or [here](https://github.com/gcheron/langchain/blob/sql-store/docs/docs/integrations/stores/sql.ipynb). - It persists `str` and `Document` objects but can be easily extended. **Issue:** Provide an easy SQL alternative to `InMemoryStore`. --------- Co-authored-by: Harrison Chase --- docs/docs/integrations/stores/sql.ipynb | 186 ++++++++++ .../langchain_community/storage/__init__.py | 6 + .../langchain_community/storage/sql.py | 345 ++++++++++++++++++ .../integration_tests/storage/test_sql.py | 228 ++++++++++++ .../tests/unit_tests/storage/test_imports.py | 2 + .../tests/unit_tests/storage/test_sql.py | 7 + 6 files changed, 774 insertions(+) create mode 100644 docs/docs/integrations/stores/sql.ipynb create mode 100644 libs/community/langchain_community/storage/sql.py create mode 100644 libs/community/tests/integration_tests/storage/test_sql.py create mode 100644 libs/community/tests/unit_tests/storage/test_sql.py diff --git a/docs/docs/integrations/stores/sql.ipynb b/docs/docs/integrations/stores/sql.ipynb new file mode 100644 index 0000000000..ecb2f472a8 --- /dev/null +++ b/docs/docs/integrations/stores/sql.ipynb @@ -0,0 +1,186 @@ +{ + "cells": [ + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "---\n", + "sidebar_label: SQL\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SQLStore\n", + "\n", + "The `SQLStrStore` and `SQLDocStore` implement remote data access and persistence to store strings or LangChain documents in your SQL instance." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['value1', 'value2']\n", + "['key2']\n", + "['key2']\n" + ] + } + ], + "source": [ + "from langchain_community.storage import SQLStrStore\n", + "\n", + "# simple example using an SQLStrStore to store strings\n", + "# same as you would use in \"InMemoryStore\" but using SQL persistence\n", + "CONNECTION_STRING = \"postgresql+psycopg2://user:pass@localhost:5432/db\"\n", + "COLLECTION_NAME = \"test_collection\"\n", + "\n", + "store = SQLStrStore(\n", + " collection_name=COLLECTION_NAME,\n", + " connection_string=CONNECTION_STRING,\n", + ")\n", + "store.mset([(\"key1\", \"value1\"), (\"key2\", \"value2\")])\n", + "print(store.mget([\"key1\", \"key2\"]))\n", + "# ['value1', 'value2']\n", + "store.mdelete([\"key1\"])\n", + "print(list(store.yield_keys()))\n", + "# ['key2']\n", + "print(list(store.yield_keys(prefix=\"k\")))\n", + "# ['key2']\n", + "# delete the COLLECTION_NAME collection" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Integration with ParentRetriever and PGVector\n", + "\n", + "When using PGVector, you already have a SQL instance running. Here is a convenient way of using this instance to store documents associated to vectors. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Prepare the PGVector vectorestore with something like this:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.vectorstores import PGVector\n", + "from langchain_openai import OpenAIEmbeddings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "embeddings = OpenAIEmbeddings()\n", + "vector_db = PGVector.from_existing_index(\n", + " embedding=embeddings,\n", + " collection_name=COLLECTION_NAME,\n", + " connection_string=CONNECTION_STRING,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then create the parent retiever using `SQLDocStore` to persist the documents" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.document_loaders import TextLoader\n", + "from langchain.retrievers import ParentDocumentRetriever\n", + "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", + "from langchain_community.storage import SQLDocStore\n", + "\n", + "CONNECTION_STRING = \"postgresql+psycopg2://user:pass@localhost:5432/db\"\n", + "COLLECTION_NAME = \"state_of_the_union_test\"\n", + "docstore = SQLDocStore(\n", + " collection_name=COLLECTION_NAME,\n", + " connection_string=CONNECTION_STRING,\n", + ")\n", + "\n", + "loader = TextLoader(\"./state_of_the_union.txt\")\n", + "documents = loader.load()\n", + "\n", + "parent_splitter = RecursiveCharacterTextSplitter(chunk_size=400)\n", + "child_splitter = RecursiveCharacterTextSplitter(chunk_size=50)\n", + "\n", + "retriever = ParentDocumentRetriever(\n", + " vectorstore=vector_db,\n", + " docstore=docstore,\n", + " child_splitter=child_splitter,\n", + " parent_splitter=parent_splitter,\n", + ")\n", + "retriever.add_documents(documents)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Delete a collection" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.storage import SQLStrStore\n", + "\n", + "# delete the COLLECTION_NAME collection\n", + "CONNECTION_STRING = \"postgresql+psycopg2://user:pass@localhost:5432/db\"\n", + "COLLECTION_NAME = \"test_collection\"\n", + "store = SQLStrStore(\n", + " collection_name=COLLECTION_NAME,\n", + " connection_string=CONNECTION_STRING,\n", + ")\n", + "store.delete_collection()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.1" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/community/langchain_community/storage/__init__.py b/libs/community/langchain_community/storage/__init__.py index 494591b03c..5c28015e57 100644 --- a/libs/community/langchain_community/storage/__init__.py +++ b/libs/community/langchain_community/storage/__init__.py @@ -11,6 +11,10 @@ from langchain_community.storage.astradb import ( AstraDBStore, ) from langchain_community.storage.redis import RedisStore +from langchain_community.storage.sql import ( + SQLDocStore, + SQLStrStore, +) from langchain_community.storage.upstash_redis import ( UpstashRedisByteStore, UpstashRedisStore, @@ -22,4 +26,6 @@ __all__ = [ "RedisStore", "UpstashRedisByteStore", "UpstashRedisStore", + "SQLDocStore", + "SQLStrStore", ] diff --git a/libs/community/langchain_community/storage/sql.py b/libs/community/langchain_community/storage/sql.py new file mode 100644 index 0000000000..7baaf64285 --- /dev/null +++ b/libs/community/langchain_community/storage/sql.py @@ -0,0 +1,345 @@ +"""SQL storage that persists data in a SQL database +and supports data isolation using collections.""" +from __future__ import annotations + +import uuid +from typing import Any, Generic, Iterator, List, Optional, Sequence, Tuple, TypeVar + +import sqlalchemy +from sqlalchemy import JSON, UUID +from sqlalchemy.orm import Session, relationship + +try: + from sqlalchemy.orm import declarative_base +except ImportError: + from sqlalchemy.ext.declarative import declarative_base + +from langchain_core.documents import Document +from langchain_core.load import Serializable, dumps, loads +from langchain_core.stores import BaseStore + +V = TypeVar("V") + +ITERATOR_WINDOW_SIZE = 1000 + +Base = declarative_base() # type: Any + + +_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain" + + +class BaseModel(Base): + """Base model for the SQL stores.""" + + __abstract__ = True + uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + + +_classes: Any = None + + +def _get_storage_stores() -> Any: + global _classes + if _classes is not None: + return _classes + + class CollectionStore(BaseModel): + """Collection store.""" + + __tablename__ = "langchain_storage_collection" + + name = sqlalchemy.Column(sqlalchemy.String) + cmetadata = sqlalchemy.Column(JSON) + + items = relationship( + "ItemStore", + back_populates="collection", + passive_deletes=True, + ) + + @classmethod + def get_by_name( + cls, session: Session, name: str + ) -> Optional["CollectionStore"]: + # type: ignore + return session.query(cls).filter(cls.name == name).first() + + @classmethod + def get_or_create( + cls, + session: Session, + name: str, + cmetadata: Optional[dict] = None, + ) -> Tuple["CollectionStore", bool]: + """ + Get or create a collection. + Returns [Collection, bool] where the bool is True if the collection was created. + """ # noqa: E501 + created = False + collection = cls.get_by_name(session, name) + if collection: + return collection, created + + collection = cls(name=name, cmetadata=cmetadata) + session.add(collection) + session.commit() + created = True + return collection, created + + class ItemStore(BaseModel): + """Item store.""" + + __tablename__ = "langchain_storage_items" + + collection_id = sqlalchemy.Column( + UUID(as_uuid=True), + sqlalchemy.ForeignKey( + f"{CollectionStore.__tablename__}.uuid", + ondelete="CASCADE", + ), + ) + collection = relationship(CollectionStore, back_populates="items") + + content = sqlalchemy.Column(sqlalchemy.String, nullable=True) + + # custom_id : any user defined id + custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True) + + _classes = (ItemStore, CollectionStore) + + return _classes + + +class SQLBaseStore(BaseStore[str, V], Generic[V]): + """SQL storage + + Args: + connection_string: SQL connection string that will be passed to SQLAlchemy. + collection_name: The name of the collection to use. (default: langchain) + NOTE: Collections are useful to isolate your data in a given a database. + This is not the name of the table, but the name of the collection. + The tables will be created when initializing the store (if not exists) + So, make sure the user has the right permissions to create tables. + pre_delete_collection: If True, will delete the collection if it exists. + (default: False). Useful for testing. + engine_args: SQLAlchemy's create engine arguments. + + Example: + .. code-block:: python + + from langchain_community.storage import SQLDocStore + from langchain_community.embeddings.openai import OpenAIEmbeddings + + # example using an SQLDocStore to store Document objects for + # a ParentDocumentRetriever + CONNECTION_STRING = "postgresql+psycopg2://user:pass@localhost:5432/db" + COLLECTION_NAME = "state_of_the_union_test" + docstore = SQLDocStore( + collection_name=COLLECTION_NAME, + connection_string=CONNECTION_STRING, + ) + child_splitter = RecursiveCharacterTextSplitter(chunk_size=400) + vectorstore = ... + + retriever = ParentDocumentRetriever( + vectorstore=vectorstore, + docstore=docstore, + child_splitter=child_splitter, + ) + + # example using an SQLStrStore to store strings + # same example as in "InMemoryStore" but using SQL persistence + store = SQLDocStore( + collection_name=COLLECTION_NAME, + connection_string=CONNECTION_STRING, + ) + store.mset([('key1', 'value1'), ('key2', 'value2')]) + store.mget(['key1', 'key2']) + # ['value1', 'value2'] + store.mdelete(['key1']) + list(store.yield_keys()) + # ['key2'] + list(store.yield_keys(prefix='k')) + # ['key2'] + + # delete the COLLECTION_NAME collection + docstore.delete_collection() + """ + + def __init__( + self, + connection_string: str, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + collection_metadata: Optional[dict] = None, + pre_delete_collection: bool = False, + connection: Optional[sqlalchemy.engine.Connection] = None, + engine_args: Optional[dict[str, Any]] = None, + ) -> None: + self.connection_string = connection_string + self.collection_name = collection_name + self.collection_metadata = collection_metadata + self.pre_delete_collection = pre_delete_collection + self.engine_args = engine_args or {} + # Create a connection if not provided, otherwise use the provided connection + self._conn = connection if connection else self.__connect() + self.__post_init__() + + def __post_init__( + self, + ) -> None: + """Initialize the store.""" + ItemStore, CollectionStore = _get_storage_stores() + self.CollectionStore = CollectionStore + self.ItemStore = ItemStore + self.__create_tables_if_not_exists() + self.__create_collection() + + def __connect(self) -> sqlalchemy.engine.Connection: + engine = sqlalchemy.create_engine(self.connection_string, **self.engine_args) + conn = engine.connect() + return conn + + def __create_tables_if_not_exists(self) -> None: + with self._conn.begin(): + Base.metadata.create_all(self._conn) + + def __create_collection(self) -> None: + if self.pre_delete_collection: + self.delete_collection() + with Session(self._conn) as session: + self.CollectionStore.get_or_create( + session, self.collection_name, cmetadata=self.collection_metadata + ) + + def delete_collection(self) -> None: + with Session(self._conn) as session: + collection = self.__get_collection(session) + if not collection: + return + session.delete(collection) + session.commit() + + def __get_collection(self, session: Session) -> Any: + return self.CollectionStore.get_by_name(session, self.collection_name) + + def __del__(self) -> None: + if self._conn: + self._conn.close() + + def __serialize_value(self, obj: V) -> str: + if isinstance(obj, Serializable): + return dumps(obj) + return obj + + def __deserialize_value(self, obj: V) -> str: + try: + return loads(obj) + except Exception: + return obj + + def mget(self, keys: Sequence[str]) -> List[Optional[V]]: + """Get the values associated with the given keys. + + Args: + keys (Sequence[str]): A sequence of keys. + + Returns: + A sequence of optional values associated with the keys. + If a key is not found, the corresponding value will be None. + """ + with Session(self._conn) as session: + collection = self.__get_collection(session) + + items = ( + session.query(self.ItemStore.content, self.ItemStore.custom_id) + .where( + sqlalchemy.and_( + self.ItemStore.custom_id.in_(keys), + self.ItemStore.collection_id == (collection.uuid), + ) + ) + .all() + ) + + ordered_values = {key: None for key in keys} + for item in items: + v = item[0] + val = self.__deserialize_value(v) if v is not None else v + k = item[1] + ordered_values[k] = val + + return [ordered_values[key] for key in keys] + + def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None: + """Set the values for the given keys. + + Args: + key_value_pairs (Sequence[Tuple[str, V]]): A sequence of key-value pairs. + + Returns: + None + """ + with Session(self._conn) as session: + collection = self.__get_collection(session) + if not collection: + raise ValueError("Collection not found") + for id, item in key_value_pairs: + content = self.__serialize_value(item) + item_store = self.ItemStore( + content=content, + custom_id=id, + collection_id=collection.uuid, + ) + session.add(item_store) + session.commit() + + def mdelete(self, keys: Sequence[str]) -> None: + """Delete the given keys and their associated values. + + Args: + keys (Sequence[str]): A sequence of keys to delete. + """ + with Session(self._conn) as session: + collection = self.__get_collection(session) + if not collection: + raise ValueError("Collection not found") + if keys is not None: + stmt = sqlalchemy.delete(self.ItemStore).where( + sqlalchemy.and_( + self.ItemStore.custom_id.in_(keys), + self.ItemStore.collection_id == (collection.uuid), + ) + ) + session.execute(stmt) + session.commit() + + def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]: + """Get an iterator over keys that match the given prefix. + + Args: + prefix (str, optional): The prefix to match. Defaults to None. + + Returns: + Iterator[str]: An iterator over keys that match the given prefix. + """ + with Session(self._conn) as session: + collection = self.__get_collection(session) + start = 0 + while True: + stop = start + ITERATOR_WINDOW_SIZE + query = session.query(self.ItemStore.custom_id).where( + self.ItemStore.collection_id == (collection.uuid) + ) + if prefix is not None: + query = query.filter(self.ItemStore.custom_id.startswith(prefix)) + items = query.slice(start, stop).all() + + if len(items) == 0: + break + for item in items: + yield item[0] + start += ITERATOR_WINDOW_SIZE + + +SQLDocStore = SQLBaseStore[Document] +SQLStrStore = SQLBaseStore[str] diff --git a/libs/community/tests/integration_tests/storage/test_sql.py b/libs/community/tests/integration_tests/storage/test_sql.py new file mode 100644 index 0000000000..c09b9745c6 --- /dev/null +++ b/libs/community/tests/integration_tests/storage/test_sql.py @@ -0,0 +1,228 @@ +"""Implement integration tests for SQL storage.""" +import os + +from langchain_core.documents import Document + +from langchain_community.storage.sql import SQLDocStore, SQLStrStore + + +def connection_string_from_db_params() -> str: + """Return connection string from database parameters.""" + dbdriver = os.environ.get("TEST_SQL_DBDRIVER", "postgresql+psycopg2") + host = os.environ.get("TEST_SQL_HOST", "localhost") + port = int(os.environ.get("TEST_SQL_PORT", "5432")) + database = os.environ.get("TEST_SQL_DATABASE", "postgres") + user = os.environ.get("TEST_SQL_USER", "postgres") + password = os.environ.get("TEST_SQL_PASSWORD", "postgres") + return f"{dbdriver}://{user}:{password}@{host}:{port}/{database}" + + +CONNECTION_STRING = connection_string_from_db_params() +COLLECTION_NAME = "test_collection" +COLLECTION_NAME_2 = "test_collection_2" + + +def test_str_store_mget() -> None: + store = SQLStrStore( + collection_name=COLLECTION_NAME, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + store.mset([("key1", "value1"), ("key2", "value2")]) + + values = store.mget(["key1", "key2"]) + assert values == ["value1", "value2"] + + # Test non-existent key + non_existent_value = store.mget(["key3"]) + assert non_existent_value == [None] + store.delete_collection() + + +def test_str_store_mset() -> None: + store = SQLStrStore( + collection_name=COLLECTION_NAME, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + store.mset([("key1", "value1"), ("key2", "value2")]) + + values = store.mget(["key1", "key2"]) + assert values == ["value1", "value2"] + store.delete_collection() + + +def test_str_store_mdelete() -> None: + store = SQLStrStore( + collection_name=COLLECTION_NAME, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + store.mset([("key1", "value1"), ("key2", "value2")]) + + store.mdelete(["key1"]) + + values = store.mget(["key1", "key2"]) + assert values == [None, "value2"] + + # Test deleting non-existent key + store.mdelete(["key3"]) # No error should be raised + store.delete_collection() + + +def test_str_store_yield_keys() -> None: + store = SQLStrStore( + collection_name=COLLECTION_NAME, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + store.mset([("key1", "value1"), ("key2", "value2"), ("key3", "value3")]) + + keys = list(store.yield_keys()) + assert set(keys) == {"key1", "key2", "key3"} + + keys_with_prefix = list(store.yield_keys(prefix="key")) + assert set(keys_with_prefix) == {"key1", "key2", "key3"} + + keys_with_invalid_prefix = list(store.yield_keys(prefix="x")) + assert keys_with_invalid_prefix == [] + store.delete_collection() + + +def test_str_store_collection() -> None: + """Test that collections are isolated within a db.""" + store_1 = SQLStrStore( + collection_name=COLLECTION_NAME, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + store_2 = SQLStrStore( + collection_name=COLLECTION_NAME_2, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + + store_1.mset([("key1", "value1"), ("key2", "value2")]) + store_2.mset([("key3", "value3"), ("key4", "value4")]) + + values = store_1.mget(["key1", "key2"]) + assert values == ["value1", "value2"] + values = store_1.mget(["key3", "key4"]) + assert values == [None, None] + + values = store_2.mget(["key1", "key2"]) + assert values == [None, None] + values = store_2.mget(["key3", "key4"]) + assert values == ["value3", "value4"] + + store_1.delete_collection() + store_2.delete_collection() + + +def test_doc_store_mget() -> None: + store = SQLDocStore( + collection_name=COLLECTION_NAME, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + doc_1 = Document(page_content="value1") + doc_2 = Document(page_content="value2") + store.mset([("key1", doc_1), ("key2", doc_2)]) + + values = store.mget(["key1", "key2"]) + assert values == [doc_1, doc_2] + + # Test non-existent key + non_existent_value = store.mget(["key3"]) + assert non_existent_value == [None] + store.delete_collection() + + +def test_doc_store_mset() -> None: + store = SQLDocStore( + collection_name=COLLECTION_NAME, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + doc_1 = Document(page_content="value1") + doc_2 = Document(page_content="value2") + store.mset([("key1", doc_1), ("key2", doc_2)]) + + values = store.mget(["key1", "key2"]) + assert values == [doc_1, doc_2] + store.delete_collection() + + +def test_doc_store_mdelete() -> None: + store = SQLDocStore( + collection_name=COLLECTION_NAME, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + doc_1 = Document(page_content="value1") + doc_2 = Document(page_content="value2") + store.mset([("key1", doc_1), ("key2", doc_2)]) + + store.mdelete(["key1"]) + + values = store.mget(["key1", "key2"]) + assert values == [None, doc_2] + + # Test deleting non-existent key + store.mdelete(["key3"]) # No error should be raised + store.delete_collection() + + +def test_doc_store_yield_keys() -> None: + store = SQLDocStore( + collection_name=COLLECTION_NAME, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + doc_1 = Document(page_content="value1") + doc_2 = Document(page_content="value2") + doc_3 = Document(page_content="value3") + store.mset([("key1", doc_1), ("key2", doc_2), ("key3", doc_3)]) + + keys = list(store.yield_keys()) + assert set(keys) == {"key1", "key2", "key3"} + + keys_with_prefix = list(store.yield_keys(prefix="key")) + assert set(keys_with_prefix) == {"key1", "key2", "key3"} + + keys_with_invalid_prefix = list(store.yield_keys(prefix="x")) + assert keys_with_invalid_prefix == [] + store.delete_collection() + + +def test_doc_store_collection() -> None: + """Test that collections are isolated within a db.""" + store_1 = SQLDocStore( + collection_name=COLLECTION_NAME, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + store_2 = SQLDocStore( + collection_name=COLLECTION_NAME_2, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + doc_1 = Document(page_content="value1") + doc_2 = Document(page_content="value2") + doc_3 = Document(page_content="value3") + doc_4 = Document(page_content="value4") + store_1.mset([("key1", doc_1), ("key2", doc_2)]) + store_2.mset([("key3", doc_3), ("key4", doc_4)]) + + values = store_1.mget(["key1", "key2"]) + assert values == [doc_1, doc_2] + values = store_1.mget(["key3", "key4"]) + assert values == [None, None] + + values = store_2.mget(["key1", "key2"]) + assert values == [None, None] + values = store_2.mget(["key3", "key4"]) + assert values == [doc_3, doc_4] + + store_1.delete_collection() + store_2.delete_collection() diff --git a/libs/community/tests/unit_tests/storage/test_imports.py b/libs/community/tests/unit_tests/storage/test_imports.py index 27bcec56d4..0cc4914a74 100644 --- a/libs/community/tests/unit_tests/storage/test_imports.py +++ b/libs/community/tests/unit_tests/storage/test_imports.py @@ -6,6 +6,8 @@ EXPECTED_ALL = [ "RedisStore", "UpstashRedisByteStore", "UpstashRedisStore", + "SQLDocStore", + "SQLStrStore", ] diff --git a/libs/community/tests/unit_tests/storage/test_sql.py b/libs/community/tests/unit_tests/storage/test_sql.py new file mode 100644 index 0000000000..1d98841463 --- /dev/null +++ b/libs/community/tests/unit_tests/storage/test_sql.py @@ -0,0 +1,7 @@ +"""Light weight unit test that attempts to import SQLDocStore/SQLStrStore. +""" + + +def test_import_storage() -> None: + """Attempt to import storage modules.""" + from langchain_community.storage.sql import SQLDocStore, SQLStrStore # noqa