import contextlib from pathlib import Path from typing import ( Any, AsyncGenerator, AsyncIterator, Dict, Generator, Iterator, List, Optional, Sequence, Tuple, Union, cast, ) from langchain_core.stores import BaseStore from sqlalchemy import ( Engine, LargeBinary, and_, create_engine, delete, select, ) from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine, ) from sqlalchemy.orm import ( Mapped, Session, declarative_base, mapped_column, sessionmaker, ) Base = declarative_base() def items_equal(x: Any, y: Any) -> bool: return x == y class LangchainKeyValueStores(Base): # type: ignore[valid-type,misc] """Table used to save values.""" # ATTENTION: # Prior to modifying this table, please determine whether # we should create migrations for this table to make sure # users do not experience data loss. __tablename__ = "langchain_key_value_stores" namespace: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False) key: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False) value = mapped_column(LargeBinary, index=False, nullable=False) # This is a fix of original SQLStore. # This can will be removed when a PR will be merged. class SQLStore(BaseStore[str, bytes]): """BaseStore interface that works on an SQL database. Examples: Create a SQLStore instance and perform operations on it: .. code-block:: python from langchain_rag.storage import SQLStore # Instantiate the SQLStore with the root path sql_store = SQLStore(namespace="test", db_url="sqllite://:memory:") # Set values for keys sql_store.mset([("key1", b"value1"), ("key2", b"value2")]) # Get values for keys values = sql_store.mget(["key1", "key2"]) # Returns [b"value1", b"value2"] # Delete keys sql_store.mdelete(["key1"]) # Iterate over keys for key in sql_store.yield_keys(): print(key) """ def __init__( self, *, namespace: str, db_url: Optional[Union[str, Path]] = None, engine: Optional[Union[Engine, AsyncEngine]] = None, engine_kwargs: Optional[Dict[str, Any]] = None, async_mode: Optional[bool] = None, ): if db_url is None and engine is None: raise ValueError("Must specify either db_url or engine") if db_url is not None and engine is not None: raise ValueError("Must specify either db_url or engine, not both") _engine: Union[Engine, AsyncEngine] if db_url: if async_mode is None: async_mode = False if async_mode: _engine = create_async_engine( url=str(db_url), **(engine_kwargs or {}), ) else: _engine = create_engine(url=str(db_url), **(engine_kwargs or {})) elif engine: _engine = engine else: raise AssertionError("Something went wrong with configuration of engine.") _session_maker: Union[sessionmaker[Session], async_sessionmaker[AsyncSession]] if isinstance(_engine, AsyncEngine): self.async_mode = True _session_maker = async_sessionmaker(bind=_engine) else: self.async_mode = False _session_maker = sessionmaker(bind=_engine) self.engine = _engine self.dialect = _engine.dialect.name self.session_maker = _session_maker self.namespace = namespace def create_schema(self) -> None: Base.metadata.create_all(self.engine) async def acreate_schema(self) -> None: assert isinstance(self.engine, AsyncEngine) async with self.engine.begin() as session: await session.run_sync(Base.metadata.create_all) def drop(self) -> None: Base.metadata.drop_all(bind=self.engine.connect()) async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]: assert isinstance(self.engine, AsyncEngine) result: Dict[str, bytes] = {} async with self._make_async_session() as session: stmt = select(LangchainKeyValueStores).filter( and_( LangchainKeyValueStores.key.in_(keys), LangchainKeyValueStores.namespace == self.namespace, ) ) for v in await session.scalars(stmt): result[v.key] = v.value return [result.get(key) for key in keys] def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]: result = {} with self._make_sync_session() as session: stmt = select(LangchainKeyValueStores).filter( and_( LangchainKeyValueStores.key.in_(keys), LangchainKeyValueStores.namespace == self.namespace, ) ) for v in session.scalars(stmt): result[v.key] = v.value return [result.get(key) for key in keys] async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: async with self._make_async_session() as session: await self._amdelete([key for key, _ in key_value_pairs], session) session.add_all( [ LangchainKeyValueStores(namespace=self.namespace, key=k, value=v) for k, v in key_value_pairs ] ) await session.commit() def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: values: Dict[str, bytes] = dict(key_value_pairs) with self._make_sync_session() as session: self._mdelete(list(values.keys()), session) session.add_all( [ LangchainKeyValueStores(namespace=self.namespace, key=k, value=v) for k, v in values.items() ] ) session.commit() def _mdelete(self, keys: Sequence[str], session: Session) -> None: stmt = delete(LangchainKeyValueStores).filter( and_( LangchainKeyValueStores.key.in_(keys), LangchainKeyValueStores.namespace == self.namespace, ) ) session.execute(stmt) async def _amdelete(self, keys: Sequence[str], session: AsyncSession) -> None: stmt = delete(LangchainKeyValueStores).filter( and_( LangchainKeyValueStores.key.in_(keys), LangchainKeyValueStores.namespace == self.namespace, ) ) await session.execute(stmt) def mdelete(self, keys: Sequence[str]) -> None: with self._make_sync_session() as session: self._mdelete(keys, session) session.commit() async def amdelete(self, keys: Sequence[str]) -> None: async with self._make_async_session() as session: await self._amdelete(keys, session) await session.commit() def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: with self._make_sync_session() as session: for v in session.query(LangchainKeyValueStores).filter( # type: ignore LangchainKeyValueStores.namespace == self.namespace ): if str(v.key).startswith(prefix or ""): yield str(v.key) session.close() async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]: async with self._make_async_session() as session: stmt = select(LangchainKeyValueStores).filter( LangchainKeyValueStores.namespace == self.namespace ) for v in await session.scalars(stmt): if str(v.key).startswith(prefix or ""): yield str(v.key) await session.close() @contextlib.contextmanager def _make_sync_session(self) -> Generator[Session, None, None]: """Make an async session.""" if self.async_mode: raise ValueError( "Attempting to use a sync method in when async mode is turned on. " "Please use the corresponding async method instead." ) with cast(Session, self.session_maker()) as session: yield cast(Session, session) @contextlib.asynccontextmanager async def _make_async_session(self) -> AsyncGenerator[AsyncSession, None]: """Make an async session.""" if not self.async_mode: raise ValueError( "Attempting to use an async method in when sync mode is turned on. " "Please use the corresponding async method instead." ) async with cast(AsyncSession, self.session_maker()) as session: yield cast(AsyncSession, session)