mirror of
https://github.com/hwchase17/langchain
synced 2024-11-16 06:13:16 +00:00
community: SQLStrStore/SQLDocStore provide an easy SQL alternative to InMemoryStore
to persist data remotely in a SQL storage (#15909)
**Description:** - Implement `SQLStrStore` and `SQLDocStore` classes that inherits from `BaseStore` to allow to persist data remotely on a SQL server. - SQL is widely used and sometimes we do not want to install a caching solution like Redis. - Multiple issues/comments complain that there is no easy remote and persistent solution that are not in memory (users want to replace InMemoryStore), e.g., https://github.com/langchain-ai/langchain/issues/14267, https://github.com/langchain-ai/langchain/issues/15633, https://github.com/langchain-ai/langchain/issues/14643, https://stackoverflow.com/questions/77385587/persist-parentdocumentretriever-of-langchain - This is particularly painful when wanting to use `ParentDocumentRetriever ` - This implementation is particularly useful when: * it's expensive to construct an InMemoryDocstore/dict * you want to retrieve documents from remote sources * you just want to reuse existing objects - This implementation integrates well with PGVector, indeed, when using PGVector, you already have a SQL instance running. `SQLDocStore` is a convenient way of using this instance to store documents associated to vectors. An integration example with ParentDocumentRetriever and PGVector is provided in docs/docs/integrations/stores/sql.ipynb or [here](https://github.com/gcheron/langchain/blob/sql-store/docs/docs/integrations/stores/sql.ipynb). - It persists `str` and `Document` objects but can be easily extended. **Issue:** Provide an easy SQL alternative to `InMemoryStore`. --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
26b2ad6d5b
commit
cfc225ecb3
186
docs/docs/integrations/stores/sql.ipynb
Normal file
186
docs/docs/integrations/stores/sql.ipynb
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "raw",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"---\n",
|
||||||
|
"sidebar_label: SQL\n",
|
||||||
|
"---"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# SQLStore\n",
|
||||||
|
"\n",
|
||||||
|
"The `SQLStrStore` and `SQLDocStore` implement remote data access and persistence to store strings or LangChain documents in your SQL instance."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"['value1', 'value2']\n",
|
||||||
|
"['key2']\n",
|
||||||
|
"['key2']\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from langchain_community.storage import SQLStrStore\n",
|
||||||
|
"\n",
|
||||||
|
"# simple example using an SQLStrStore to store strings\n",
|
||||||
|
"# same as you would use in \"InMemoryStore\" but using SQL persistence\n",
|
||||||
|
"CONNECTION_STRING = \"postgresql+psycopg2://user:pass@localhost:5432/db\"\n",
|
||||||
|
"COLLECTION_NAME = \"test_collection\"\n",
|
||||||
|
"\n",
|
||||||
|
"store = SQLStrStore(\n",
|
||||||
|
" collection_name=COLLECTION_NAME,\n",
|
||||||
|
" connection_string=CONNECTION_STRING,\n",
|
||||||
|
")\n",
|
||||||
|
"store.mset([(\"key1\", \"value1\"), (\"key2\", \"value2\")])\n",
|
||||||
|
"print(store.mget([\"key1\", \"key2\"]))\n",
|
||||||
|
"# ['value1', 'value2']\n",
|
||||||
|
"store.mdelete([\"key1\"])\n",
|
||||||
|
"print(list(store.yield_keys()))\n",
|
||||||
|
"# ['key2']\n",
|
||||||
|
"print(list(store.yield_keys(prefix=\"k\")))\n",
|
||||||
|
"# ['key2']\n",
|
||||||
|
"# delete the COLLECTION_NAME collection"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Integration with ParentRetriever and PGVector\n",
|
||||||
|
"\n",
|
||||||
|
"When using PGVector, you already have a SQL instance running. Here is a convenient way of using this instance to store documents associated to vectors. "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Prepare the PGVector vectorestore with something like this:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain_community.vectorstores import PGVector\n",
|
||||||
|
"from langchain_openai import OpenAIEmbeddings"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"embeddings = OpenAIEmbeddings()\n",
|
||||||
|
"vector_db = PGVector.from_existing_index(\n",
|
||||||
|
" embedding=embeddings,\n",
|
||||||
|
" collection_name=COLLECTION_NAME,\n",
|
||||||
|
" connection_string=CONNECTION_STRING,\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Then create the parent retiever using `SQLDocStore` to persist the documents"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.document_loaders import TextLoader\n",
|
||||||
|
"from langchain.retrievers import ParentDocumentRetriever\n",
|
||||||
|
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
|
||||||
|
"from langchain_community.storage import SQLDocStore\n",
|
||||||
|
"\n",
|
||||||
|
"CONNECTION_STRING = \"postgresql+psycopg2://user:pass@localhost:5432/db\"\n",
|
||||||
|
"COLLECTION_NAME = \"state_of_the_union_test\"\n",
|
||||||
|
"docstore = SQLDocStore(\n",
|
||||||
|
" collection_name=COLLECTION_NAME,\n",
|
||||||
|
" connection_string=CONNECTION_STRING,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"loader = TextLoader(\"./state_of_the_union.txt\")\n",
|
||||||
|
"documents = loader.load()\n",
|
||||||
|
"\n",
|
||||||
|
"parent_splitter = RecursiveCharacterTextSplitter(chunk_size=400)\n",
|
||||||
|
"child_splitter = RecursiveCharacterTextSplitter(chunk_size=50)\n",
|
||||||
|
"\n",
|
||||||
|
"retriever = ParentDocumentRetriever(\n",
|
||||||
|
" vectorstore=vector_db,\n",
|
||||||
|
" docstore=docstore,\n",
|
||||||
|
" child_splitter=child_splitter,\n",
|
||||||
|
" parent_splitter=parent_splitter,\n",
|
||||||
|
")\n",
|
||||||
|
"retriever.add_documents(documents)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Delete a collection"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain_community.storage import SQLStrStore\n",
|
||||||
|
"\n",
|
||||||
|
"# delete the COLLECTION_NAME collection\n",
|
||||||
|
"CONNECTION_STRING = \"postgresql+psycopg2://user:pass@localhost:5432/db\"\n",
|
||||||
|
"COLLECTION_NAME = \"test_collection\"\n",
|
||||||
|
"store = SQLStrStore(\n",
|
||||||
|
" collection_name=COLLECTION_NAME,\n",
|
||||||
|
" connection_string=CONNECTION_STRING,\n",
|
||||||
|
")\n",
|
||||||
|
"store.delete_collection()"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
@ -11,6 +11,10 @@ from langchain_community.storage.astradb import (
|
|||||||
AstraDBStore,
|
AstraDBStore,
|
||||||
)
|
)
|
||||||
from langchain_community.storage.redis import RedisStore
|
from langchain_community.storage.redis import RedisStore
|
||||||
|
from langchain_community.storage.sql import (
|
||||||
|
SQLDocStore,
|
||||||
|
SQLStrStore,
|
||||||
|
)
|
||||||
from langchain_community.storage.upstash_redis import (
|
from langchain_community.storage.upstash_redis import (
|
||||||
UpstashRedisByteStore,
|
UpstashRedisByteStore,
|
||||||
UpstashRedisStore,
|
UpstashRedisStore,
|
||||||
@ -22,4 +26,6 @@ __all__ = [
|
|||||||
"RedisStore",
|
"RedisStore",
|
||||||
"UpstashRedisByteStore",
|
"UpstashRedisByteStore",
|
||||||
"UpstashRedisStore",
|
"UpstashRedisStore",
|
||||||
|
"SQLDocStore",
|
||||||
|
"SQLStrStore",
|
||||||
]
|
]
|
||||||
|
345
libs/community/langchain_community/storage/sql.py
Normal file
345
libs/community/langchain_community/storage/sql.py
Normal file
@ -0,0 +1,345 @@
|
|||||||
|
"""SQL storage that persists data in a SQL database
|
||||||
|
and supports data isolation using collections."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Generic, Iterator, List, Optional, Sequence, Tuple, TypeVar
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
|
from sqlalchemy import JSON, UUID
|
||||||
|
from sqlalchemy.orm import Session, relationship
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sqlalchemy.orm import declarative_base
|
||||||
|
except ImportError:
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.load import Serializable, dumps, loads
|
||||||
|
from langchain_core.stores import BaseStore
|
||||||
|
|
||||||
|
V = TypeVar("V")
|
||||||
|
|
||||||
|
ITERATOR_WINDOW_SIZE = 1000
|
||||||
|
|
||||||
|
Base = declarative_base() # type: Any
|
||||||
|
|
||||||
|
|
||||||
|
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModel(Base):
|
||||||
|
"""Base model for the SQL stores."""
|
||||||
|
|
||||||
|
__abstract__ = True
|
||||||
|
uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
|
||||||
|
|
||||||
|
_classes: Any = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_storage_stores() -> Any:
|
||||||
|
global _classes
|
||||||
|
if _classes is not None:
|
||||||
|
return _classes
|
||||||
|
|
||||||
|
class CollectionStore(BaseModel):
|
||||||
|
"""Collection store."""
|
||||||
|
|
||||||
|
__tablename__ = "langchain_storage_collection"
|
||||||
|
|
||||||
|
name = sqlalchemy.Column(sqlalchemy.String)
|
||||||
|
cmetadata = sqlalchemy.Column(JSON)
|
||||||
|
|
||||||
|
items = relationship(
|
||||||
|
"ItemStore",
|
||||||
|
back_populates="collection",
|
||||||
|
passive_deletes=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_by_name(
|
||||||
|
cls, session: Session, name: str
|
||||||
|
) -> Optional["CollectionStore"]:
|
||||||
|
# type: ignore
|
||||||
|
return session.query(cls).filter(cls.name == name).first()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_or_create(
|
||||||
|
cls,
|
||||||
|
session: Session,
|
||||||
|
name: str,
|
||||||
|
cmetadata: Optional[dict] = None,
|
||||||
|
) -> Tuple["CollectionStore", bool]:
|
||||||
|
"""
|
||||||
|
Get or create a collection.
|
||||||
|
Returns [Collection, bool] where the bool is True if the collection was created.
|
||||||
|
""" # noqa: E501
|
||||||
|
created = False
|
||||||
|
collection = cls.get_by_name(session, name)
|
||||||
|
if collection:
|
||||||
|
return collection, created
|
||||||
|
|
||||||
|
collection = cls(name=name, cmetadata=cmetadata)
|
||||||
|
session.add(collection)
|
||||||
|
session.commit()
|
||||||
|
created = True
|
||||||
|
return collection, created
|
||||||
|
|
||||||
|
class ItemStore(BaseModel):
|
||||||
|
"""Item store."""
|
||||||
|
|
||||||
|
__tablename__ = "langchain_storage_items"
|
||||||
|
|
||||||
|
collection_id = sqlalchemy.Column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
sqlalchemy.ForeignKey(
|
||||||
|
f"{CollectionStore.__tablename__}.uuid",
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
collection = relationship(CollectionStore, back_populates="items")
|
||||||
|
|
||||||
|
content = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
||||||
|
|
||||||
|
# custom_id : any user defined id
|
||||||
|
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
||||||
|
|
||||||
|
_classes = (ItemStore, CollectionStore)
|
||||||
|
|
||||||
|
return _classes
|
||||||
|
|
||||||
|
|
||||||
|
class SQLBaseStore(BaseStore[str, V], Generic[V]):
|
||||||
|
"""SQL storage
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection_string: SQL connection string that will be passed to SQLAlchemy.
|
||||||
|
collection_name: The name of the collection to use. (default: langchain)
|
||||||
|
NOTE: Collections are useful to isolate your data in a given a database.
|
||||||
|
This is not the name of the table, but the name of the collection.
|
||||||
|
The tables will be created when initializing the store (if not exists)
|
||||||
|
So, make sure the user has the right permissions to create tables.
|
||||||
|
pre_delete_collection: If True, will delete the collection if it exists.
|
||||||
|
(default: False). Useful for testing.
|
||||||
|
engine_args: SQLAlchemy's create engine arguments.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_community.storage import SQLDocStore
|
||||||
|
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||||
|
|
||||||
|
# example using an SQLDocStore to store Document objects for
|
||||||
|
# a ParentDocumentRetriever
|
||||||
|
CONNECTION_STRING = "postgresql+psycopg2://user:pass@localhost:5432/db"
|
||||||
|
COLLECTION_NAME = "state_of_the_union_test"
|
||||||
|
docstore = SQLDocStore(
|
||||||
|
collection_name=COLLECTION_NAME,
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
)
|
||||||
|
child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
|
||||||
|
vectorstore = ...
|
||||||
|
|
||||||
|
retriever = ParentDocumentRetriever(
|
||||||
|
vectorstore=vectorstore,
|
||||||
|
docstore=docstore,
|
||||||
|
child_splitter=child_splitter,
|
||||||
|
)
|
||||||
|
|
||||||
|
# example using an SQLStrStore to store strings
|
||||||
|
# same example as in "InMemoryStore" but using SQL persistence
|
||||||
|
store = SQLDocStore(
|
||||||
|
collection_name=COLLECTION_NAME,
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
)
|
||||||
|
store.mset([('key1', 'value1'), ('key2', 'value2')])
|
||||||
|
store.mget(['key1', 'key2'])
|
||||||
|
# ['value1', 'value2']
|
||||||
|
store.mdelete(['key1'])
|
||||||
|
list(store.yield_keys())
|
||||||
|
# ['key2']
|
||||||
|
list(store.yield_keys(prefix='k'))
|
||||||
|
# ['key2']
|
||||||
|
|
||||||
|
# delete the COLLECTION_NAME collection
|
||||||
|
docstore.delete_collection()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
connection_string: str,
|
||||||
|
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||||
|
collection_metadata: Optional[dict] = None,
|
||||||
|
pre_delete_collection: bool = False,
|
||||||
|
connection: Optional[sqlalchemy.engine.Connection] = None,
|
||||||
|
engine_args: Optional[dict[str, Any]] = None,
|
||||||
|
) -> None:
|
||||||
|
self.connection_string = connection_string
|
||||||
|
self.collection_name = collection_name
|
||||||
|
self.collection_metadata = collection_metadata
|
||||||
|
self.pre_delete_collection = pre_delete_collection
|
||||||
|
self.engine_args = engine_args or {}
|
||||||
|
# Create a connection if not provided, otherwise use the provided connection
|
||||||
|
self._conn = connection if connection else self.__connect()
|
||||||
|
self.__post_init__()
|
||||||
|
|
||||||
|
def __post_init__(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the store."""
|
||||||
|
ItemStore, CollectionStore = _get_storage_stores()
|
||||||
|
self.CollectionStore = CollectionStore
|
||||||
|
self.ItemStore = ItemStore
|
||||||
|
self.__create_tables_if_not_exists()
|
||||||
|
self.__create_collection()
|
||||||
|
|
||||||
|
def __connect(self) -> sqlalchemy.engine.Connection:
|
||||||
|
engine = sqlalchemy.create_engine(self.connection_string, **self.engine_args)
|
||||||
|
conn = engine.connect()
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def __create_tables_if_not_exists(self) -> None:
|
||||||
|
with self._conn.begin():
|
||||||
|
Base.metadata.create_all(self._conn)
|
||||||
|
|
||||||
|
def __create_collection(self) -> None:
|
||||||
|
if self.pre_delete_collection:
|
||||||
|
self.delete_collection()
|
||||||
|
with Session(self._conn) as session:
|
||||||
|
self.CollectionStore.get_or_create(
|
||||||
|
session, self.collection_name, cmetadata=self.collection_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_collection(self) -> None:
|
||||||
|
with Session(self._conn) as session:
|
||||||
|
collection = self.__get_collection(session)
|
||||||
|
if not collection:
|
||||||
|
return
|
||||||
|
session.delete(collection)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
def __get_collection(self, session: Session) -> Any:
|
||||||
|
return self.CollectionStore.get_by_name(session, self.collection_name)
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
if self._conn:
|
||||||
|
self._conn.close()
|
||||||
|
|
||||||
|
def __serialize_value(self, obj: V) -> str:
|
||||||
|
if isinstance(obj, Serializable):
|
||||||
|
return dumps(obj)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def __deserialize_value(self, obj: V) -> str:
|
||||||
|
try:
|
||||||
|
return loads(obj)
|
||||||
|
except Exception:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def mget(self, keys: Sequence[str]) -> List[Optional[V]]:
|
||||||
|
"""Get the values associated with the given keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keys (Sequence[str]): A sequence of keys.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A sequence of optional values associated with the keys.
|
||||||
|
If a key is not found, the corresponding value will be None.
|
||||||
|
"""
|
||||||
|
with Session(self._conn) as session:
|
||||||
|
collection = self.__get_collection(session)
|
||||||
|
|
||||||
|
items = (
|
||||||
|
session.query(self.ItemStore.content, self.ItemStore.custom_id)
|
||||||
|
.where(
|
||||||
|
sqlalchemy.and_(
|
||||||
|
self.ItemStore.custom_id.in_(keys),
|
||||||
|
self.ItemStore.collection_id == (collection.uuid),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
ordered_values = {key: None for key in keys}
|
||||||
|
for item in items:
|
||||||
|
v = item[0]
|
||||||
|
val = self.__deserialize_value(v) if v is not None else v
|
||||||
|
k = item[1]
|
||||||
|
ordered_values[k] = val
|
||||||
|
|
||||||
|
return [ordered_values[key] for key in keys]
|
||||||
|
|
||||||
|
def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
|
||||||
|
"""Set the values for the given keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_value_pairs (Sequence[Tuple[str, V]]): A sequence of key-value pairs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
with Session(self._conn) as session:
|
||||||
|
collection = self.__get_collection(session)
|
||||||
|
if not collection:
|
||||||
|
raise ValueError("Collection not found")
|
||||||
|
for id, item in key_value_pairs:
|
||||||
|
content = self.__serialize_value(item)
|
||||||
|
item_store = self.ItemStore(
|
||||||
|
content=content,
|
||||||
|
custom_id=id,
|
||||||
|
collection_id=collection.uuid,
|
||||||
|
)
|
||||||
|
session.add(item_store)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
def mdelete(self, keys: Sequence[str]) -> None:
|
||||||
|
"""Delete the given keys and their associated values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keys (Sequence[str]): A sequence of keys to delete.
|
||||||
|
"""
|
||||||
|
with Session(self._conn) as session:
|
||||||
|
collection = self.__get_collection(session)
|
||||||
|
if not collection:
|
||||||
|
raise ValueError("Collection not found")
|
||||||
|
if keys is not None:
|
||||||
|
stmt = sqlalchemy.delete(self.ItemStore).where(
|
||||||
|
sqlalchemy.and_(
|
||||||
|
self.ItemStore.custom_id.in_(keys),
|
||||||
|
self.ItemStore.collection_id == (collection.uuid),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session.execute(stmt)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]:
|
||||||
|
"""Get an iterator over keys that match the given prefix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix (str, optional): The prefix to match. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Iterator[str]: An iterator over keys that match the given prefix.
|
||||||
|
"""
|
||||||
|
with Session(self._conn) as session:
|
||||||
|
collection = self.__get_collection(session)
|
||||||
|
start = 0
|
||||||
|
while True:
|
||||||
|
stop = start + ITERATOR_WINDOW_SIZE
|
||||||
|
query = session.query(self.ItemStore.custom_id).where(
|
||||||
|
self.ItemStore.collection_id == (collection.uuid)
|
||||||
|
)
|
||||||
|
if prefix is not None:
|
||||||
|
query = query.filter(self.ItemStore.custom_id.startswith(prefix))
|
||||||
|
items = query.slice(start, stop).all()
|
||||||
|
|
||||||
|
if len(items) == 0:
|
||||||
|
break
|
||||||
|
for item in items:
|
||||||
|
yield item[0]
|
||||||
|
start += ITERATOR_WINDOW_SIZE
|
||||||
|
|
||||||
|
|
||||||
|
SQLDocStore = SQLBaseStore[Document]
|
||||||
|
SQLStrStore = SQLBaseStore[str]
|
228
libs/community/tests/integration_tests/storage/test_sql.py
Normal file
228
libs/community/tests/integration_tests/storage/test_sql.py
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
"""Implement integration tests for SQL storage."""
|
||||||
|
import os
|
||||||
|
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
|
from langchain_community.storage.sql import SQLDocStore, SQLStrStore
|
||||||
|
|
||||||
|
|
||||||
|
def connection_string_from_db_params() -> str:
|
||||||
|
"""Return connection string from database parameters."""
|
||||||
|
dbdriver = os.environ.get("TEST_SQL_DBDRIVER", "postgresql+psycopg2")
|
||||||
|
host = os.environ.get("TEST_SQL_HOST", "localhost")
|
||||||
|
port = int(os.environ.get("TEST_SQL_PORT", "5432"))
|
||||||
|
database = os.environ.get("TEST_SQL_DATABASE", "postgres")
|
||||||
|
user = os.environ.get("TEST_SQL_USER", "postgres")
|
||||||
|
password = os.environ.get("TEST_SQL_PASSWORD", "postgres")
|
||||||
|
return f"{dbdriver}://{user}:{password}@{host}:{port}/{database}"
|
||||||
|
|
||||||
|
|
||||||
|
CONNECTION_STRING = connection_string_from_db_params()
|
||||||
|
COLLECTION_NAME = "test_collection"
|
||||||
|
COLLECTION_NAME_2 = "test_collection_2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_str_store_mget() -> None:
|
||||||
|
store = SQLStrStore(
|
||||||
|
collection_name=COLLECTION_NAME,
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
)
|
||||||
|
store.mset([("key1", "value1"), ("key2", "value2")])
|
||||||
|
|
||||||
|
values = store.mget(["key1", "key2"])
|
||||||
|
assert values == ["value1", "value2"]
|
||||||
|
|
||||||
|
# Test non-existent key
|
||||||
|
non_existent_value = store.mget(["key3"])
|
||||||
|
assert non_existent_value == [None]
|
||||||
|
store.delete_collection()
|
||||||
|
|
||||||
|
|
||||||
|
def test_str_store_mset() -> None:
|
||||||
|
store = SQLStrStore(
|
||||||
|
collection_name=COLLECTION_NAME,
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
)
|
||||||
|
store.mset([("key1", "value1"), ("key2", "value2")])
|
||||||
|
|
||||||
|
values = store.mget(["key1", "key2"])
|
||||||
|
assert values == ["value1", "value2"]
|
||||||
|
store.delete_collection()
|
||||||
|
|
||||||
|
|
||||||
|
def test_str_store_mdelete() -> None:
|
||||||
|
store = SQLStrStore(
|
||||||
|
collection_name=COLLECTION_NAME,
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
)
|
||||||
|
store.mset([("key1", "value1"), ("key2", "value2")])
|
||||||
|
|
||||||
|
store.mdelete(["key1"])
|
||||||
|
|
||||||
|
values = store.mget(["key1", "key2"])
|
||||||
|
assert values == [None, "value2"]
|
||||||
|
|
||||||
|
# Test deleting non-existent key
|
||||||
|
store.mdelete(["key3"]) # No error should be raised
|
||||||
|
store.delete_collection()
|
||||||
|
|
||||||
|
|
||||||
|
def test_str_store_yield_keys() -> None:
|
||||||
|
store = SQLStrStore(
|
||||||
|
collection_name=COLLECTION_NAME,
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
)
|
||||||
|
store.mset([("key1", "value1"), ("key2", "value2"), ("key3", "value3")])
|
||||||
|
|
||||||
|
keys = list(store.yield_keys())
|
||||||
|
assert set(keys) == {"key1", "key2", "key3"}
|
||||||
|
|
||||||
|
keys_with_prefix = list(store.yield_keys(prefix="key"))
|
||||||
|
assert set(keys_with_prefix) == {"key1", "key2", "key3"}
|
||||||
|
|
||||||
|
keys_with_invalid_prefix = list(store.yield_keys(prefix="x"))
|
||||||
|
assert keys_with_invalid_prefix == []
|
||||||
|
store.delete_collection()
|
||||||
|
|
||||||
|
|
||||||
|
def test_str_store_collection() -> None:
|
||||||
|
"""Test that collections are isolated within a db."""
|
||||||
|
store_1 = SQLStrStore(
|
||||||
|
collection_name=COLLECTION_NAME,
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
)
|
||||||
|
store_2 = SQLStrStore(
|
||||||
|
collection_name=COLLECTION_NAME_2,
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
store_1.mset([("key1", "value1"), ("key2", "value2")])
|
||||||
|
store_2.mset([("key3", "value3"), ("key4", "value4")])
|
||||||
|
|
||||||
|
values = store_1.mget(["key1", "key2"])
|
||||||
|
assert values == ["value1", "value2"]
|
||||||
|
values = store_1.mget(["key3", "key4"])
|
||||||
|
assert values == [None, None]
|
||||||
|
|
||||||
|
values = store_2.mget(["key1", "key2"])
|
||||||
|
assert values == [None, None]
|
||||||
|
values = store_2.mget(["key3", "key4"])
|
||||||
|
assert values == ["value3", "value4"]
|
||||||
|
|
||||||
|
store_1.delete_collection()
|
||||||
|
store_2.delete_collection()
|
||||||
|
|
||||||
|
|
||||||
|
def test_doc_store_mget() -> None:
|
||||||
|
store = SQLDocStore(
|
||||||
|
collection_name=COLLECTION_NAME,
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
)
|
||||||
|
doc_1 = Document(page_content="value1")
|
||||||
|
doc_2 = Document(page_content="value2")
|
||||||
|
store.mset([("key1", doc_1), ("key2", doc_2)])
|
||||||
|
|
||||||
|
values = store.mget(["key1", "key2"])
|
||||||
|
assert values == [doc_1, doc_2]
|
||||||
|
|
||||||
|
# Test non-existent key
|
||||||
|
non_existent_value = store.mget(["key3"])
|
||||||
|
assert non_existent_value == [None]
|
||||||
|
store.delete_collection()
|
||||||
|
|
||||||
|
|
||||||
|
def test_doc_store_mset() -> None:
|
||||||
|
store = SQLDocStore(
|
||||||
|
collection_name=COLLECTION_NAME,
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
)
|
||||||
|
doc_1 = Document(page_content="value1")
|
||||||
|
doc_2 = Document(page_content="value2")
|
||||||
|
store.mset([("key1", doc_1), ("key2", doc_2)])
|
||||||
|
|
||||||
|
values = store.mget(["key1", "key2"])
|
||||||
|
assert values == [doc_1, doc_2]
|
||||||
|
store.delete_collection()
|
||||||
|
|
||||||
|
|
||||||
|
def test_doc_store_mdelete() -> None:
|
||||||
|
store = SQLDocStore(
|
||||||
|
collection_name=COLLECTION_NAME,
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
)
|
||||||
|
doc_1 = Document(page_content="value1")
|
||||||
|
doc_2 = Document(page_content="value2")
|
||||||
|
store.mset([("key1", doc_1), ("key2", doc_2)])
|
||||||
|
|
||||||
|
store.mdelete(["key1"])
|
||||||
|
|
||||||
|
values = store.mget(["key1", "key2"])
|
||||||
|
assert values == [None, doc_2]
|
||||||
|
|
||||||
|
# Test deleting non-existent key
|
||||||
|
store.mdelete(["key3"]) # No error should be raised
|
||||||
|
store.delete_collection()
|
||||||
|
|
||||||
|
|
||||||
|
def test_doc_store_yield_keys() -> None:
|
||||||
|
store = SQLDocStore(
|
||||||
|
collection_name=COLLECTION_NAME,
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
)
|
||||||
|
doc_1 = Document(page_content="value1")
|
||||||
|
doc_2 = Document(page_content="value2")
|
||||||
|
doc_3 = Document(page_content="value3")
|
||||||
|
store.mset([("key1", doc_1), ("key2", doc_2), ("key3", doc_3)])
|
||||||
|
|
||||||
|
keys = list(store.yield_keys())
|
||||||
|
assert set(keys) == {"key1", "key2", "key3"}
|
||||||
|
|
||||||
|
keys_with_prefix = list(store.yield_keys(prefix="key"))
|
||||||
|
assert set(keys_with_prefix) == {"key1", "key2", "key3"}
|
||||||
|
|
||||||
|
keys_with_invalid_prefix = list(store.yield_keys(prefix="x"))
|
||||||
|
assert keys_with_invalid_prefix == []
|
||||||
|
store.delete_collection()
|
||||||
|
|
||||||
|
|
||||||
|
def test_doc_store_collection() -> None:
|
||||||
|
"""Test that collections are isolated within a db."""
|
||||||
|
store_1 = SQLDocStore(
|
||||||
|
collection_name=COLLECTION_NAME,
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
)
|
||||||
|
store_2 = SQLDocStore(
|
||||||
|
collection_name=COLLECTION_NAME_2,
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
)
|
||||||
|
doc_1 = Document(page_content="value1")
|
||||||
|
doc_2 = Document(page_content="value2")
|
||||||
|
doc_3 = Document(page_content="value3")
|
||||||
|
doc_4 = Document(page_content="value4")
|
||||||
|
store_1.mset([("key1", doc_1), ("key2", doc_2)])
|
||||||
|
store_2.mset([("key3", doc_3), ("key4", doc_4)])
|
||||||
|
|
||||||
|
values = store_1.mget(["key1", "key2"])
|
||||||
|
assert values == [doc_1, doc_2]
|
||||||
|
values = store_1.mget(["key3", "key4"])
|
||||||
|
assert values == [None, None]
|
||||||
|
|
||||||
|
values = store_2.mget(["key1", "key2"])
|
||||||
|
assert values == [None, None]
|
||||||
|
values = store_2.mget(["key3", "key4"])
|
||||||
|
assert values == [doc_3, doc_4]
|
||||||
|
|
||||||
|
store_1.delete_collection()
|
||||||
|
store_2.delete_collection()
|
@ -6,6 +6,8 @@ EXPECTED_ALL = [
|
|||||||
"RedisStore",
|
"RedisStore",
|
||||||
"UpstashRedisByteStore",
|
"UpstashRedisByteStore",
|
||||||
"UpstashRedisStore",
|
"UpstashRedisStore",
|
||||||
|
"SQLDocStore",
|
||||||
|
"SQLStrStore",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
7
libs/community/tests/unit_tests/storage/test_sql.py
Normal file
7
libs/community/tests/unit_tests/storage/test_sql.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
"""Light weight unit test that attempts to import SQLDocStore/SQLStrStore.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_import_storage() -> None:
|
||||||
|
"""Attempt to import storage modules."""
|
||||||
|
from langchain_community.storage.sql import SQLDocStore, SQLStrStore # noqa
|
Loading…
Reference in New Issue
Block a user