mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
cfc225ecb3
**Description:** - Implement `SQLStrStore` and `SQLDocStore` classes that inherits from `BaseStore` to allow to persist data remotely on a SQL server. - SQL is widely used and sometimes we do not want to install a caching solution like Redis. - Multiple issues/comments complain that there is no easy remote and persistent solution that are not in memory (users want to replace InMemoryStore), e.g., https://github.com/langchain-ai/langchain/issues/14267, https://github.com/langchain-ai/langchain/issues/15633, https://github.com/langchain-ai/langchain/issues/14643, https://stackoverflow.com/questions/77385587/persist-parentdocumentretriever-of-langchain - This is particularly painful when wanting to use `ParentDocumentRetriever ` - This implementation is particularly useful when: * it's expensive to construct an InMemoryDocstore/dict * you want to retrieve documents from remote sources * you just want to reuse existing objects - This implementation integrates well with PGVector, indeed, when using PGVector, you already have a SQL instance running. `SQLDocStore` is a convenient way of using this instance to store documents associated to vectors. An integration example with ParentDocumentRetriever and PGVector is provided in docs/docs/integrations/stores/sql.ipynb or [here](https://github.com/gcheron/langchain/blob/sql-store/docs/docs/integrations/stores/sql.ipynb). - It persists `str` and `Document` objects but can be easily extended. **Issue:** Provide an easy SQL alternative to `InMemoryStore`. --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
229 lines
6.9 KiB
Python
229 lines
6.9 KiB
Python
"""Implement integration tests for SQL storage."""
|
|
import os
|
|
|
|
from langchain_core.documents import Document
|
|
|
|
from langchain_community.storage.sql import SQLDocStore, SQLStrStore
|
|
|
|
|
|
def connection_string_from_db_params() -> str:
|
|
"""Return connection string from database parameters."""
|
|
dbdriver = os.environ.get("TEST_SQL_DBDRIVER", "postgresql+psycopg2")
|
|
host = os.environ.get("TEST_SQL_HOST", "localhost")
|
|
port = int(os.environ.get("TEST_SQL_PORT", "5432"))
|
|
database = os.environ.get("TEST_SQL_DATABASE", "postgres")
|
|
user = os.environ.get("TEST_SQL_USER", "postgres")
|
|
password = os.environ.get("TEST_SQL_PASSWORD", "postgres")
|
|
return f"{dbdriver}://{user}:{password}@{host}:{port}/{database}"
|
|
|
|
|
|
CONNECTION_STRING = connection_string_from_db_params()
|
|
COLLECTION_NAME = "test_collection"
|
|
COLLECTION_NAME_2 = "test_collection_2"
|
|
|
|
|
|
def test_str_store_mget() -> None:
|
|
store = SQLStrStore(
|
|
collection_name=COLLECTION_NAME,
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
store.mset([("key1", "value1"), ("key2", "value2")])
|
|
|
|
values = store.mget(["key1", "key2"])
|
|
assert values == ["value1", "value2"]
|
|
|
|
# Test non-existent key
|
|
non_existent_value = store.mget(["key3"])
|
|
assert non_existent_value == [None]
|
|
store.delete_collection()
|
|
|
|
|
|
def test_str_store_mset() -> None:
|
|
store = SQLStrStore(
|
|
collection_name=COLLECTION_NAME,
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
store.mset([("key1", "value1"), ("key2", "value2")])
|
|
|
|
values = store.mget(["key1", "key2"])
|
|
assert values == ["value1", "value2"]
|
|
store.delete_collection()
|
|
|
|
|
|
def test_str_store_mdelete() -> None:
|
|
store = SQLStrStore(
|
|
collection_name=COLLECTION_NAME,
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
store.mset([("key1", "value1"), ("key2", "value2")])
|
|
|
|
store.mdelete(["key1"])
|
|
|
|
values = store.mget(["key1", "key2"])
|
|
assert values == [None, "value2"]
|
|
|
|
# Test deleting non-existent key
|
|
store.mdelete(["key3"]) # No error should be raised
|
|
store.delete_collection()
|
|
|
|
|
|
def test_str_store_yield_keys() -> None:
|
|
store = SQLStrStore(
|
|
collection_name=COLLECTION_NAME,
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
store.mset([("key1", "value1"), ("key2", "value2"), ("key3", "value3")])
|
|
|
|
keys = list(store.yield_keys())
|
|
assert set(keys) == {"key1", "key2", "key3"}
|
|
|
|
keys_with_prefix = list(store.yield_keys(prefix="key"))
|
|
assert set(keys_with_prefix) == {"key1", "key2", "key3"}
|
|
|
|
keys_with_invalid_prefix = list(store.yield_keys(prefix="x"))
|
|
assert keys_with_invalid_prefix == []
|
|
store.delete_collection()
|
|
|
|
|
|
def test_str_store_collection() -> None:
|
|
"""Test that collections are isolated within a db."""
|
|
store_1 = SQLStrStore(
|
|
collection_name=COLLECTION_NAME,
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
store_2 = SQLStrStore(
|
|
collection_name=COLLECTION_NAME_2,
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
|
|
store_1.mset([("key1", "value1"), ("key2", "value2")])
|
|
store_2.mset([("key3", "value3"), ("key4", "value4")])
|
|
|
|
values = store_1.mget(["key1", "key2"])
|
|
assert values == ["value1", "value2"]
|
|
values = store_1.mget(["key3", "key4"])
|
|
assert values == [None, None]
|
|
|
|
values = store_2.mget(["key1", "key2"])
|
|
assert values == [None, None]
|
|
values = store_2.mget(["key3", "key4"])
|
|
assert values == ["value3", "value4"]
|
|
|
|
store_1.delete_collection()
|
|
store_2.delete_collection()
|
|
|
|
|
|
def test_doc_store_mget() -> None:
|
|
store = SQLDocStore(
|
|
collection_name=COLLECTION_NAME,
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
doc_1 = Document(page_content="value1")
|
|
doc_2 = Document(page_content="value2")
|
|
store.mset([("key1", doc_1), ("key2", doc_2)])
|
|
|
|
values = store.mget(["key1", "key2"])
|
|
assert values == [doc_1, doc_2]
|
|
|
|
# Test non-existent key
|
|
non_existent_value = store.mget(["key3"])
|
|
assert non_existent_value == [None]
|
|
store.delete_collection()
|
|
|
|
|
|
def test_doc_store_mset() -> None:
|
|
store = SQLDocStore(
|
|
collection_name=COLLECTION_NAME,
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
doc_1 = Document(page_content="value1")
|
|
doc_2 = Document(page_content="value2")
|
|
store.mset([("key1", doc_1), ("key2", doc_2)])
|
|
|
|
values = store.mget(["key1", "key2"])
|
|
assert values == [doc_1, doc_2]
|
|
store.delete_collection()
|
|
|
|
|
|
def test_doc_store_mdelete() -> None:
|
|
store = SQLDocStore(
|
|
collection_name=COLLECTION_NAME,
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
doc_1 = Document(page_content="value1")
|
|
doc_2 = Document(page_content="value2")
|
|
store.mset([("key1", doc_1), ("key2", doc_2)])
|
|
|
|
store.mdelete(["key1"])
|
|
|
|
values = store.mget(["key1", "key2"])
|
|
assert values == [None, doc_2]
|
|
|
|
# Test deleting non-existent key
|
|
store.mdelete(["key3"]) # No error should be raised
|
|
store.delete_collection()
|
|
|
|
|
|
def test_doc_store_yield_keys() -> None:
|
|
store = SQLDocStore(
|
|
collection_name=COLLECTION_NAME,
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
doc_1 = Document(page_content="value1")
|
|
doc_2 = Document(page_content="value2")
|
|
doc_3 = Document(page_content="value3")
|
|
store.mset([("key1", doc_1), ("key2", doc_2), ("key3", doc_3)])
|
|
|
|
keys = list(store.yield_keys())
|
|
assert set(keys) == {"key1", "key2", "key3"}
|
|
|
|
keys_with_prefix = list(store.yield_keys(prefix="key"))
|
|
assert set(keys_with_prefix) == {"key1", "key2", "key3"}
|
|
|
|
keys_with_invalid_prefix = list(store.yield_keys(prefix="x"))
|
|
assert keys_with_invalid_prefix == []
|
|
store.delete_collection()
|
|
|
|
|
|
def test_doc_store_collection() -> None:
|
|
"""Test that collections are isolated within a db."""
|
|
store_1 = SQLDocStore(
|
|
collection_name=COLLECTION_NAME,
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
store_2 = SQLDocStore(
|
|
collection_name=COLLECTION_NAME_2,
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
doc_1 = Document(page_content="value1")
|
|
doc_2 = Document(page_content="value2")
|
|
doc_3 = Document(page_content="value3")
|
|
doc_4 = Document(page_content="value4")
|
|
store_1.mset([("key1", doc_1), ("key2", doc_2)])
|
|
store_2.mset([("key3", doc_3), ("key4", doc_4)])
|
|
|
|
values = store_1.mget(["key1", "key2"])
|
|
assert values == [doc_1, doc_2]
|
|
values = store_1.mget(["key3", "key4"])
|
|
assert values == [None, None]
|
|
|
|
values = store_2.mget(["key1", "key2"])
|
|
assert values == [None, None]
|
|
values = store_2.mget(["key3", "key4"])
|
|
assert values == [doc_3, doc_4]
|
|
|
|
store_1.delete_collection()
|
|
store_2.delete_collection()
|