community[minor]: Add Cassandra ByteStore (#22064)

pull/22052/merge
Christophe Bornet 3 weeks ago committed by GitHub
parent fea6b99b16
commit 74947ec894
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -14,7 +14,7 @@ from typing import (
from langchain_core.documents import Document
from langchain_community.document_loaders.base import BaseLoader
from langchain_community.utilities.cassandra import wrapped_response_future
from langchain_community.utilities.cassandra import aexecute_cql
_NOT_SET = object()
@ -118,11 +118,7 @@ class CassandraLoader(BaseLoader):
)
async def alazy_load(self) -> AsyncIterator[Document]:
for row in await wrapped_response_future(
self.session.execute_async,
self.query,
**self.query_kwargs,
):
for row in await aexecute_cql(self.session, self.query, **self.query_kwargs):
metadata = self.metadata.copy()
metadata.update(self.metadata_mapper(row))
yield Document(

@ -0,0 +1,188 @@
from __future__ import annotations
import asyncio
from asyncio import InvalidStateError, Task
from typing import (
TYPE_CHECKING,
AsyncIterator,
Iterator,
List,
Optional,
Sequence,
Tuple,
)
from langchain_core.stores import ByteStore
from langchain_community.utilities.cassandra import SetupMode, aexecute_cql
if TYPE_CHECKING:
from cassandra.cluster import Session
from cassandra.query import PreparedStatement
CREATE_TABLE_CQL_TEMPLATE = """
CREATE TABLE IF NOT EXISTS {keyspace}.{table}
(row_id TEXT, body_blob BLOB, PRIMARY KEY (row_id));
"""
SELECT_TABLE_CQL_TEMPLATE = (
"""SELECT row_id, body_blob FROM {keyspace}.{table} WHERE row_id IN ?;"""
)
SELECT_ALL_TABLE_CQL_TEMPLATE = """SELECT row_id, body_blob FROM {keyspace}.{table};"""
INSERT_TABLE_CQL_TEMPLATE = (
"""INSERT INTO {keyspace}.{table} (row_id, body_blob) VALUES (?, ?);"""
)
DELETE_TABLE_CQL_TEMPLATE = """DELETE FROM {keyspace}.{table} WHERE row_id IN ?;"""
class CassandraByteStore(ByteStore):
def __init__(
self,
table: str,
*,
session: Optional[Session] = None,
keyspace: Optional[str] = None,
setup_mode: SetupMode = SetupMode.SYNC,
) -> None:
if not session or not keyspace:
try:
from cassio.config import check_resolve_keyspace, check_resolve_session
self.keyspace = keyspace or check_resolve_keyspace(keyspace)
self.session = session or check_resolve_session()
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Could not import a recent cassio package."
"Please install it with `pip install --upgrade cassio`."
)
else:
self.keyspace = keyspace
self.session = session
self.table = table
self.select_statement = None
self.insert_statement = None
self.delete_statement = None
create_cql = CREATE_TABLE_CQL_TEMPLATE.format(
keyspace=self.keyspace,
table=self.table,
)
self.db_setup_task: Optional[Task[None]] = None
if setup_mode == SetupMode.ASYNC:
self.db_setup_task = asyncio.create_task(
aexecute_cql(self.session, create_cql)
)
else:
self.session.execute(create_cql)
def ensure_db_setup(self) -> None:
if self.db_setup_task:
try:
self.db_setup_task.result()
except InvalidStateError:
raise ValueError(
"Asynchronous setup of the DB not finished. "
"NB: AstraDB components sync methods shouldn't be called from the "
"event loop. Consider using their async equivalents."
)
async def aensure_db_setup(self) -> None:
if self.db_setup_task:
await self.db_setup_task
def get_select_statement(self) -> PreparedStatement:
if not self.select_statement:
self.select_statement = self.session.prepare(
SELECT_TABLE_CQL_TEMPLATE.format(
keyspace=self.keyspace, table=self.table
)
)
return self.select_statement
def get_insert_statement(self) -> PreparedStatement:
if not self.insert_statement:
self.insert_statement = self.session.prepare(
INSERT_TABLE_CQL_TEMPLATE.format(
keyspace=self.keyspace, table=self.table
)
)
return self.insert_statement
def get_delete_statement(self) -> PreparedStatement:
if not self.delete_statement:
self.delete_statement = self.session.prepare(
DELETE_TABLE_CQL_TEMPLATE.format(
keyspace=self.keyspace, table=self.table
)
)
return self.delete_statement
def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
from cassandra.query import ValueSequence
self.ensure_db_setup()
docs_dict = {}
for row in self.session.execute(
self.get_select_statement(), [ValueSequence(keys)]
):
docs_dict[row.row_id] = row.body_blob
return [docs_dict.get(key) for key in keys]
async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
from cassandra.query import ValueSequence
await self.aensure_db_setup()
docs_dict = {}
for row in await aexecute_cql(
self.session, self.get_select_statement(), parameters=[ValueSequence(keys)]
):
docs_dict[row.row_id] = row.body_blob
return [docs_dict.get(key) for key in keys]
def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
self.ensure_db_setup()
insert_statement = self.get_insert_statement()
for k, v in key_value_pairs:
self.session.execute(insert_statement, (k, v))
async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
await self.aensure_db_setup()
insert_statement = self.get_insert_statement()
for k, v in key_value_pairs:
await aexecute_cql(self.session, insert_statement, parameters=(k, v))
def mdelete(self, keys: Sequence[str]) -> None:
from cassandra.query import ValueSequence
self.ensure_db_setup()
self.session.execute(self.get_delete_statement(), [ValueSequence(keys)])
async def amdelete(self, keys: Sequence[str]) -> None:
from cassandra.query import ValueSequence
await self.aensure_db_setup()
await aexecute_cql(
self.session, self.get_delete_statement(), parameters=[ValueSequence(keys)]
)
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
self.ensure_db_setup()
for row in self.session.execute(
SELECT_ALL_TABLE_CQL_TEMPLATE.format(
keyspace=self.keyspace, table=self.table
)
):
key = row.row_id
if not prefix or key.startswith(prefix):
yield key
async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]:
await self.aensure_db_setup()
for row in await aexecute_cql(
self.session,
SELECT_ALL_TABLE_CQL_TEMPLATE.format(
keyspace=self.keyspace, table=self.table
),
):
key = row.row_id
if not prefix or key.startswith(prefix):
yield key

