mirror of https://github.com/hwchase17/langchain
community[minor]: Add Cassandra ByteStore (#22064)
parent
fea6b99b16
commit
74947ec894
@ -0,0 +1,188 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from asyncio import InvalidStateError, Task
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
AsyncIterator,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain_core.stores import ByteStore
|
||||||
|
|
||||||
|
from langchain_community.utilities.cassandra import SetupMode, aexecute_cql
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from cassandra.cluster import Session
|
||||||
|
from cassandra.query import PreparedStatement
|
||||||
|
|
||||||
|
CREATE_TABLE_CQL_TEMPLATE = """
|
||||||
|
CREATE TABLE IF NOT EXISTS {keyspace}.{table}
|
||||||
|
(row_id TEXT, body_blob BLOB, PRIMARY KEY (row_id));
|
||||||
|
"""
|
||||||
|
SELECT_TABLE_CQL_TEMPLATE = (
|
||||||
|
"""SELECT row_id, body_blob FROM {keyspace}.{table} WHERE row_id IN ?;"""
|
||||||
|
)
|
||||||
|
SELECT_ALL_TABLE_CQL_TEMPLATE = """SELECT row_id, body_blob FROM {keyspace}.{table};"""
|
||||||
|
INSERT_TABLE_CQL_TEMPLATE = (
|
||||||
|
"""INSERT INTO {keyspace}.{table} (row_id, body_blob) VALUES (?, ?);"""
|
||||||
|
)
|
||||||
|
DELETE_TABLE_CQL_TEMPLATE = """DELETE FROM {keyspace}.{table} WHERE row_id IN ?;"""
|
||||||
|
|
||||||
|
|
||||||
|
class CassandraByteStore(ByteStore):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
*,
|
||||||
|
session: Optional[Session] = None,
|
||||||
|
keyspace: Optional[str] = None,
|
||||||
|
setup_mode: SetupMode = SetupMode.SYNC,
|
||||||
|
) -> None:
|
||||||
|
if not session or not keyspace:
|
||||||
|
try:
|
||||||
|
from cassio.config import check_resolve_keyspace, check_resolve_session
|
||||||
|
|
||||||
|
self.keyspace = keyspace or check_resolve_keyspace(keyspace)
|
||||||
|
self.session = session or check_resolve_session()
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import a recent cassio package."
|
||||||
|
"Please install it with `pip install --upgrade cassio`."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.keyspace = keyspace
|
||||||
|
self.session = session
|
||||||
|
self.table = table
|
||||||
|
self.select_statement = None
|
||||||
|
self.insert_statement = None
|
||||||
|
self.delete_statement = None
|
||||||
|
|
||||||
|
create_cql = CREATE_TABLE_CQL_TEMPLATE.format(
|
||||||
|
keyspace=self.keyspace,
|
||||||
|
table=self.table,
|
||||||
|
)
|
||||||
|
self.db_setup_task: Optional[Task[None]] = None
|
||||||
|
if setup_mode == SetupMode.ASYNC:
|
||||||
|
self.db_setup_task = asyncio.create_task(
|
||||||
|
aexecute_cql(self.session, create_cql)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.session.execute(create_cql)
|
||||||
|
|
||||||
|
def ensure_db_setup(self) -> None:
|
||||||
|
if self.db_setup_task:
|
||||||
|
try:
|
||||||
|
self.db_setup_task.result()
|
||||||
|
except InvalidStateError:
|
||||||
|
raise ValueError(
|
||||||
|
"Asynchronous setup of the DB not finished. "
|
||||||
|
"NB: AstraDB components sync methods shouldn't be called from the "
|
||||||
|
"event loop. Consider using their async equivalents."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def aensure_db_setup(self) -> None:
|
||||||
|
if self.db_setup_task:
|
||||||
|
await self.db_setup_task
|
||||||
|
|
||||||
|
def get_select_statement(self) -> PreparedStatement:
|
||||||
|
if not self.select_statement:
|
||||||
|
self.select_statement = self.session.prepare(
|
||||||
|
SELECT_TABLE_CQL_TEMPLATE.format(
|
||||||
|
keyspace=self.keyspace, table=self.table
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return self.select_statement
|
||||||
|
|
||||||
|
def get_insert_statement(self) -> PreparedStatement:
|
||||||
|
if not self.insert_statement:
|
||||||
|
self.insert_statement = self.session.prepare(
|
||||||
|
INSERT_TABLE_CQL_TEMPLATE.format(
|
||||||
|
keyspace=self.keyspace, table=self.table
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return self.insert_statement
|
||||||
|
|
||||||
|
def get_delete_statement(self) -> PreparedStatement:
|
||||||
|
if not self.delete_statement:
|
||||||
|
self.delete_statement = self.session.prepare(
|
||||||
|
DELETE_TABLE_CQL_TEMPLATE.format(
|
||||||
|
keyspace=self.keyspace, table=self.table
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return self.delete_statement
|
||||||
|
|
||||||
|
def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
|
||||||
|
from cassandra.query import ValueSequence
|
||||||
|
|
||||||
|
self.ensure_db_setup()
|
||||||
|
docs_dict = {}
|
||||||
|
for row in self.session.execute(
|
||||||
|
self.get_select_statement(), [ValueSequence(keys)]
|
||||||
|
):
|
||||||
|
docs_dict[row.row_id] = row.body_blob
|
||||||
|
return [docs_dict.get(key) for key in keys]
|
||||||
|
|
||||||
|
async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
|
||||||
|
from cassandra.query import ValueSequence
|
||||||
|
|
||||||
|
await self.aensure_db_setup()
|
||||||
|
docs_dict = {}
|
||||||
|
for row in await aexecute_cql(
|
||||||
|
self.session, self.get_select_statement(), parameters=[ValueSequence(keys)]
|
||||||
|
):
|
||||||
|
docs_dict[row.row_id] = row.body_blob
|
||||||
|
return [docs_dict.get(key) for key in keys]
|
||||||
|
|
||||||
|
def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
|
||||||
|
self.ensure_db_setup()
|
||||||
|
insert_statement = self.get_insert_statement()
|
||||||
|
for k, v in key_value_pairs:
|
||||||
|
self.session.execute(insert_statement, (k, v))
|
||||||
|
|
||||||
|
async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
|
||||||
|
await self.aensure_db_setup()
|
||||||
|
insert_statement = self.get_insert_statement()
|
||||||
|
for k, v in key_value_pairs:
|
||||||
|
await aexecute_cql(self.session, insert_statement, parameters=(k, v))
|
||||||
|
|
||||||
|
def mdelete(self, keys: Sequence[str]) -> None:
|
||||||
|
from cassandra.query import ValueSequence
|
||||||
|
|
||||||
|
self.ensure_db_setup()
|
||||||
|
self.session.execute(self.get_delete_statement(), [ValueSequence(keys)])
|
||||||
|
|
||||||
|
async def amdelete(self, keys: Sequence[str]) -> None:
|
||||||
|
from cassandra.query import ValueSequence
|
||||||
|
|
||||||
|
await self.aensure_db_setup()
|
||||||
|
await aexecute_cql(
|
||||||
|
self.session, self.get_delete_statement(), parameters=[ValueSequence(keys)]
|
||||||
|
)
|
||||||
|
|
||||||
|
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
|
||||||
|
self.ensure_db_setup()
|
||||||
|
for row in self.session.execute(
|
||||||
|
SELECT_ALL_TABLE_CQL_TEMPLATE.format(
|
||||||
|
keyspace=self.keyspace, table=self.table
|
||||||
|
)
|
||||||
|
):
|
||||||
|
key = row.row_id
|
||||||
|
if not prefix or key.startswith(prefix):
|
||||||
|
yield key
|
||||||
|
|
||||||
|
async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]:
|
||||||
|
await self.aensure_db_setup()
|
||||||
|
for row in await aexecute_cql(
|
||||||
|
self.session,
|
||||||
|
SELECT_ALL_TABLE_CQL_TEMPLATE.format(
|
||||||
|
keyspace=self.keyspace, table=self.table
|
||||||
|
),
|
||||||
|
):
|
||||||
|
key = row.row_id
|
||||||
|
if not prefix or key.startswith(prefix):
|
||||||
|
yield key
|
@ -0,0 +1,155 @@
|
|||||||
|
"""Implement integration tests for Cassandra storage."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain_community.storage.cassandra import CassandraByteStore
|
||||||
|
from langchain_community.utilities.cassandra import SetupMode
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from cassandra.cluster import Session
|
||||||
|
|
||||||
|
KEYSPACE = "storage_test_keyspace"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def session() -> Session:
|
||||||
|
from cassandra.cluster import Cluster
|
||||||
|
|
||||||
|
cluster = Cluster()
|
||||||
|
session = cluster.connect()
|
||||||
|
session.execute(
|
||||||
|
(
|
||||||
|
f"CREATE KEYSPACE IF NOT EXISTS {KEYSPACE} "
|
||||||
|
f"WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
def init_store(table_name: str, session: Session) -> CassandraByteStore:
|
||||||
|
store = CassandraByteStore(table=table_name, keyspace=KEYSPACE, session=session)
|
||||||
|
store.mset([("key1", b"value1"), ("key2", b"value2")])
|
||||||
|
return store
|
||||||
|
|
||||||
|
|
||||||
|
async def init_async_store(table_name: str, session: Session) -> CassandraByteStore:
|
||||||
|
store = CassandraByteStore(
|
||||||
|
table=table_name, keyspace=KEYSPACE, session=session, setup_mode=SetupMode.ASYNC
|
||||||
|
)
|
||||||
|
await store.amset([("key1", b"value1"), ("key2", b"value2")])
|
||||||
|
return store
|
||||||
|
|
||||||
|
|
||||||
|
def drop_table(table_name: str, session: Session) -> None:
|
||||||
|
session.execute(f"DROP TABLE {KEYSPACE}.{table_name}")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_mget(session: Session) -> None:
|
||||||
|
"""Test CassandraByteStore mget method."""
|
||||||
|
table_name = "lc_test_store_mget"
|
||||||
|
try:
|
||||||
|
store = init_store(table_name, session)
|
||||||
|
assert store.mget(["key1", "key2"]) == [b"value1", b"value2"]
|
||||||
|
assert await store.amget(["key1", "key2"]) == [b"value1", b"value2"]
|
||||||
|
finally:
|
||||||
|
drop_table(table_name, session)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_amget(session: Session) -> None:
|
||||||
|
"""Test CassandraByteStore amget method."""
|
||||||
|
table_name = "lc_test_store_amget"
|
||||||
|
try:
|
||||||
|
store = await init_async_store(table_name, session)
|
||||||
|
assert await store.amget(["key1", "key2"]) == [b"value1", b"value2"]
|
||||||
|
finally:
|
||||||
|
drop_table(table_name, session)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mset(session: Session) -> None:
|
||||||
|
"""Test that multiple keys can be set with CassandraByteStore."""
|
||||||
|
table_name = "lc_test_store_mset"
|
||||||
|
try:
|
||||||
|
init_store(table_name, session)
|
||||||
|
result = session.execute(
|
||||||
|
"SELECT row_id, body_blob FROM storage_test_keyspace.lc_test_store_mset "
|
||||||
|
"WHERE row_id = 'key1';"
|
||||||
|
).one()
|
||||||
|
assert result.body_blob == b"value1"
|
||||||
|
result = session.execute(
|
||||||
|
"SELECT row_id, body_blob FROM storage_test_keyspace.lc_test_store_mset "
|
||||||
|
"WHERE row_id = 'key2';"
|
||||||
|
).one()
|
||||||
|
assert result.body_blob == b"value2"
|
||||||
|
finally:
|
||||||
|
drop_table(table_name, session)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_amset(session: Session) -> None:
|
||||||
|
"""Test that multiple keys can be set with CassandraByteStore."""
|
||||||
|
table_name = "lc_test_store_amset"
|
||||||
|
try:
|
||||||
|
await init_async_store(table_name, session)
|
||||||
|
result = session.execute(
|
||||||
|
"SELECT row_id, body_blob FROM storage_test_keyspace.lc_test_store_amset "
|
||||||
|
"WHERE row_id = 'key1';"
|
||||||
|
).one()
|
||||||
|
assert result.body_blob == b"value1"
|
||||||
|
result = session.execute(
|
||||||
|
"SELECT row_id, body_blob FROM storage_test_keyspace.lc_test_store_amset "
|
||||||
|
"WHERE row_id = 'key2';"
|
||||||
|
).one()
|
||||||
|
assert result.body_blob == b"value2"
|
||||||
|
finally:
|
||||||
|
drop_table(table_name, session)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mdelete(session: Session) -> None:
|
||||||
|
"""Test that deletion works as expected."""
|
||||||
|
table_name = "lc_test_store_mdelete"
|
||||||
|
try:
|
||||||
|
store = init_store(table_name, session)
|
||||||
|
store.mdelete(["key1", "key2"])
|
||||||
|
result = store.mget(["key1", "key2"])
|
||||||
|
assert result == [None, None]
|
||||||
|
finally:
|
||||||
|
drop_table(table_name, session)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_amdelete(session: Session) -> None:
|
||||||
|
"""Test that deletion works as expected."""
|
||||||
|
table_name = "lc_test_store_amdelete"
|
||||||
|
try:
|
||||||
|
store = await init_async_store(table_name, session)
|
||||||
|
await store.amdelete(["key1", "key2"])
|
||||||
|
result = await store.amget(["key1", "key2"])
|
||||||
|
assert result == [None, None]
|
||||||
|
finally:
|
||||||
|
drop_table(table_name, session)
|
||||||
|
|
||||||
|
|
||||||
|
def test_yield_keys(session: Session) -> None:
|
||||||
|
table_name = "lc_test_store_yield_keys"
|
||||||
|
try:
|
||||||
|
store = init_store(table_name, session)
|
||||||
|
assert set(store.yield_keys()) == {"key1", "key2"}
|
||||||
|
assert set(store.yield_keys(prefix="key")) == {"key1", "key2"}
|
||||||
|
assert set(store.yield_keys(prefix="lang")) == set()
|
||||||
|
finally:
|
||||||
|
drop_table(table_name, session)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_ayield_keys(session: Session) -> None:
|
||||||
|
table_name = "lc_test_store_ayield_keys"
|
||||||
|
try:
|
||||||
|
store = await init_async_store(table_name, session)
|
||||||
|
assert {key async for key in store.ayield_keys()} == {"key1", "key2"}
|
||||||
|
assert {key async for key in store.ayield_keys(prefix="key")} == {
|
||||||
|
"key1",
|
||||||
|
"key2",
|
||||||
|
}
|
||||||
|
assert {key async for key in store.ayield_keys(prefix="lang")} == set()
|
||||||
|
finally:
|
||||||
|
drop_table(table_name, session)
|
Loading…
Reference in New Issue