mirror of
https://github.com/hwchase17/langchain
synced 2024-11-16 06:13:16 +00:00
community: SQLStrStore/SQLDocStore provide an easy SQL alternative to InMemoryStore
to persist data remotely in a SQL storage (#15909)
**Description:** - Implement `SQLStrStore` and `SQLDocStore` classes that inherits from `BaseStore` to allow to persist data remotely on a SQL server. - SQL is widely used and sometimes we do not want to install a caching solution like Redis. - Multiple issues/comments complain that there is no easy remote and persistent solution that are not in memory (users want to replace InMemoryStore), e.g., https://github.com/langchain-ai/langchain/issues/14267, https://github.com/langchain-ai/langchain/issues/15633, https://github.com/langchain-ai/langchain/issues/14643, https://stackoverflow.com/questions/77385587/persist-parentdocumentretriever-of-langchain - This is particularly painful when wanting to use `ParentDocumentRetriever ` - This implementation is particularly useful when: * it's expensive to construct an InMemoryDocstore/dict * you want to retrieve documents from remote sources * you just want to reuse existing objects - This implementation integrates well with PGVector, indeed, when using PGVector, you already have a SQL instance running. `SQLDocStore` is a convenient way of using this instance to store documents associated to vectors. An integration example with ParentDocumentRetriever and PGVector is provided in docs/docs/integrations/stores/sql.ipynb or [here](https://github.com/gcheron/langchain/blob/sql-store/docs/docs/integrations/stores/sql.ipynb). - It persists `str` and `Document` objects but can be easily extended. **Issue:** Provide an easy SQL alternative to `InMemoryStore`. --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
26b2ad6d5b
commit
cfc225ecb3
186
docs/docs/integrations/stores/sql.ipynb
Normal file
186
docs/docs/integrations/stores/sql.ipynb
Normal file
@ -0,0 +1,186 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "raw",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"sidebar_label: SQL\n",
|
||||
"---"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# SQLStore\n",
|
||||
"\n",
|
||||
"The `SQLStrStore` and `SQLDocStore` implement remote data access and persistence to store strings or LangChain documents in your SQL instance."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"['value1', 'value2']\n",
|
||||
"['key2']\n",
|
||||
"['key2']\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_community.storage import SQLStrStore\n",
|
||||
"\n",
|
||||
"# simple example using an SQLStrStore to store strings\n",
|
||||
"# same as you would use in \"InMemoryStore\" but using SQL persistence\n",
|
||||
"CONNECTION_STRING = \"postgresql+psycopg2://user:pass@localhost:5432/db\"\n",
|
||||
"COLLECTION_NAME = \"test_collection\"\n",
|
||||
"\n",
|
||||
"store = SQLStrStore(\n",
|
||||
" collection_name=COLLECTION_NAME,\n",
|
||||
" connection_string=CONNECTION_STRING,\n",
|
||||
")\n",
|
||||
"store.mset([(\"key1\", \"value1\"), (\"key2\", \"value2\")])\n",
|
||||
"print(store.mget([\"key1\", \"key2\"]))\n",
|
||||
"# ['value1', 'value2']\n",
|
||||
"store.mdelete([\"key1\"])\n",
|
||||
"print(list(store.yield_keys()))\n",
|
||||
"# ['key2']\n",
|
||||
"print(list(store.yield_keys(prefix=\"k\")))\n",
|
||||
"# ['key2']\n",
|
||||
"# delete the COLLECTION_NAME collection"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Integration with ParentRetriever and PGVector\n",
|
||||
"\n",
|
||||
"When using PGVector, you already have a SQL instance running. Here is a convenient way of using this instance to store documents associated to vectors. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Prepare the PGVector vectorestore with something like this:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.vectorstores import PGVector\n",
|
||||
"from langchain_openai import OpenAIEmbeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"embeddings = OpenAIEmbeddings()\n",
|
||||
"vector_db = PGVector.from_existing_index(\n",
|
||||
" embedding=embeddings,\n",
|
||||
" collection_name=COLLECTION_NAME,\n",
|
||||
" connection_string=CONNECTION_STRING,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Then create the parent retiever using `SQLDocStore` to persist the documents"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.document_loaders import TextLoader\n",
|
||||
"from langchain.retrievers import ParentDocumentRetriever\n",
|
||||
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
|
||||
"from langchain_community.storage import SQLDocStore\n",
|
||||
"\n",
|
||||
"CONNECTION_STRING = \"postgresql+psycopg2://user:pass@localhost:5432/db\"\n",
|
||||
"COLLECTION_NAME = \"state_of_the_union_test\"\n",
|
||||
"docstore = SQLDocStore(\n",
|
||||
" collection_name=COLLECTION_NAME,\n",
|
||||
" connection_string=CONNECTION_STRING,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"loader = TextLoader(\"./state_of_the_union.txt\")\n",
|
||||
"documents = loader.load()\n",
|
||||
"\n",
|
||||
"parent_splitter = RecursiveCharacterTextSplitter(chunk_size=400)\n",
|
||||
"child_splitter = RecursiveCharacterTextSplitter(chunk_size=50)\n",
|
||||
"\n",
|
||||
"retriever = ParentDocumentRetriever(\n",
|
||||
" vectorstore=vector_db,\n",
|
||||
" docstore=docstore,\n",
|
||||
" child_splitter=child_splitter,\n",
|
||||
" parent_splitter=parent_splitter,\n",
|
||||
")\n",
|
||||
"retriever.add_documents(documents)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Delete a collection"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.storage import SQLStrStore\n",
|
||||
"\n",
|
||||
"# delete the COLLECTION_NAME collection\n",
|
||||
"CONNECTION_STRING = \"postgresql+psycopg2://user:pass@localhost:5432/db\"\n",
|
||||
"COLLECTION_NAME = \"test_collection\"\n",
|
||||
"store = SQLStrStore(\n",
|
||||
" collection_name=COLLECTION_NAME,\n",
|
||||
" connection_string=CONNECTION_STRING,\n",
|
||||
")\n",
|
||||
"store.delete_collection()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -11,6 +11,10 @@ from langchain_community.storage.astradb import (
|
||||
AstraDBStore,
|
||||
)
|
||||
from langchain_community.storage.redis import RedisStore
|
||||
from langchain_community.storage.sql import (
|
||||
SQLDocStore,
|
||||
SQLStrStore,
|
||||
)
|
||||
from langchain_community.storage.upstash_redis import (
|
||||
UpstashRedisByteStore,
|
||||
UpstashRedisStore,
|
||||
@ -22,4 +26,6 @@ __all__ = [
|
||||
"RedisStore",
|
||||
"UpstashRedisByteStore",
|
||||
"UpstashRedisStore",
|
||||
"SQLDocStore",
|
||||
"SQLStrStore",
|
||||
]
|
||||
|
345
libs/community/langchain_community/storage/sql.py
Normal file
345
libs/community/langchain_community/storage/sql.py
Normal file
@ -0,0 +1,345 @@
|
||||
"""SQL storage that persists data in a SQL database
|
||||
and supports data isolation using collections."""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Any, Generic, Iterator, List, Optional, Sequence, Tuple, TypeVar
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy import JSON, UUID
|
||||
from sqlalchemy.orm import Session, relationship
|
||||
|
||||
try:
|
||||
from sqlalchemy.orm import declarative_base
|
||||
except ImportError:
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.load import Serializable, dumps, loads
|
||||
from langchain_core.stores import BaseStore
|
||||
|
||||
V = TypeVar("V")
|
||||
|
||||
ITERATOR_WINDOW_SIZE = 1000
|
||||
|
||||
Base = declarative_base() # type: Any
|
||||
|
||||
|
||||
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
|
||||
|
||||
|
||||
class BaseModel(Base):
|
||||
"""Base model for the SQL stores."""
|
||||
|
||||
__abstract__ = True
|
||||
uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
|
||||
_classes: Any = None
|
||||
|
||||
|
||||
def _get_storage_stores() -> Any:
|
||||
global _classes
|
||||
if _classes is not None:
|
||||
return _classes
|
||||
|
||||
class CollectionStore(BaseModel):
|
||||
"""Collection store."""
|
||||
|
||||
__tablename__ = "langchain_storage_collection"
|
||||
|
||||
name = sqlalchemy.Column(sqlalchemy.String)
|
||||
cmetadata = sqlalchemy.Column(JSON)
|
||||
|
||||
items = relationship(
|
||||
"ItemStore",
|
||||
back_populates="collection",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_by_name(
|
||||
cls, session: Session, name: str
|
||||
) -> Optional["CollectionStore"]:
|
||||
# type: ignore
|
||||
return session.query(cls).filter(cls.name == name).first()
|
||||
|
||||
@classmethod
|
||||
def get_or_create(
|
||||
cls,
|
||||
session: Session,
|
||||
name: str,
|
||||
cmetadata: Optional[dict] = None,
|
||||
) -> Tuple["CollectionStore", bool]:
|
||||
"""
|
||||
Get or create a collection.
|
||||
Returns [Collection, bool] where the bool is True if the collection was created.
|
||||
""" # noqa: E501
|
||||
created = False
|
||||
collection = cls.get_by_name(session, name)
|
||||
if collection:
|
||||
return collection, created
|
||||
|
||||
collection = cls(name=name, cmetadata=cmetadata)
|
||||
session.add(collection)
|
||||
session.commit()
|
||||
created = True
|
||||
return collection, created
|
||||
|
||||
class ItemStore(BaseModel):
|
||||
"""Item store."""
|
||||
|
||||
__tablename__ = "langchain_storage_items"
|
||||
|
||||
collection_id = sqlalchemy.Column(
|
||||
UUID(as_uuid=True),
|
||||
sqlalchemy.ForeignKey(
|
||||
f"{CollectionStore.__tablename__}.uuid",
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
)
|
||||
collection = relationship(CollectionStore, back_populates="items")
|
||||
|
||||
content = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
||||
|
||||
# custom_id : any user defined id
|
||||
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
||||
|
||||
_classes = (ItemStore, CollectionStore)
|
||||
|
||||
return _classes
|
||||
|
||||
|
||||
class SQLBaseStore(BaseStore[str, V], Generic[V]):
|
||||
"""SQL storage
|
||||
|
||||
Args:
|
||||
connection_string: SQL connection string that will be passed to SQLAlchemy.
|
||||
collection_name: The name of the collection to use. (default: langchain)
|
||||
NOTE: Collections are useful to isolate your data in a given a database.
|
||||
This is not the name of the table, but the name of the collection.
|
||||
The tables will be created when initializing the store (if not exists)
|
||||
So, make sure the user has the right permissions to create tables.
|
||||
pre_delete_collection: If True, will delete the collection if it exists.
|
||||
(default: False). Useful for testing.
|
||||
engine_args: SQLAlchemy's create engine arguments.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.storage import SQLDocStore
|
||||
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||
|
||||
# example using an SQLDocStore to store Document objects for
|
||||
# a ParentDocumentRetriever
|
||||
CONNECTION_STRING = "postgresql+psycopg2://user:pass@localhost:5432/db"
|
||||
COLLECTION_NAME = "state_of_the_union_test"
|
||||
docstore = SQLDocStore(
|
||||
collection_name=COLLECTION_NAME,
|
||||
connection_string=CONNECTION_STRING,
|
||||
)
|
||||
child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
|
||||
vectorstore = ...
|
||||
|
||||
retriever = ParentDocumentRetriever(
|
||||
vectorstore=vectorstore,
|
||||
docstore=docstore,
|
||||
child_splitter=child_splitter,
|
||||
)
|
||||
|
||||
# example using an SQLStrStore to store strings
|
||||
# same example as in "InMemoryStore" but using SQL persistence
|
||||
store = SQLDocStore(
|
||||
collection_name=COLLECTION_NAME,
|
||||
connection_string=CONNECTION_STRING,
|
||||
)
|
||||
store.mset([('key1', 'value1'), ('key2', 'value2')])
|
||||
store.mget(['key1', 'key2'])
|
||||
# ['value1', 'value2']
|
||||
store.mdelete(['key1'])
|
||||
list(store.yield_keys())
|
||||
# ['key2']
|
||||
list(store.yield_keys(prefix='k'))
|
||||
# ['key2']
|
||||
|
||||
# delete the COLLECTION_NAME collection
|
||||
docstore.delete_collection()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_string: str,
|
||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||
collection_metadata: Optional[dict] = None,
|
||||
pre_delete_collection: bool = False,
|
||||
connection: Optional[sqlalchemy.engine.Connection] = None,
|
||||
engine_args: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.connection_string = connection_string
|
||||
self.collection_name = collection_name
|
||||
self.collection_metadata = collection_metadata
|
||||
self.pre_delete_collection = pre_delete_collection
|
||||
self.engine_args = engine_args or {}
|
||||
# Create a connection if not provided, otherwise use the provided connection
|
||||
self._conn = connection if connection else self.__connect()
|
||||
self.__post_init__()
|
||||
|
||||
def __post_init__(
|
||||
self,
|
||||
) -> None:
|
||||
"""Initialize the store."""
|
||||
ItemStore, CollectionStore = _get_storage_stores()
|
||||
self.CollectionStore = CollectionStore
|
||||
self.ItemStore = ItemStore
|
||||
self.__create_tables_if_not_exists()
|
||||
self.__create_collection()
|
||||
|
||||
def __connect(self) -> sqlalchemy.engine.Connection:
|
||||
engine = sqlalchemy.create_engine(self.connection_string, **self.engine_args)
|
||||
conn = engine.connect()
|
||||
return conn
|
||||
|
||||
def __create_tables_if_not_exists(self) -> None:
|
||||
with self._conn.begin():
|
||||
Base.metadata.create_all(self._conn)
|
||||
|
||||
def __create_collection(self) -> None:
|
||||
if self.pre_delete_collection:
|
||||
self.delete_collection()
|
||||
with Session(self._conn) as session:
|
||||
self.CollectionStore.get_or_create(
|
||||
session, self.collection_name, cmetadata=self.collection_metadata
|
||||
)
|
||||
|
||||
def delete_collection(self) -> None:
|
||||
with Session(self._conn) as session:
|
||||
collection = self.__get_collection(session)
|
||||
if not collection:
|
||||
return
|
||||
session.delete(collection)
|
||||
session.commit()
|
||||
|
||||
def __get_collection(self, session: Session) -> Any:
|
||||
return self.CollectionStore.get_by_name(session, self.collection_name)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self._conn:
|
||||
self._conn.close()
|
||||
|
||||
def __serialize_value(self, obj: V) -> str:
|
||||
if isinstance(obj, Serializable):
|
||||
return dumps(obj)
|
||||
return obj
|
||||
|
||||
def __deserialize_value(self, obj: V) -> str:
|
||||
try:
|
||||
return loads(obj)
|
||||
except Exception:
|
||||
return obj
|
||||
|
||||
def mget(self, keys: Sequence[str]) -> List[Optional[V]]:
|
||||
"""Get the values associated with the given keys.
|
||||
|
||||
Args:
|
||||
keys (Sequence[str]): A sequence of keys.
|
||||
|
||||
Returns:
|
||||
A sequence of optional values associated with the keys.
|
||||
If a key is not found, the corresponding value will be None.
|
||||
"""
|
||||
with Session(self._conn) as session:
|
||||
collection = self.__get_collection(session)
|
||||
|
||||
items = (
|
||||
session.query(self.ItemStore.content, self.ItemStore.custom_id)
|
||||
.where(
|
||||
sqlalchemy.and_(
|
||||
self.ItemStore.custom_id.in_(keys),
|
||||
self.ItemStore.collection_id == (collection.uuid),
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
ordered_values = {key: None for key in keys}
|
||||
for item in items:
|
||||
v = item[0]
|
||||
val = self.__deserialize_value(v) if v is not None else v
|
||||
k = item[1]
|
||||
ordered_values[k] = val
|
||||
|
||||
return [ordered_values[key] for key in keys]
|
||||
|
||||
def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
|
||||
"""Set the values for the given keys.
|
||||
|
||||
Args:
|
||||
key_value_pairs (Sequence[Tuple[str, V]]): A sequence of key-value pairs.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
with Session(self._conn) as session:
|
||||
collection = self.__get_collection(session)
|
||||
if not collection:
|
||||
raise ValueError("Collection not found")
|
||||
for id, item in key_value_pairs:
|
||||
content = self.__serialize_value(item)
|
||||
item_store = self.ItemStore(
|
||||
content=content,
|
||||
custom_id=id,
|
||||
collection_id=collection.uuid,
|
||||
)
|
||||
session.add(item_store)
|
||||
session.commit()
|
||||
|
||||
def mdelete(self, keys: Sequence[str]) -> None:
|
||||
"""Delete the given keys and their associated values.
|
||||
|
||||
Args:
|
||||
keys (Sequence[str]): A sequence of keys to delete.
|
||||
"""
|
||||
with Session(self._conn) as session:
|
||||
collection = self.__get_collection(session)
|
||||
if not collection:
|
||||
raise ValueError("Collection not found")
|
||||
if keys is not None:
|
||||
stmt = sqlalchemy.delete(self.ItemStore).where(
|
||||
sqlalchemy.and_(
|
||||
self.ItemStore.custom_id.in_(keys),
|
||||
self.ItemStore.collection_id == (collection.uuid),
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]:
|
||||
"""Get an iterator over keys that match the given prefix.
|
||||
|
||||
Args:
|
||||
prefix (str, optional): The prefix to match. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Iterator[str]: An iterator over keys that match the given prefix.
|
||||
"""
|
||||
with Session(self._conn) as session:
|
||||
collection = self.__get_collection(session)
|
||||
start = 0
|
||||
while True:
|
||||
stop = start + ITERATOR_WINDOW_SIZE
|
||||
query = session.query(self.ItemStore.custom_id).where(
|
||||
self.ItemStore.collection_id == (collection.uuid)
|
||||
)
|
||||
if prefix is not None:
|
||||
query = query.filter(self.ItemStore.custom_id.startswith(prefix))
|
||||
items = query.slice(start, stop).all()
|
||||
|
||||
if len(items) == 0:
|
||||
break
|
||||
for item in items:
|
||||
yield item[0]
|
||||
start += ITERATOR_WINDOW_SIZE
|
||||
|
||||
|
||||
SQLDocStore = SQLBaseStore[Document]
|
||||
SQLStrStore = SQLBaseStore[str]
|
228
libs/community/tests/integration_tests/storage/test_sql.py
Normal file
228
libs/community/tests/integration_tests/storage/test_sql.py
Normal file
@ -0,0 +1,228 @@
|
||||
"""Implement integration tests for SQL storage."""
|
||||
import os
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.storage.sql import SQLDocStore, SQLStrStore
|
||||
|
||||
|
||||
def connection_string_from_db_params() -> str:
|
||||
"""Return connection string from database parameters."""
|
||||
dbdriver = os.environ.get("TEST_SQL_DBDRIVER", "postgresql+psycopg2")
|
||||
host = os.environ.get("TEST_SQL_HOST", "localhost")
|
||||
port = int(os.environ.get("TEST_SQL_PORT", "5432"))
|
||||
database = os.environ.get("TEST_SQL_DATABASE", "postgres")
|
||||
user = os.environ.get("TEST_SQL_USER", "postgres")
|
||||
password = os.environ.get("TEST_SQL_PASSWORD", "postgres")
|
||||
return f"{dbdriver}://{user}:{password}@{host}:{port}/{database}"
|
||||
|
||||
|
||||
CONNECTION_STRING = connection_string_from_db_params()
|
||||
COLLECTION_NAME = "test_collection"
|
||||
COLLECTION_NAME_2 = "test_collection_2"
|
||||
|
||||
|
||||
def test_str_store_mget() -> None:
|
||||
store = SQLStrStore(
|
||||
collection_name=COLLECTION_NAME,
|
||||
connection_string=CONNECTION_STRING,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
store.mset([("key1", "value1"), ("key2", "value2")])
|
||||
|
||||
values = store.mget(["key1", "key2"])
|
||||
assert values == ["value1", "value2"]
|
||||
|
||||
# Test non-existent key
|
||||
non_existent_value = store.mget(["key3"])
|
||||
assert non_existent_value == [None]
|
||||
store.delete_collection()
|
||||
|
||||
|
||||
def test_str_store_mset() -> None:
|
||||
store = SQLStrStore(
|
||||
collection_name=COLLECTION_NAME,
|
||||
connection_string=CONNECTION_STRING,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
store.mset([("key1", "value1"), ("key2", "value2")])
|
||||
|
||||
values = store.mget(["key1", "key2"])
|
||||
assert values == ["value1", "value2"]
|
||||
store.delete_collection()
|
||||
|
||||
|
||||
def test_str_store_mdelete() -> None:
|
||||
store = SQLStrStore(
|
||||
collection_name=COLLECTION_NAME,
|
||||
connection_string=CONNECTION_STRING,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
store.mset([("key1", "value1"), ("key2", "value2")])
|
||||
|
||||
store.mdelete(["key1"])
|
||||
|
||||
values = store.mget(["key1", "key2"])
|
||||
assert values == [None, "value2"]
|
||||
|
||||
# Test deleting non-existent key
|
||||
store.mdelete(["key3"]) # No error should be raised
|
||||
store.delete_collection()
|
||||
|
||||
|
||||
def test_str_store_yield_keys() -> None:
|
||||
store = SQLStrStore(
|
||||
collection_name=COLLECTION_NAME,
|
||||
connection_string=CONNECTION_STRING,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
store.mset([("key1", "value1"), ("key2", "value2"), ("key3", "value3")])
|
||||
|
||||
keys = list(store.yield_keys())
|
||||
assert set(keys) == {"key1", "key2", "key3"}
|
||||
|
||||
keys_with_prefix = list(store.yield_keys(prefix="key"))
|
||||
assert set(keys_with_prefix) == {"key1", "key2", "key3"}
|
||||
|
||||
keys_with_invalid_prefix = list(store.yield_keys(prefix="x"))
|
||||
assert keys_with_invalid_prefix == []
|
||||
store.delete_collection()
|
||||
|
||||
|
||||
def test_str_store_collection() -> None:
|
||||
"""Test that collections are isolated within a db."""
|
||||
store_1 = SQLStrStore(
|
||||
collection_name=COLLECTION_NAME,
|
||||
connection_string=CONNECTION_STRING,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
store_2 = SQLStrStore(
|
||||
collection_name=COLLECTION_NAME_2,
|
||||
connection_string=CONNECTION_STRING,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
|
||||
store_1.mset([("key1", "value1"), ("key2", "value2")])
|
||||
store_2.mset([("key3", "value3"), ("key4", "value4")])
|
||||
|
||||
values = store_1.mget(["key1", "key2"])
|
||||
assert values == ["value1", "value2"]
|
||||
values = store_1.mget(["key3", "key4"])
|
||||
assert values == [None, None]
|
||||
|
||||
values = store_2.mget(["key1", "key2"])
|
||||
assert values == [None, None]
|
||||
values = store_2.mget(["key3", "key4"])
|
||||
assert values == ["value3", "value4"]
|
||||
|
||||
store_1.delete_collection()
|
||||
store_2.delete_collection()
|
||||
|
||||
|
||||
def test_doc_store_mget() -> None:
|
||||
store = SQLDocStore(
|
||||
collection_name=COLLECTION_NAME,
|
||||
connection_string=CONNECTION_STRING,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
doc_1 = Document(page_content="value1")
|
||||
doc_2 = Document(page_content="value2")
|
||||
store.mset([("key1", doc_1), ("key2", doc_2)])
|
||||
|
||||
values = store.mget(["key1", "key2"])
|
||||
assert values == [doc_1, doc_2]
|
||||
|
||||
# Test non-existent key
|
||||
non_existent_value = store.mget(["key3"])
|
||||
assert non_existent_value == [None]
|
||||
store.delete_collection()
|
||||
|
||||
|
||||
def test_doc_store_mset() -> None:
|
||||
store = SQLDocStore(
|
||||
collection_name=COLLECTION_NAME,
|
||||
connection_string=CONNECTION_STRING,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
doc_1 = Document(page_content="value1")
|
||||
doc_2 = Document(page_content="value2")
|
||||
store.mset([("key1", doc_1), ("key2", doc_2)])
|
||||
|
||||
values = store.mget(["key1", "key2"])
|
||||
assert values == [doc_1, doc_2]
|
||||
store.delete_collection()
|
||||
|
||||
|
||||
def test_doc_store_mdelete() -> None:
|
||||
store = SQLDocStore(
|
||||
collection_name=COLLECTION_NAME,
|
||||
connection_string=CONNECTION_STRING,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
doc_1 = Document(page_content="value1")
|
||||
doc_2 = Document(page_content="value2")
|
||||
store.mset([("key1", doc_1), ("key2", doc_2)])
|
||||
|
||||
store.mdelete(["key1"])
|
||||
|
||||
values = store.mget(["key1", "key2"])
|
||||
assert values == [None, doc_2]
|
||||
|
||||
# Test deleting non-existent key
|
||||
store.mdelete(["key3"]) # No error should be raised
|
||||
store.delete_collection()
|
||||
|
||||
|
||||
def test_doc_store_yield_keys() -> None:
|
||||
store = SQLDocStore(
|
||||
collection_name=COLLECTION_NAME,
|
||||
connection_string=CONNECTION_STRING,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
doc_1 = Document(page_content="value1")
|
||||
doc_2 = Document(page_content="value2")
|
||||
doc_3 = Document(page_content="value3")
|
||||
store.mset([("key1", doc_1), ("key2", doc_2), ("key3", doc_3)])
|
||||
|
||||
keys = list(store.yield_keys())
|
||||
assert set(keys) == {"key1", "key2", "key3"}
|
||||
|
||||
keys_with_prefix = list(store.yield_keys(prefix="key"))
|
||||
assert set(keys_with_prefix) == {"key1", "key2", "key3"}
|
||||
|
||||
keys_with_invalid_prefix = list(store.yield_keys(prefix="x"))
|
||||
assert keys_with_invalid_prefix == []
|
||||
store.delete_collection()
|
||||
|
||||
|
||||
def test_doc_store_collection() -> None:
|
||||
"""Test that collections are isolated within a db."""
|
||||
store_1 = SQLDocStore(
|
||||
collection_name=COLLECTION_NAME,
|
||||
connection_string=CONNECTION_STRING,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
store_2 = SQLDocStore(
|
||||
collection_name=COLLECTION_NAME_2,
|
||||
connection_string=CONNECTION_STRING,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
doc_1 = Document(page_content="value1")
|
||||
doc_2 = Document(page_content="value2")
|
||||
doc_3 = Document(page_content="value3")
|
||||
doc_4 = Document(page_content="value4")
|
||||
store_1.mset([("key1", doc_1), ("key2", doc_2)])
|
||||
store_2.mset([("key3", doc_3), ("key4", doc_4)])
|
||||
|
||||
values = store_1.mget(["key1", "key2"])
|
||||
assert values == [doc_1, doc_2]
|
||||
values = store_1.mget(["key3", "key4"])
|
||||
assert values == [None, None]
|
||||
|
||||
values = store_2.mget(["key1", "key2"])
|
||||
assert values == [None, None]
|
||||
values = store_2.mget(["key3", "key4"])
|
||||
assert values == [doc_3, doc_4]
|
||||
|
||||
store_1.delete_collection()
|
||||
store_2.delete_collection()
|
@ -6,6 +6,8 @@ EXPECTED_ALL = [
|
||||
"RedisStore",
|
||||
"UpstashRedisByteStore",
|
||||
"UpstashRedisStore",
|
||||
"SQLDocStore",
|
||||
"SQLStrStore",
|
||||
]
|
||||
|
||||
|
||||
|
7
libs/community/tests/unit_tests/storage/test_sql.py
Normal file
7
libs/community/tests/unit_tests/storage/test_sql.py
Normal file
@ -0,0 +1,7 @@
|
||||
"""Light weight unit test that attempts to import SQLDocStore/SQLStrStore.
|
||||
"""
|
||||
|
||||
|
||||
def test_import_storage() -> None:
|
||||
"""Attempt to import storage modules."""
|
||||
from langchain_community.storage.sql import SQLDocStore, SQLStrStore # noqa
|
Loading…
Reference in New Issue
Block a user