@ -5,7 +5,7 @@ from enum import Enum
from typing import TYPE_CHECKING, Any, Callable
if TYPE_CHECKING:
from cassandra.cluster import ResponseFuture
from cassandra.cluster import ResponseFuture, Session
async def wrapped_response_future(
@ -35,6 +35,10 @@ async def wrapped_response_future(
return await asyncio_future
async def aexecute_cql(session: Session, query: str, **kwargs: Any) -> Any:
return await wrapped_response_future(session.execute_async, query, **kwargs)
class SetupMode(Enum):
SYNC = 1
ASYNC = 2

@ -0,0 +1,155 @@
"""Implement integration tests for Cassandra storage."""
from __future__ import annotations
from typing import TYPE_CHECKING
import pytest
from langchain_community.storage.cassandra import CassandraByteStore
from langchain_community.utilities.cassandra import SetupMode
if TYPE_CHECKING:
from cassandra.cluster import Session
KEYSPACE = "storage_test_keyspace"
@pytest.fixture(scope="session")
def session() -> Session:
from cassandra.cluster import Cluster
cluster = Cluster()
session = cluster.connect()
session.execute(
(
f"CREATE KEYSPACE IF NOT EXISTS {KEYSPACE} "
f"WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}"
)
)
return session
def init_store(table_name: str, session: Session) -> CassandraByteStore:
store = CassandraByteStore(table=table_name, keyspace=KEYSPACE, session=session)
store.mset([("key1", b"value1"), ("key2", b"value2")])
return store
async def init_async_store(table_name: str, session: Session) -> CassandraByteStore:
store = CassandraByteStore(
table=table_name, keyspace=KEYSPACE, session=session, setup_mode=SetupMode.ASYNC
)
await store.amset([("key1", b"value1"), ("key2", b"value2")])
return store
def drop_table(table_name: str, session: Session) -> None:
session.execute(f"DROP TABLE {KEYSPACE}.{table_name}")
async def test_mget(session: Session) -> None:
"""Test CassandraByteStore mget method."""
table_name = "lc_test_store_mget"
try:
store = init_store(table_name, session)
assert store.mget(["key1", "key2"]) == [b"value1", b"value2"]
assert await store.amget(["key1", "key2"]) == [b"value1", b"value2"]
finally:
drop_table(table_name, session)
async def test_amget(session: Session) -> None:
"""Test CassandraByteStore amget method."""
table_name = "lc_test_store_amget"
try:
store = await init_async_store(table_name, session)
assert await store.amget(["key1", "key2"]) == [b"value1", b"value2"]
finally:
drop_table(table_name, session)
def test_mset(session: Session) -> None:
"""Test that multiple keys can be set with CassandraByteStore."""
table_name = "lc_test_store_mset"
try:
init_store(table_name, session)
result = session.execute(
"SELECT row_id, body_blob FROM storage_test_keyspace.lc_test_store_mset "
"WHERE row_id = 'key1';"
).one()
assert result.body_blob == b"value1"
result = session.execute(
"SELECT row_id, body_blob FROM storage_test_keyspace.lc_test_store_mset "
"WHERE row_id = 'key2';"
).one()
assert result.body_blob == b"value2"
finally:
drop_table(table_name, session)
async def test_amset(session: Session) -> None:
"""Test that multiple keys can be set with CassandraByteStore."""
table_name = "lc_test_store_amset"
try:
await init_async_store(table_name, session)
result = session.execute(
"SELECT row_id, body_blob FROM storage_test_keyspace.lc_test_store_amset "
"WHERE row_id = 'key1';"
).one()
assert result.body_blob == b"value1"
result = session.execute(
"SELECT row_id, body_blob FROM storage_test_keyspace.lc_test_store_amset "
"WHERE row_id = 'key2';"
).one()
assert result.body_blob == b"value2"
finally:
drop_table(table_name, session)
def test_mdelete(session: Session) -> None:
"""Test that deletion works as expected."""
table_name = "lc_test_store_mdelete"
try:
store = init_store(table_name, session)
store.mdelete(["key1", "key2"])
result = store.mget(["key1", "key2"])
assert result == [None, None]
finally:
drop_table(table_name, session)
async def test_amdelete(session: Session) -> None:
"""Test that deletion works as expected."""
table_name = "lc_test_store_amdelete"
try:
store = await init_async_store(table_name, session)
await store.amdelete(["key1", "key2"])
result = await store.amget(["key1", "key2"])
assert result == [None, None]
finally:
drop_table(table_name, session)
def test_yield_keys(session: Session) -> None:
table_name = "lc_test_store_yield_keys"
try:
store = init_store(table_name, session)
assert set(store.yield_keys()) == {"key1", "key2"}
assert set(store.yield_keys(prefix="key")) == {"key1", "key2"}
assert set(store.yield_keys(prefix="lang")) == set()
finally:
drop_table(table_name, session)
async def test_ayield_keys(session: Session) -> None:
table_name = "lc_test_store_ayield_keys"
try:
store = await init_async_store(table_name, session)
assert {key async for key in store.ayield_keys()} == {"key1", "key2"}
assert {key async for key in store.ayield_keys(prefix="key")} == {
"key1",
"key2",
}
assert {key async for key in store.ayield_keys(prefix="lang")} == set()
finally:
drop_table(table_name, session)
Loading…
Cancel
Save