mirror of https://github.com/hwchase17/langchain
community[minor]: Add SQL storage implementation (#22207)
Hello @eyurtsev - package: langchain-comminity - **Description**: Add SQL implementation for docstore. A new implementation, in line with my other PR ([async PGVector](https://github.com/langchain-ai/langchain-postgres/pull/32), [SQLChatMessageMemory](https://github.com/langchain-ai/langchain/pull/22065)) - Twitter handler: pprados --------- Signed-off-by: ChengZi <chen.zhang@zilliz.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Piotr Mardziel <piotrm@gmail.com> Co-authored-by: ChengZi <chen.zhang@zilliz.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>pull/22691/head
parent
f2f0e0e13d
commit
9aabb446c5
@ -0,0 +1,266 @@
|
||||
import contextlib
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.stores import BaseStore
|
||||
from sqlalchemy import (
|
||||
Engine,
|
||||
LargeBinary,
|
||||
and_,
|
||||
create_engine,
|
||||
delete,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.orm import (
|
||||
Mapped,
|
||||
Session,
|
||||
declarative_base,
|
||||
mapped_column,
|
||||
sessionmaker,
|
||||
)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def items_equal(x: Any, y: Any) -> bool:
|
||||
return x == y
|
||||
|
||||
|
||||
class LangchainKeyValueStores(Base): # type: ignore[valid-type,misc]
|
||||
"""Table used to save values."""
|
||||
|
||||
# ATTENTION:
|
||||
# Prior to modifying this table, please determine whether
|
||||
# we should create migrations for this table to make sure
|
||||
# users do not experience data loss.
|
||||
__tablename__ = "langchain_key_value_stores"
|
||||
|
||||
namespace: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False)
|
||||
key: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False)
|
||||
value = mapped_column(LargeBinary, index=False, nullable=False)
|
||||
|
||||
|
||||
# This is a fix of original SQLStore.
|
||||
# This can will be removed when a PR will be merged.
|
||||
class SQLStore(BaseStore[str, bytes]):
|
||||
"""BaseStore interface that works on an SQL database.
|
||||
|
||||
Examples:
|
||||
Create a SQLStore instance and perform operations on it:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_rag.storage import SQLStore
|
||||
|
||||
# Instantiate the SQLStore with the root path
|
||||
sql_store = SQLStore(namespace="test", db_url="sqllite://:memory:")
|
||||
|
||||
# Set values for keys
|
||||
sql_store.mset([("key1", b"value1"), ("key2", b"value2")])
|
||||
|
||||
# Get values for keys
|
||||
values = sql_store.mget(["key1", "key2"]) # Returns [b"value1", b"value2"]
|
||||
|
||||
# Delete keys
|
||||
sql_store.mdelete(["key1"])
|
||||
|
||||
# Iterate over keys
|
||||
for key in sql_store.yield_keys():
|
||||
print(key)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
namespace: str,
|
||||
db_url: Optional[Union[str, Path]] = None,
|
||||
engine: Optional[Union[Engine, AsyncEngine]] = None,
|
||||
engine_kwargs: Optional[Dict[str, Any]] = None,
|
||||
async_mode: Optional[bool] = None,
|
||||
):
|
||||
if db_url is None and engine is None:
|
||||
raise ValueError("Must specify either db_url or engine")
|
||||
|
||||
if db_url is not None and engine is not None:
|
||||
raise ValueError("Must specify either db_url or engine, not both")
|
||||
|
||||
_engine: Union[Engine, AsyncEngine]
|
||||
if db_url:
|
||||
if async_mode is None:
|
||||
async_mode = False
|
||||
if async_mode:
|
||||
_engine = create_async_engine(
|
||||
url=str(db_url),
|
||||
**(engine_kwargs or {}),
|
||||
)
|
||||
else:
|
||||
_engine = create_engine(url=str(db_url), **(engine_kwargs or {}))
|
||||
elif engine:
|
||||
_engine = engine
|
||||
|
||||
else:
|
||||
raise AssertionError("Something went wrong with configuration of engine.")
|
||||
|
||||
_session_maker: Union[sessionmaker[Session], async_sessionmaker[AsyncSession]]
|
||||
if isinstance(_engine, AsyncEngine):
|
||||
self.async_mode = True
|
||||
_session_maker = async_sessionmaker(bind=_engine)
|
||||
else:
|
||||
self.async_mode = False
|
||||
_session_maker = sessionmaker(bind=_engine)
|
||||
|
||||
self.engine = _engine
|
||||
self.dialect = _engine.dialect.name
|
||||
self.session_maker = _session_maker
|
||||
self.namespace = namespace
|
||||
|
||||
def create_schema(self) -> None:
|
||||
Base.metadata.create_all(self.engine)
|
||||
|
||||
async def acreate_schema(self) -> None:
|
||||
assert isinstance(self.engine, AsyncEngine)
|
||||
async with self.engine.begin() as session:
|
||||
await session.run_sync(Base.metadata.create_all)
|
||||
|
||||
def drop(self) -> None:
|
||||
Base.metadata.drop_all(bind=self.engine.connect())
|
||||
|
||||
async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
|
||||
assert isinstance(self.engine, AsyncEngine)
|
||||
result: Dict[str, bytes] = {}
|
||||
async with self._make_async_session() as session:
|
||||
stmt = select(LangchainKeyValueStores).filter(
|
||||
and_(
|
||||
LangchainKeyValueStores.key.in_(keys),
|
||||
LangchainKeyValueStores.namespace == self.namespace,
|
||||
)
|
||||
)
|
||||
for v in await session.scalars(stmt):
|
||||
result[v.key] = v.value
|
||||
return [result.get(key) for key in keys]
|
||||
|
||||
def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
|
||||
result = {}
|
||||
|
||||
with self._make_sync_session() as session:
|
||||
stmt = select(LangchainKeyValueStores).filter(
|
||||
and_(
|
||||
LangchainKeyValueStores.key.in_(keys),
|
||||
LangchainKeyValueStores.namespace == self.namespace,
|
||||
)
|
||||
)
|
||||
for v in session.scalars(stmt):
|
||||
result[v.key] = v.value
|
||||
return [result.get(key) for key in keys]
|
||||
|
||||
async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
|
||||
async with self._make_async_session() as session:
|
||||
await self._amdelete([key for key, _ in key_value_pairs], session)
|
||||
session.add_all(
|
||||
[
|
||||
LangchainKeyValueStores(namespace=self.namespace, key=k, value=v)
|
||||
for k, v in key_value_pairs
|
||||
]
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
|
||||
values: Dict[str, bytes] = dict(key_value_pairs)
|
||||
with self._make_sync_session() as session:
|
||||
self._mdelete(list(values.keys()), session)
|
||||
session.add_all(
|
||||
[
|
||||
LangchainKeyValueStores(namespace=self.namespace, key=k, value=v)
|
||||
for k, v in values.items()
|
||||
]
|
||||
)
|
||||
session.commit()
|
||||
|
||||
def _mdelete(self, keys: Sequence[str], session: Session) -> None:
|
||||
stmt = delete(LangchainKeyValueStores).filter(
|
||||
and_(
|
||||
LangchainKeyValueStores.key.in_(keys),
|
||||
LangchainKeyValueStores.namespace == self.namespace,
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
|
||||
async def _amdelete(self, keys: Sequence[str], session: AsyncSession) -> None:
|
||||
stmt = delete(LangchainKeyValueStores).filter(
|
||||
and_(
|
||||
LangchainKeyValueStores.key.in_(keys),
|
||||
LangchainKeyValueStores.namespace == self.namespace,
|
||||
)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
|
||||
def mdelete(self, keys: Sequence[str]) -> None:
|
||||
with self._make_sync_session() as session:
|
||||
self._mdelete(keys, session)
|
||||
session.commit()
|
||||
|
||||
async def amdelete(self, keys: Sequence[str]) -> None:
|
||||
async with self._make_async_session() as session:
|
||||
await self._amdelete(keys, session)
|
||||
await session.commit()
|
||||
|
||||
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
|
||||
with self._make_sync_session() as session:
|
||||
for v in session.query(LangchainKeyValueStores).filter( # type: ignore
|
||||
LangchainKeyValueStores.namespace == self.namespace
|
||||
):
|
||||
if str(v.key).startswith(prefix or ""):
|
||||
yield str(v.key)
|
||||
session.close()
|
||||
|
||||
async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]:
|
||||
async with self._make_async_session() as session:
|
||||
stmt = select(LangchainKeyValueStores).filter(
|
||||
LangchainKeyValueStores.namespace == self.namespace
|
||||
)
|
||||
for v in await session.scalars(stmt):
|
||||
if str(v.key).startswith(prefix or ""):
|
||||
yield str(v.key)
|
||||
await session.close()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _make_sync_session(self) -> Generator[Session, None, None]:
|
||||
"""Make an async session."""
|
||||
if self.async_mode:
|
||||
raise ValueError(
|
||||
"Attempting to use a sync method in when async mode is turned on. "
|
||||
"Please use the corresponding async method instead."
|
||||
)
|
||||
with cast(Session, self.session_maker()) as session:
|
||||
yield cast(Session, session)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _make_async_session(self) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Make an async session."""
|
||||
if not self.async_mode:
|
||||
raise ValueError(
|
||||
"Attempting to use an async method in when sync mode is turned on. "
|
||||
"Please use the corresponding async method instead."
|
||||
)
|
||||
async with cast(AsyncSession, self.session_maker()) as session:
|
||||
yield cast(AsyncSession, session)
|
@ -0,0 +1,186 @@
|
||||
"""Implement integration tests for Redis storage."""
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine, create_engine, text
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
||||
|
||||
from langchain_community.storage import SQLStore
|
||||
|
||||
pytest.importorskip("sqlalchemy")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sql_engine() -> Engine:
|
||||
"""Yield redis client."""
|
||||
return create_engine(url="sqlite://", echo=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sql_aengine() -> AsyncEngine:
|
||||
"""Yield redis client."""
|
||||
return create_async_engine(url="sqlite+aiosqlite:///:memory:", echo=True)
|
||||
|
||||
|
||||
def test_mget(sql_engine: Engine) -> None:
|
||||
"""Test mget method."""
|
||||
store = SQLStore(engine=sql_engine, namespace="test")
|
||||
store.create_schema()
|
||||
keys = ["key1", "key2"]
|
||||
with sql_engine.connect() as session:
|
||||
session.execute(
|
||||
text(
|
||||
"insert into langchain_key_value_stores ('namespace', 'key', 'value') "
|
||||
"values('test','key1',:value)"
|
||||
).bindparams(value=b"value1"),
|
||||
)
|
||||
session.execute(
|
||||
text(
|
||||
"insert into langchain_key_value_stores ('namespace', 'key', 'value') "
|
||||
"values('test','key2',:value)"
|
||||
).bindparams(value=b"value2"),
|
||||
)
|
||||
session.commit()
|
||||
|
||||
result = store.mget(keys)
|
||||
assert result == [b"value1", b"value2"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_amget(sql_aengine: AsyncEngine) -> None:
|
||||
"""Test mget method."""
|
||||
store = SQLStore(engine=sql_aengine, namespace="test")
|
||||
await store.acreate_schema()
|
||||
keys = ["key1", "key2"]
|
||||
async with sql_aengine.connect() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"insert into langchain_key_value_stores ('namespace', 'key', 'value') "
|
||||
"values('test','key1',:value)"
|
||||
).bindparams(value=b"value1"),
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"insert into langchain_key_value_stores ('namespace', 'key', 'value') "
|
||||
"values('test','key2',:value)"
|
||||
).bindparams(value=b"value2"),
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
result = await store.amget(keys)
|
||||
assert result == [b"value1", b"value2"]
|
||||
|
||||
|
||||
def test_mset(sql_engine: Engine) -> None:
|
||||
"""Test that multiple keys can be set."""
|
||||
store = SQLStore(engine=sql_engine, namespace="test")
|
||||
store.create_schema()
|
||||
key_value_pairs = [("key1", b"value1"), ("key2", b"value2")]
|
||||
store.mset(key_value_pairs)
|
||||
|
||||
with sql_engine.connect() as session:
|
||||
result = session.exec_driver_sql("select * from langchain_key_value_stores")
|
||||
assert result.keys() == ["namespace", "key", "value"]
|
||||
data = [(row[0], row[1]) for row in result]
|
||||
assert data == [("test", "key1"), ("test", "key2")]
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_amset(sql_aengine: AsyncEngine) -> None:
|
||||
"""Test that multiple keys can be set."""
|
||||
store = SQLStore(engine=sql_aengine, namespace="test")
|
||||
await store.acreate_schema()
|
||||
key_value_pairs = [("key1", b"value1"), ("key2", b"value2")]
|
||||
await store.amset(key_value_pairs)
|
||||
|
||||
async with sql_aengine.connect() as session:
|
||||
result = await session.exec_driver_sql(
|
||||
"select * from langchain_key_value_stores"
|
||||
)
|
||||
assert result.keys() == ["namespace", "key", "value"]
|
||||
data = [(row[0], row[1]) for row in result]
|
||||
assert data == [("test", "key1"), ("test", "key2")]
|
||||
await session.commit()
|
||||
|
||||
|
||||
def test_mdelete(sql_engine: Engine) -> None:
|
||||
"""Test that deletion works as expected."""
|
||||
store = SQLStore(engine=sql_engine, namespace="test")
|
||||
store.create_schema()
|
||||
keys = ["key1", "key2"]
|
||||
with sql_engine.connect() as session:
|
||||
session.execute(
|
||||
text(
|
||||
"insert into langchain_key_value_stores ('namespace', 'key', 'value') "
|
||||
"values('test','key1',:value)"
|
||||
).bindparams(value=b"value1"),
|
||||
)
|
||||
session.execute(
|
||||
text(
|
||||
"insert into langchain_key_value_stores ('namespace', 'key', 'value') "
|
||||
"values('test','key2',:value)"
|
||||
).bindparams(value=b"value2"),
|
||||
)
|
||||
session.commit()
|
||||
store.mdelete(keys)
|
||||
with sql_engine.connect() as session:
|
||||
result = session.exec_driver_sql("select * from langchain_key_value_stores")
|
||||
assert result.keys() == ["namespace", "key", "value"]
|
||||
data = [row for row in result]
|
||||
assert data == []
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_amdelete(sql_aengine: AsyncEngine) -> None:
|
||||
"""Test that deletion works as expected."""
|
||||
store = SQLStore(engine=sql_aengine, namespace="test")
|
||||
await store.acreate_schema()
|
||||
keys = ["key1", "key2"]
|
||||
async with sql_aengine.connect() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"insert into langchain_key_value_stores ('namespace', 'key', 'value') "
|
||||
"values('test','key1',:value)"
|
||||
).bindparams(value=b"value1"),
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"insert into langchain_key_value_stores ('namespace', 'key', 'value') "
|
||||
"values('test','key2',:value)"
|
||||
).bindparams(value=b"value2"),
|
||||
)
|
||||
await session.commit()
|
||||
await store.amdelete(keys)
|
||||
async with sql_aengine.connect() as session:
|
||||
result = await session.exec_driver_sql(
|
||||
"select * from langchain_key_value_stores"
|
||||
)
|
||||
assert result.keys() == ["namespace", "key", "value"]
|
||||
data = [row for row in result]
|
||||
assert data == []
|
||||
await session.commit()
|
||||
|
||||
|
||||
def test_yield_keys(sql_engine: Engine) -> None:
|
||||
store = SQLStore(engine=sql_engine, namespace="test")
|
||||
store.create_schema()
|
||||
key_value_pairs = [("key1", b"value1"), ("key2", b"value2")]
|
||||
store.mset(key_value_pairs)
|
||||
assert sorted(store.yield_keys()) == ["key1", "key2"]
|
||||
assert sorted(store.yield_keys(prefix="key")) == ["key1", "key2"]
|
||||
assert sorted(store.yield_keys(prefix="lang")) == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ayield_keys(sql_aengine: AsyncEngine) -> None:
|
||||
store = SQLStore(engine=sql_aengine, namespace="test")
|
||||
await store.acreate_schema()
|
||||
key_value_pairs = [("key1", b"value1"), ("key2", b"value2")]
|
||||
await store.amset(key_value_pairs)
|
||||
assert sorted([k async for k in store.ayield_keys()]) == ["key1", "key2"]
|
||||
assert sorted([k async for k in store.ayield_keys(prefix="key")]) == [
|
||||
"key1",
|
||||
"key2",
|
||||
]
|
||||
assert sorted([k async for k in store.ayield_keys(prefix="lang")]) == []
|
@ -0,0 +1,89 @@
|
||||
from typing import AsyncGenerator, Generator, cast
|
||||
|
||||
import pytest
|
||||
from langchain.storage._lc_store import create_kv_docstore, create_lc_store
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.stores import BaseStore
|
||||
|
||||
from langchain_community.storage.sql import SQLStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sql_store() -> Generator[SQLStore, None, None]:
|
||||
store = SQLStore(namespace="test", db_url="sqlite://")
|
||||
store.create_schema()
|
||||
yield store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_sql_store() -> AsyncGenerator[SQLStore, None]:
|
||||
store = SQLStore(namespace="test", db_url="sqlite+aiosqlite://", async_mode=True)
|
||||
await store.acreate_schema()
|
||||
yield store
|
||||
|
||||
|
||||
def test_create_lc_store(sql_store: SQLStore) -> None:
|
||||
"""Test that a docstore is created from a base store."""
|
||||
docstore: BaseStore[str, Document] = cast(
|
||||
BaseStore[str, Document], create_lc_store(sql_store)
|
||||
)
|
||||
docstore.mset([("key1", Document(page_content="hello", metadata={"key": "value"}))])
|
||||
fetched_doc = docstore.mget(["key1"])[0]
|
||||
assert fetched_doc is not None
|
||||
assert fetched_doc.page_content == "hello"
|
||||
assert fetched_doc.metadata == {"key": "value"}
|
||||
|
||||
|
||||
def test_create_kv_store(sql_store: SQLStore) -> None:
|
||||
"""Test that a docstore is created from a base store."""
|
||||
docstore = create_kv_docstore(sql_store)
|
||||
docstore.mset([("key1", Document(page_content="hello", metadata={"key": "value"}))])
|
||||
fetched_doc = docstore.mget(["key1"])[0]
|
||||
assert isinstance(fetched_doc, Document)
|
||||
assert fetched_doc.page_content == "hello"
|
||||
assert fetched_doc.metadata == {"key": "value"}
|
||||
|
||||
|
||||
@pytest.mark.requires("aiosqlite")
|
||||
async def test_async_create_kv_store(async_sql_store: SQLStore) -> None:
|
||||
"""Test that a docstore is created from a base store."""
|
||||
docstore = create_kv_docstore(async_sql_store)
|
||||
await docstore.amset(
|
||||
[("key1", Document(page_content="hello", metadata={"key": "value"}))]
|
||||
)
|
||||
fetched_doc = (await docstore.amget(["key1"]))[0]
|
||||
assert isinstance(fetched_doc, Document)
|
||||
assert fetched_doc.page_content == "hello"
|
||||
assert fetched_doc.metadata == {"key": "value"}
|
||||
|
||||
|
||||
def test_sample_sql_docstore(sql_store: SQLStore) -> None:
|
||||
# Set values for keys
|
||||
sql_store.mset([("key1", b"value1"), ("key2", b"value2")])
|
||||
|
||||
# Get values for keys
|
||||
values = sql_store.mget(["key1", "key2"]) # Returns [b"value1", b"value2"]
|
||||
assert values == [b"value1", b"value2"]
|
||||
# Delete keys
|
||||
sql_store.mdelete(["key1"])
|
||||
|
||||
# Iterate over keys
|
||||
assert [key for key in sql_store.yield_keys()] == ["key2"]
|
||||
|
||||
|
||||
@pytest.mark.requires("aiosqlite")
|
||||
async def test_async_sample_sql_docstore(async_sql_store: SQLStore) -> None:
|
||||
# Set values for keys
|
||||
await async_sql_store.amset([("key1", b"value1"), ("key2", b"value2")])
|
||||
# sql_store.mset([("key1", "value1"), ("key2", "value2")])
|
||||
|
||||
# Get values for keys
|
||||
values = await async_sql_store.amget(
|
||||
["key1", "key2"]
|
||||
) # Returns [b"value1", b"value2"]
|
||||
assert values == [b"value1", b"value2"]
|
||||
# Delete keys
|
||||
await async_sql_store.amdelete(["key1"])
|
||||
|
||||
# Iterate over keys
|
||||
assert [key async for key in async_sql_store.ayield_keys()] == ["key2"]
|
Loading…
Reference in New Issue