community: SQLStrStore/SQLDocStore provide an easy SQL alternative to InMemoryStore to persist data remotely in a SQL storage (#15909)

**Description:**

- Implement `SQLStrStore` and `SQLDocStore` classes that inherits from
`BaseStore` to allow to persist data remotely on a SQL server.
- SQL is widely used and sometimes we do not want to install a caching
solution like Redis.
- Multiple issues/comments complain that there is no easy remote and
persistent solution that are not in memory (users want to replace
InMemoryStore), e.g.,
https://github.com/langchain-ai/langchain/issues/14267,
https://github.com/langchain-ai/langchain/issues/15633,
https://github.com/langchain-ai/langchain/issues/14643,
https://stackoverflow.com/questions/77385587/persist-parentdocumentretriever-of-langchain
- This is particularly painful when wanting to use
`ParentDocumentRetriever `
- This implementation is particularly useful when:
     * it's expensive to construct an InMemoryDocstore/dict
     * you want to retrieve documents from remote sources
     * you just want to reuse existing objects
- This implementation integrates well with PGVector, indeed, when using
PGVector, you already have a SQL instance running. `SQLDocStore` is a
convenient way of using this instance to store documents associated to
vectors. An integration example with ParentDocumentRetriever and
PGVector is provided in docs/docs/integrations/stores/sql.ipynb or
[here](https://github.com/gcheron/langchain/blob/sql-store/docs/docs/integrations/stores/sql.ipynb).
- It persists `str` and `Document` objects but can be easily extended.

 **Issue:**

Provide an easy SQL alternative to `InMemoryStore`.

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
gcheron 2024-01-24 01:50:48 +01:00 committed by GitHub
parent 26b2ad6d5b
commit cfc225ecb3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 774 additions and 0 deletions

View File

@ -0,0 +1,186 @@
{
"cells": [
{
"cell_type": "raw",
"metadata": {},
"source": [
"---\n",
"sidebar_label: SQL\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# SQLStore\n",
"\n",
"The `SQLStrStore` and `SQLDocStore` implement remote data access and persistence to store strings or LangChain documents in your SQL instance."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['value1', 'value2']\n",
"['key2']\n",
"['key2']\n"
]
}
],
"source": [
"from langchain_community.storage import SQLStrStore\n",
"\n",
"# simple example using an SQLStrStore to store strings\n",
"# same as you would use in \"InMemoryStore\" but using SQL persistence\n",
"CONNECTION_STRING = \"postgresql+psycopg2://user:pass@localhost:5432/db\"\n",
"COLLECTION_NAME = \"test_collection\"\n",
"\n",
"store = SQLStrStore(\n",
" collection_name=COLLECTION_NAME,\n",
" connection_string=CONNECTION_STRING,\n",
")\n",
"store.mset([(\"key1\", \"value1\"), (\"key2\", \"value2\")])\n",
"print(store.mget([\"key1\", \"key2\"]))\n",
"# ['value1', 'value2']\n",
"store.mdelete([\"key1\"])\n",
"print(list(store.yield_keys()))\n",
"# ['key2']\n",
"print(list(store.yield_keys(prefix=\"k\")))\n",
"# ['key2']\n",
"# delete the COLLECTION_NAME collection"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Integration with ParentRetriever and PGVector\n",
"\n",
"When using PGVector, you already have a SQL instance running. Here is a convenient way of using this instance to store documents associated to vectors. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Prepare the PGVector vectorestore with something like this:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.vectorstores import PGVector\n",
"from langchain_openai import OpenAIEmbeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"embeddings = OpenAIEmbeddings()\n",
"vector_db = PGVector.from_existing_index(\n",
" embedding=embeddings,\n",
" collection_name=COLLECTION_NAME,\n",
" connection_string=CONNECTION_STRING,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then create the parent retiever using `SQLDocStore` to persist the documents"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.document_loaders import TextLoader\n",
"from langchain.retrievers import ParentDocumentRetriever\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"from langchain_community.storage import SQLDocStore\n",
"\n",
"CONNECTION_STRING = \"postgresql+psycopg2://user:pass@localhost:5432/db\"\n",
"COLLECTION_NAME = \"state_of_the_union_test\"\n",
"docstore = SQLDocStore(\n",
" collection_name=COLLECTION_NAME,\n",
" connection_string=CONNECTION_STRING,\n",
")\n",
"\n",
"loader = TextLoader(\"./state_of_the_union.txt\")\n",
"documents = loader.load()\n",
"\n",
"parent_splitter = RecursiveCharacterTextSplitter(chunk_size=400)\n",
"child_splitter = RecursiveCharacterTextSplitter(chunk_size=50)\n",
"\n",
"retriever = ParentDocumentRetriever(\n",
" vectorstore=vector_db,\n",
" docstore=docstore,\n",
" child_splitter=child_splitter,\n",
" parent_splitter=parent_splitter,\n",
")\n",
"retriever.add_documents(documents)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Delete a collection"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.storage import SQLStrStore\n",
"\n",
"# delete the COLLECTION_NAME collection\n",
"CONNECTION_STRING = \"postgresql+psycopg2://user:pass@localhost:5432/db\"\n",
"COLLECTION_NAME = \"test_collection\"\n",
"store = SQLStrStore(\n",
" collection_name=COLLECTION_NAME,\n",
" connection_string=CONNECTION_STRING,\n",
")\n",
"store.delete_collection()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -11,6 +11,10 @@ from langchain_community.storage.astradb import (
AstraDBStore,
)
from langchain_community.storage.redis import RedisStore
from langchain_community.storage.sql import (
SQLDocStore,
SQLStrStore,
)
from langchain_community.storage.upstash_redis import (
UpstashRedisByteStore,
UpstashRedisStore,
@ -22,4 +26,6 @@ __all__ = [
"RedisStore",
"UpstashRedisByteStore",
"UpstashRedisStore",
"SQLDocStore",
"SQLStrStore",
]

View File

@ -0,0 +1,345 @@
"""SQL storage that persists data in a SQL database
and supports data isolation using collections."""
from __future__ import annotations
import uuid
from typing import Any, Generic, Iterator, List, Optional, Sequence, Tuple, TypeVar
import sqlalchemy
from sqlalchemy import JSON, UUID
from sqlalchemy.orm import Session, relationship
try:
from sqlalchemy.orm import declarative_base
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
from langchain_core.documents import Document
from langchain_core.load import Serializable, dumps, loads
from langchain_core.stores import BaseStore
V = TypeVar("V")
ITERATOR_WINDOW_SIZE = 1000
Base = declarative_base() # type: Any
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
class BaseModel(Base):
"""Base model for the SQL stores."""
__abstract__ = True
uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
_classes: Any = None
def _get_storage_stores() -> Any:
global _classes
if _classes is not None:
return _classes
class CollectionStore(BaseModel):
"""Collection store."""
__tablename__ = "langchain_storage_collection"
name = sqlalchemy.Column(sqlalchemy.String)
cmetadata = sqlalchemy.Column(JSON)
items = relationship(
"ItemStore",
back_populates="collection",
passive_deletes=True,
)
@classmethod
def get_by_name(
cls, session: Session, name: str
) -> Optional["CollectionStore"]:
# type: ignore
return session.query(cls).filter(cls.name == name).first()
@classmethod
def get_or_create(
cls,
session: Session,
name: str,
cmetadata: Optional[dict] = None,
) -> Tuple["CollectionStore", bool]:
"""
Get or create a collection.
Returns [Collection, bool] where the bool is True if the collection was created.
""" # noqa: E501
created = False
collection = cls.get_by_name(session, name)
if collection:
return collection, created
collection = cls(name=name, cmetadata=cmetadata)
session.add(collection)
session.commit()
created = True
return collection, created
class ItemStore(BaseModel):
"""Item store."""
__tablename__ = "langchain_storage_items"
collection_id = sqlalchemy.Column(
UUID(as_uuid=True),
sqlalchemy.ForeignKey(
f"{CollectionStore.__tablename__}.uuid",
ondelete="CASCADE",
),
)
collection = relationship(CollectionStore, back_populates="items")
content = sqlalchemy.Column(sqlalchemy.String, nullable=True)
# custom_id : any user defined id
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
_classes = (ItemStore, CollectionStore)
return _classes
class SQLBaseStore(BaseStore[str, V], Generic[V]):
"""SQL storage
Args:
connection_string: SQL connection string that will be passed to SQLAlchemy.
collection_name: The name of the collection to use. (default: langchain)
NOTE: Collections are useful to isolate your data in a given a database.
This is not the name of the table, but the name of the collection.
The tables will be created when initializing the store (if not exists)
So, make sure the user has the right permissions to create tables.
pre_delete_collection: If True, will delete the collection if it exists.
(default: False). Useful for testing.
engine_args: SQLAlchemy's create engine arguments.
Example:
.. code-block:: python
from langchain_community.storage import SQLDocStore
from langchain_community.embeddings.openai import OpenAIEmbeddings
# example using an SQLDocStore to store Document objects for
# a ParentDocumentRetriever
CONNECTION_STRING = "postgresql+psycopg2://user:pass@localhost:5432/db"
COLLECTION_NAME = "state_of_the_union_test"
docstore = SQLDocStore(
collection_name=COLLECTION_NAME,
connection_string=CONNECTION_STRING,
)
child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
vectorstore = ...
retriever = ParentDocumentRetriever(
vectorstore=vectorstore,
docstore=docstore,
child_splitter=child_splitter,
)
# example using an SQLStrStore to store strings
# same example as in "InMemoryStore" but using SQL persistence
store = SQLDocStore(
collection_name=COLLECTION_NAME,
connection_string=CONNECTION_STRING,
)
store.mset([('key1', 'value1'), ('key2', 'value2')])
store.mget(['key1', 'key2'])
# ['value1', 'value2']
store.mdelete(['key1'])
list(store.yield_keys())
# ['key2']
list(store.yield_keys(prefix='k'))
# ['key2']
# delete the COLLECTION_NAME collection
docstore.delete_collection()
"""
def __init__(
self,
connection_string: str,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
collection_metadata: Optional[dict] = None,
pre_delete_collection: bool = False,
connection: Optional[sqlalchemy.engine.Connection] = None,
engine_args: Optional[dict[str, Any]] = None,
) -> None:
self.connection_string = connection_string
self.collection_name = collection_name
self.collection_metadata = collection_metadata
self.pre_delete_collection = pre_delete_collection
self.engine_args = engine_args or {}
# Create a connection if not provided, otherwise use the provided connection
self._conn = connection if connection else self.__connect()
self.__post_init__()
def __post_init__(
self,
) -> None:
"""Initialize the store."""
ItemStore, CollectionStore = _get_storage_stores()
self.CollectionStore = CollectionStore
self.ItemStore = ItemStore
self.__create_tables_if_not_exists()
self.__create_collection()
def __connect(self) -> sqlalchemy.engine.Connection:
engine = sqlalchemy.create_engine(self.connection_string, **self.engine_args)
conn = engine.connect()
return conn
def __create_tables_if_not_exists(self) -> None:
with self._conn.begin():
Base.metadata.create_all(self._conn)
def __create_collection(self) -> None:
if self.pre_delete_collection:
self.delete_collection()
with Session(self._conn) as session:
self.CollectionStore.get_or_create(
session, self.collection_name, cmetadata=self.collection_metadata
)
def delete_collection(self) -> None:
with Session(self._conn) as session:
collection = self.__get_collection(session)
if not collection:
return
session.delete(collection)
session.commit()
def __get_collection(self, session: Session) -> Any:
return self.CollectionStore.get_by_name(session, self.collection_name)
def __del__(self) -> None:
if self._conn:
self._conn.close()
def __serialize_value(self, obj: V) -> str:
if isinstance(obj, Serializable):
return dumps(obj)
return obj
def __deserialize_value(self, obj: V) -> str:
try:
return loads(obj)
except Exception:
return obj
def mget(self, keys: Sequence[str]) -> List[Optional[V]]:
"""Get the values associated with the given keys.
Args:
keys (Sequence[str]): A sequence of keys.
Returns:
A sequence of optional values associated with the keys.
If a key is not found, the corresponding value will be None.
"""
with Session(self._conn) as session:
collection = self.__get_collection(session)
items = (
session.query(self.ItemStore.content, self.ItemStore.custom_id)
.where(
sqlalchemy.and_(
self.ItemStore.custom_id.in_(keys),
self.ItemStore.collection_id == (collection.uuid),
)
)
.all()
)
ordered_values = {key: None for key in keys}
for item in items:
v = item[0]
val = self.__deserialize_value(v) if v is not None else v
k = item[1]
ordered_values[k] = val
return [ordered_values[key] for key in keys]
def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
"""Set the values for the given keys.
Args:
key_value_pairs (Sequence[Tuple[str, V]]): A sequence of key-value pairs.
Returns:
None
"""
with Session(self._conn) as session:
collection = self.__get_collection(session)
if not collection:
raise ValueError("Collection not found")
for id, item in key_value_pairs:
content = self.__serialize_value(item)
item_store = self.ItemStore(
content=content,
custom_id=id,
collection_id=collection.uuid,
)
session.add(item_store)
session.commit()
def mdelete(self, keys: Sequence[str]) -> None:
"""Delete the given keys and their associated values.
Args:
keys (Sequence[str]): A sequence of keys to delete.
"""
with Session(self._conn) as session:
collection = self.__get_collection(session)
if not collection:
raise ValueError("Collection not found")
if keys is not None:
stmt = sqlalchemy.delete(self.ItemStore).where(
sqlalchemy.and_(
self.ItemStore.custom_id.in_(keys),
self.ItemStore.collection_id == (collection.uuid),
)
)
session.execute(stmt)
session.commit()
def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]:
"""Get an iterator over keys that match the given prefix.
Args:
prefix (str, optional): The prefix to match. Defaults to None.
Returns:
Iterator[str]: An iterator over keys that match the given prefix.
"""
with Session(self._conn) as session:
collection = self.__get_collection(session)
start = 0
while True:
stop = start + ITERATOR_WINDOW_SIZE
query = session.query(self.ItemStore.custom_id).where(
self.ItemStore.collection_id == (collection.uuid)
)
if prefix is not None:
query = query.filter(self.ItemStore.custom_id.startswith(prefix))
items = query.slice(start, stop).all()
if len(items) == 0:
break
for item in items:
yield item[0]
start += ITERATOR_WINDOW_SIZE
SQLDocStore = SQLBaseStore[Document]
SQLStrStore = SQLBaseStore[str]

View File

@ -0,0 +1,228 @@
"""Implement integration tests for SQL storage."""
import os
from langchain_core.documents import Document
from langchain_community.storage.sql import SQLDocStore, SQLStrStore
def connection_string_from_db_params() -> str:
"""Return connection string from database parameters."""
dbdriver = os.environ.get("TEST_SQL_DBDRIVER", "postgresql+psycopg2")
host = os.environ.get("TEST_SQL_HOST", "localhost")
port = int(os.environ.get("TEST_SQL_PORT", "5432"))
database = os.environ.get("TEST_SQL_DATABASE", "postgres")
user = os.environ.get("TEST_SQL_USER", "postgres")
password = os.environ.get("TEST_SQL_PASSWORD", "postgres")
return f"{dbdriver}://{user}:{password}@{host}:{port}/{database}"
CONNECTION_STRING = connection_string_from_db_params()
COLLECTION_NAME = "test_collection"
COLLECTION_NAME_2 = "test_collection_2"
def test_str_store_mget() -> None:
store = SQLStrStore(
collection_name=COLLECTION_NAME,
connection_string=CONNECTION_STRING,
pre_delete_collection=True,
)
store.mset([("key1", "value1"), ("key2", "value2")])
values = store.mget(["key1", "key2"])
assert values == ["value1", "value2"]
# Test non-existent key
non_existent_value = store.mget(["key3"])
assert non_existent_value == [None]
store.delete_collection()
def test_str_store_mset() -> None:
store = SQLStrStore(
collection_name=COLLECTION_NAME,
connection_string=CONNECTION_STRING,
pre_delete_collection=True,
)
store.mset([("key1", "value1"), ("key2", "value2")])
values = store.mget(["key1", "key2"])
assert values == ["value1", "value2"]
store.delete_collection()
def test_str_store_mdelete() -> None:
store = SQLStrStore(
collection_name=COLLECTION_NAME,
connection_string=CONNECTION_STRING,
pre_delete_collection=True,
)
store.mset([("key1", "value1"), ("key2", "value2")])
store.mdelete(["key1"])
values = store.mget(["key1", "key2"])
assert values == [None, "value2"]
# Test deleting non-existent key
store.mdelete(["key3"]) # No error should be raised
store.delete_collection()
def test_str_store_yield_keys() -> None:
store = SQLStrStore(
collection_name=COLLECTION_NAME,
connection_string=CONNECTION_STRING,
pre_delete_collection=True,
)
store.mset([("key1", "value1"), ("key2", "value2"), ("key3", "value3")])
keys = list(store.yield_keys())
assert set(keys) == {"key1", "key2", "key3"}
keys_with_prefix = list(store.yield_keys(prefix="key"))
assert set(keys_with_prefix) == {"key1", "key2", "key3"}
keys_with_invalid_prefix = list(store.yield_keys(prefix="x"))
assert keys_with_invalid_prefix == []
store.delete_collection()
def test_str_store_collection() -> None:
"""Test that collections are isolated within a db."""
store_1 = SQLStrStore(
collection_name=COLLECTION_NAME,
connection_string=CONNECTION_STRING,
pre_delete_collection=True,
)
store_2 = SQLStrStore(
collection_name=COLLECTION_NAME_2,
connection_string=CONNECTION_STRING,
pre_delete_collection=True,
)
store_1.mset([("key1", "value1"), ("key2", "value2")])
store_2.mset([("key3", "value3"), ("key4", "value4")])
values = store_1.mget(["key1", "key2"])
assert values == ["value1", "value2"]
values = store_1.mget(["key3", "key4"])
assert values == [None, None]
values = store_2.mget(["key1", "key2"])
assert values == [None, None]
values = store_2.mget(["key3", "key4"])
assert values == ["value3", "value4"]
store_1.delete_collection()
store_2.delete_collection()
def test_doc_store_mget() -> None:
store = SQLDocStore(
collection_name=COLLECTION_NAME,
connection_string=CONNECTION_STRING,
pre_delete_collection=True,
)
doc_1 = Document(page_content="value1")
doc_2 = Document(page_content="value2")
store.mset([("key1", doc_1), ("key2", doc_2)])
values = store.mget(["key1", "key2"])
assert values == [doc_1, doc_2]
# Test non-existent key
non_existent_value = store.mget(["key3"])
assert non_existent_value == [None]
store.delete_collection()
def test_doc_store_mset() -> None:
store = SQLDocStore(
collection_name=COLLECTION_NAME,
connection_string=CONNECTION_STRING,
pre_delete_collection=True,
)
doc_1 = Document(page_content="value1")
doc_2 = Document(page_content="value2")
store.mset([("key1", doc_1), ("key2", doc_2)])
values = store.mget(["key1", "key2"])
assert values == [doc_1, doc_2]
store.delete_collection()
def test_doc_store_mdelete() -> None:
store = SQLDocStore(
collection_name=COLLECTION_NAME,
connection_string=CONNECTION_STRING,
pre_delete_collection=True,
)
doc_1 = Document(page_content="value1")
doc_2 = Document(page_content="value2")
store.mset([("key1", doc_1), ("key2", doc_2)])
store.mdelete(["key1"])
values = store.mget(["key1", "key2"])
assert values == [None, doc_2]
# Test deleting non-existent key
store.mdelete(["key3"]) # No error should be raised
store.delete_collection()
def test_doc_store_yield_keys() -> None:
store = SQLDocStore(
collection_name=COLLECTION_NAME,
connection_string=CONNECTION_STRING,
pre_delete_collection=True,
)
doc_1 = Document(page_content="value1")
doc_2 = Document(page_content="value2")
doc_3 = Document(page_content="value3")
store.mset([("key1", doc_1), ("key2", doc_2), ("key3", doc_3)])
keys = list(store.yield_keys())
assert set(keys) == {"key1", "key2", "key3"}
keys_with_prefix = list(store.yield_keys(prefix="key"))
assert set(keys_with_prefix) == {"key1", "key2", "key3"}
keys_with_invalid_prefix = list(store.yield_keys(prefix="x"))
assert keys_with_invalid_prefix == []
store.delete_collection()
def test_doc_store_collection() -> None:
"""Test that collections are isolated within a db."""
store_1 = SQLDocStore(
collection_name=COLLECTION_NAME,
connection_string=CONNECTION_STRING,
pre_delete_collection=True,
)
store_2 = SQLDocStore(
collection_name=COLLECTION_NAME_2,
connection_string=CONNECTION_STRING,
pre_delete_collection=True,
)
doc_1 = Document(page_content="value1")
doc_2 = Document(page_content="value2")
doc_3 = Document(page_content="value3")
doc_4 = Document(page_content="value4")
store_1.mset([("key1", doc_1), ("key2", doc_2)])
store_2.mset([("key3", doc_3), ("key4", doc_4)])
values = store_1.mget(["key1", "key2"])
assert values == [doc_1, doc_2]
values = store_1.mget(["key3", "key4"])
assert values == [None, None]
values = store_2.mget(["key1", "key2"])
assert values == [None, None]
values = store_2.mget(["key3", "key4"])
assert values == [doc_3, doc_4]
store_1.delete_collection()
store_2.delete_collection()

View File

@ -6,6 +6,8 @@ EXPECTED_ALL = [
"RedisStore",
"UpstashRedisByteStore",
"UpstashRedisStore",
"SQLDocStore",
"SQLStrStore",
]

View File

@ -0,0 +1,7 @@
"""Light weight unit test that attempts to import SQLDocStore/SQLStrStore.
"""
def test_import_storage() -> None:
"""Attempt to import storage modules."""
from langchain_community.storage.sql import SQLDocStore, SQLStrStore # noqa