astradb[patch]: Add AstraDBStore to langchain-astradb package (#17789)

Co-authored-by: Erick Friis <erick@langchain.dev>
pull/17659/head^2
Christophe Bornet 8 months ago committed by GitHub
parent 4e28888d45
commit bebe401b1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -15,6 +15,7 @@ from typing import (
TypeVar, TypeVar,
) )
from langchain_core._api.deprecation import deprecated
from langchain_core.stores import BaseStore, ByteStore from langchain_core.stores import BaseStore, ByteStore
from langchain_community.utilities.astradb import ( from langchain_community.utilities.astradb import (
@ -124,6 +125,11 @@ class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC):
yield key yield key
@deprecated(
since="0.0.22",
removal="0.2.0",
alternative_import="langchain_astradb.AstraDBStore",
)
class AstraDBStore(AstraDBBaseStore[Any]): class AstraDBStore(AstraDBBaseStore[Any]):
"""BaseStore implementation using DataStax AstraDB as the underlying store. """BaseStore implementation using DataStax AstraDB as the underlying store.
@ -143,6 +149,11 @@ class AstraDBStore(AstraDBBaseStore[Any]):
return value return value
@deprecated(
since="0.0.22",
removal="0.2.0",
alternative_import="langchain_astradb.AstraDBByteStore",
)
class AstraDBByteStore(AstraDBBaseStore[bytes], ByteStore): class AstraDBByteStore(AstraDBBaseStore[bytes], ByteStore):
"""ByteStore implementation using DataStax AstraDB as the underlying store. """ByteStore implementation using DataStax AstraDB as the underlying store.

@ -1,5 +1,8 @@
from langchain_astradb.storage import AstraDBByteStore, AstraDBStore
from langchain_astradb.vectorstores import AstraDBVectorStore from langchain_astradb.vectorstores import AstraDBVectorStore
__all__ = [ __all__ = [
"AstraDBByteStore",
"AstraDBStore",
"AstraDBVectorStore", "AstraDBVectorStore",
] ]

@ -0,0 +1,217 @@
from __future__ import annotations
import base64
from abc import ABC, abstractmethod
from typing import (
Any,
AsyncIterator,
Generic,
Iterator,
List,
Optional,
Sequence,
Tuple,
TypeVar,
)
from astrapy.db import AstraDB, AsyncAstraDB
from langchain_core.stores import BaseStore, ByteStore
from langchain_astradb.utils.astradb import (
SetupMode,
_AstraDBCollectionEnvironment,
)
V = TypeVar("V")
class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC):
"""Base class for the DataStax AstraDB data store."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.astra_env = _AstraDBCollectionEnvironment(*args, **kwargs)
self.collection = self.astra_env.collection
self.async_collection = self.astra_env.async_collection
@abstractmethod
def decode_value(self, value: Any) -> Optional[V]:
"""Decodes value from Astra DB"""
@abstractmethod
def encode_value(self, value: Optional[V]) -> Any:
"""Encodes value for Astra DB"""
def mget(self, keys: Sequence[str]) -> List[Optional[V]]:
self.astra_env.ensure_db_setup()
docs_dict = {}
for doc in self.collection.paginated_find(filter={"_id": {"$in": list(keys)}}):
docs_dict[doc["_id"]] = doc.get("value")
return [self.decode_value(docs_dict.get(key)) for key in keys]
async def amget(self, keys: Sequence[str]) -> List[Optional[V]]:
await self.astra_env.aensure_db_setup()
docs_dict = {}
async for doc in self.async_collection.paginated_find(
filter={"_id": {"$in": list(keys)}}
):
docs_dict[doc["_id"]] = doc.get("value")
return [self.decode_value(docs_dict.get(key)) for key in keys]
def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
self.astra_env.ensure_db_setup()
for k, v in key_value_pairs:
self.collection.upsert({"_id": k, "value": self.encode_value(v)})
async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
await self.astra_env.aensure_db_setup()
for k, v in key_value_pairs:
await self.async_collection.upsert(
{"_id": k, "value": self.encode_value(v)}
)
def mdelete(self, keys: Sequence[str]) -> None:
self.astra_env.ensure_db_setup()
self.collection.delete_many(filter={"_id": {"$in": list(keys)}})
async def amdelete(self, keys: Sequence[str]) -> None:
await self.astra_env.aensure_db_setup()
await self.async_collection.delete_many(filter={"_id": {"$in": list(keys)}})
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
self.astra_env.ensure_db_setup()
docs = self.collection.paginated_find()
for doc in docs:
key = doc["_id"]
if not prefix or key.startswith(prefix):
yield key
async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]:
await self.astra_env.aensure_db_setup()
async for doc in self.async_collection.paginated_find():
key = doc["_id"]
if not prefix or key.startswith(prefix):
yield key
class AstraDBStore(AstraDBBaseStore[Any]):
def __init__(
self,
collection_name: str,
*,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
namespace: Optional[str] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
pre_delete_collection: bool = False,
setup_mode: SetupMode = SetupMode.SYNC,
) -> None:
"""BaseStore implementation using DataStax AstraDB as the underlying store.
The value type can be any type serializable by json.dumps.
Can be used to store embeddings with the CacheBackedEmbeddings.
Documents in the AstraDB collection will have the format
.. code-block:: json
{
"_id": "<key>",
"value": <value>
}
Args:
collection_name: name of the Astra DB collection to create/use.
token: API token for Astra DB usage.
api_endpoint: full URL to the API endpoint,
such as `https://<DB-ID>-us-east1.apps.astra.datastax.com`.
astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance.
async_astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
namespace: namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace".
setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or
OFF).
pre_delete_collection: whether to delete the collection
before creating it. If False and the collection already exists,
the collection will be used as is.
"""
super().__init__(
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
)
def decode_value(self, value: Any) -> Any:
return value
def encode_value(self, value: Any) -> Any:
return value
class AstraDBByteStore(AstraDBBaseStore[bytes], ByteStore):
def __init__(
self,
*,
collection_name: str,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
namespace: Optional[str] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
pre_delete_collection: bool = False,
setup_mode: SetupMode = SetupMode.SYNC,
) -> None:
"""ByteStore implementation using DataStax AstraDB as the underlying store.
The bytes values are converted to base64 encoded strings
Documents in the AstraDB collection will have the format
.. code-block:: json
{
"_id": "<key>",
"value": "<byte64 string value>"
}
Args:
collection_name: name of the Astra DB collection to create/use.
token: API token for Astra DB usage.
api_endpoint: full URL to the API endpoint,
such as `https://<DB-ID>-us-east1.apps.astra.datastax.com`.
astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance.
async_astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
namespace: namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace".
setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or
OFF).
pre_delete_collection: whether to delete the collection
before creating it. If False and the collection already exists,
the collection will be used as is.
"""
super().__init__(
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
)
def decode_value(self, value: Any) -> Optional[bytes]:
if value is None:
return None
return base64.b64decode(value)
def encode_value(self, value: Optional[bytes]) -> Any:
if value is None:
return None
return base64.b64encode(value).decode("ascii")

@ -0,0 +1,142 @@
from __future__ import annotations
import asyncio
import inspect
from asyncio import InvalidStateError, Task
from enum import Enum
from typing import Awaitable, Optional, Union
from astrapy.db import AstraDB, AsyncAstraDB
class SetupMode(Enum):
SYNC = 1
ASYNC = 2
OFF = 3
class _AstraDBEnvironment:
def __init__(
self,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
namespace: Optional[str] = None,
) -> None:
self.token = token
self.api_endpoint = api_endpoint
astra_db = astra_db_client
async_astra_db = async_astra_db_client
self.namespace = namespace
# Conflicting-arg checks:
if astra_db_client is not None or async_astra_db_client is not None:
if token is not None or api_endpoint is not None:
raise ValueError(
"You cannot pass 'astra_db_client' or 'async_astra_db_client' to "
"AstraDBEnvironment if passing 'token' and 'api_endpoint'."
)
if token and api_endpoint:
astra_db = AstraDB(
token=token,
api_endpoint=api_endpoint,
namespace=self.namespace,
)
async_astra_db = AsyncAstraDB(
token=token,
api_endpoint=api_endpoint,
namespace=self.namespace,
)
if astra_db:
self.astra_db = astra_db
if async_astra_db:
self.async_astra_db = async_astra_db
else:
self.async_astra_db = self.astra_db.to_async()
elif async_astra_db:
self.async_astra_db = async_astra_db
self.astra_db = self.async_astra_db.to_sync()
else:
raise ValueError(
"Must provide 'astra_db_client' or 'async_astra_db_client' or "
"'token' and 'api_endpoint'"
)
class _AstraDBCollectionEnvironment(_AstraDBEnvironment):
def __init__(
self,
collection_name: str,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
namespace: Optional[str] = None,
setup_mode: SetupMode = SetupMode.SYNC,
pre_delete_collection: bool = False,
embedding_dimension: Union[int, Awaitable[int], None] = None,
metric: Optional[str] = None,
) -> None:
from astrapy.db import AstraDBCollection, AsyncAstraDBCollection
super().__init__(
token, api_endpoint, astra_db_client, async_astra_db_client, namespace
)
self.collection_name = collection_name
self.collection = AstraDBCollection(
collection_name=collection_name,
astra_db=self.astra_db,
)
self.async_collection = AsyncAstraDBCollection(
collection_name=collection_name,
astra_db=self.async_astra_db,
)
self.async_setup_db_task: Optional[Task] = None
if setup_mode == SetupMode.ASYNC:
async_astra_db = self.async_astra_db
async def _setup_db() -> None:
if pre_delete_collection:
await async_astra_db.delete_collection(collection_name)
if inspect.isawaitable(embedding_dimension):
dimension = await embedding_dimension
else:
dimension = embedding_dimension
await async_astra_db.create_collection(
collection_name, dimension=dimension, metric=metric
)
self.async_setup_db_task = asyncio.create_task(_setup_db())
elif setup_mode == SetupMode.SYNC:
if pre_delete_collection:
self.astra_db.delete_collection(collection_name)
if inspect.isawaitable(embedding_dimension):
raise ValueError(
"Cannot use an awaitable embedding_dimension with async_setup "
"set to False"
)
self.astra_db.create_collection(
collection_name,
dimension=embedding_dimension, # type: ignore[arg-type]
metric=metric,
)
def ensure_db_setup(self) -> None:
if self.async_setup_db_task:
try:
self.async_setup_db_task.result()
except InvalidStateError:
raise ValueError(
"Asynchronous setup of the DB not finished. "
"NB: AstraDB components sync methods shouldn't be called from the "
"event loop. Consider using their async equivalents."
)
async def aensure_db_setup(self) -> None:
if self.async_setup_db_task:
await self.async_setup_db_task

@ -0,0 +1,176 @@
"""Implement integration tests for AstraDB storage."""
from __future__ import annotations
import os
import pytest
from astrapy.db import AstraDB, AsyncAstraDB
from langchain_astradb.storage import AstraDBByteStore, AstraDBStore
from langchain_astradb.utils.astradb import SetupMode
def _has_env_vars() -> bool:
return all(
[
"ASTRA_DB_APPLICATION_TOKEN" in os.environ,
"ASTRA_DB_API_ENDPOINT" in os.environ,
]
)
@pytest.fixture
def astra_db() -> AstraDB:
from astrapy.db import AstraDB
return AstraDB(
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)
@pytest.fixture
def async_astra_db() -> AsyncAstraDB:
from astrapy.db import AsyncAstraDB
return AsyncAstraDB(
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)
def init_store(astra_db: AstraDB, collection_name: str) -> AstraDBStore:
store = AstraDBStore(collection_name=collection_name, astra_db_client=astra_db)
store.mset([("key1", [0.1, 0.2]), ("key2", "value2")])
return store
def init_bytestore(astra_db: AstraDB, collection_name: str) -> AstraDBByteStore:
store = AstraDBByteStore(collection_name=collection_name, astra_db_client=astra_db)
store.mset([("key1", b"value1"), ("key2", b"value2")])
return store
async def init_async_store(
async_astra_db: AsyncAstraDB, collection_name: str
) -> AstraDBStore:
store = AstraDBStore(
collection_name=collection_name,
async_astra_db_client=async_astra_db,
setup_mode=SetupMode.ASYNC,
)
await store.amset([("key1", [0.1, 0.2]), ("key2", "value2")])
return store
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
class TestAstraDBStore:
def test_mget(self, astra_db: AstraDB) -> None:
"""Test AstraDBStore mget method."""
collection_name = "lc_test_store_mget"
try:
store = init_store(astra_db, collection_name)
assert store.mget(["key1", "key2"]) == [[0.1, 0.2], "value2"]
finally:
astra_db.delete_collection(collection_name)
async def test_amget(self, async_astra_db: AsyncAstraDB) -> None:
"""Test AstraDBStore amget method."""
collection_name = "lc_test_store_mget"
try:
store = await init_async_store(async_astra_db, collection_name)
assert await store.amget(["key1", "key2"]) == [[0.1, 0.2], "value2"]
finally:
await async_astra_db.delete_collection(collection_name)
def test_mset(self, astra_db: AstraDB) -> None:
"""Test that multiple keys can be set with AstraDBStore."""
collection_name = "lc_test_store_mset"
try:
store = init_store(astra_db, collection_name)
result = store.collection.find_one({"_id": "key1"})
assert result["data"]["document"]["value"] == [0.1, 0.2]
result = store.collection.find_one({"_id": "key2"})
assert result["data"]["document"]["value"] == "value2"
finally:
astra_db.delete_collection(collection_name)
async def test_amset(self, async_astra_db: AsyncAstraDB) -> None:
"""Test that multiple keys can be set with AstraDBStore."""
collection_name = "lc_test_store_mset"
try:
store = await init_async_store(async_astra_db, collection_name)
result = await store.async_collection.find_one({"_id": "key1"})
assert result["data"]["document"]["value"] == [0.1, 0.2]
result = await store.async_collection.find_one({"_id": "key2"})
assert result["data"]["document"]["value"] == "value2"
finally:
await async_astra_db.delete_collection(collection_name)
def test_mdelete(self, astra_db: AstraDB) -> None:
"""Test that deletion works as expected."""
collection_name = "lc_test_store_mdelete"
try:
store = init_store(astra_db, collection_name)
store.mdelete(["key1", "key2"])
result = store.mget(["key1", "key2"])
assert result == [None, None]
finally:
astra_db.delete_collection(collection_name)
async def test_amdelete(self, async_astra_db: AsyncAstraDB) -> None:
"""Test that deletion works as expected."""
collection_name = "lc_test_store_mdelete"
try:
store = await init_async_store(async_astra_db, collection_name)
await store.amdelete(["key1", "key2"])
result = await store.amget(["key1", "key2"])
assert result == [None, None]
finally:
await async_astra_db.delete_collection(collection_name)
def test_yield_keys(self, astra_db: AstraDB) -> None:
collection_name = "lc_test_store_yield_keys"
try:
store = init_store(astra_db, collection_name)
assert set(store.yield_keys()) == {"key1", "key2"}
assert set(store.yield_keys(prefix="key")) == {"key1", "key2"}
assert set(store.yield_keys(prefix="lang")) == set()
finally:
astra_db.delete_collection(collection_name)
async def test_ayield_keys(self, async_astra_db: AsyncAstraDB) -> None:
collection_name = "lc_test_store_yield_keys"
try:
store = await init_async_store(async_astra_db, collection_name)
assert {key async for key in store.ayield_keys()} == {"key1", "key2"}
assert {key async for key in store.ayield_keys(prefix="key")} == {
"key1",
"key2",
}
assert {key async for key in store.ayield_keys(prefix="lang")} == set()
finally:
await async_astra_db.delete_collection(collection_name)
def test_bytestore_mget(self, astra_db: AstraDB) -> None:
"""Test AstraDBByteStore mget method."""
collection_name = "lc_test_bytestore_mget"
try:
store = init_bytestore(astra_db, collection_name)
assert store.mget(["key1", "key2"]) == [b"value1", b"value2"]
finally:
astra_db.delete_collection(collection_name)
def test_bytestore_mset(self, astra_db: AstraDB) -> None:
"""Test that multiple keys can be set with AstraDBByteStore."""
collection_name = "lc_test_bytestore_mset"
try:
store = init_bytestore(astra_db, collection_name)
result = store.collection.find_one({"_id": "key1"})
assert result["data"]["document"]["value"] == "dmFsdWUx"
result = store.collection.find_one({"_id": "key2"})
assert result["data"]["document"]["value"] == "dmFsdWUy"
finally:
astra_db.delete_collection(collection_name)

@ -1,6 +1,8 @@
from langchain_astradb import __all__ from langchain_astradb import __all__
EXPECTED_ALL = [ EXPECTED_ALL = [
"AstraDBByteStore",
"AstraDBStore",
"AstraDBVectorStore", "AstraDBVectorStore",
] ]

Loading…
Cancel
Save