community[minor]: Add async methods to CassandraLoader (#20609)

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
pull/20620/head
Christophe Bornet 3 months ago committed by GitHub
parent 8c29b7bf35
commit d2d01370bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -3,6 +3,7 @@ from __future__ import annotations
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
Iterator,
Optional,
@ -13,6 +14,7 @@ from typing import (
from langchain_core.documents import Document
from langchain_community.document_loaders.base import BaseLoader
from langchain_community.utilities.cassandra import wrapped_response_future
_NOT_SET = object()
@ -112,3 +114,15 @@ class CassandraLoader(BaseLoader):
yield Document(
page_content=self.page_content_mapper(row), metadata=metadata
)
async def alazy_load(self) -> AsyncIterator[Document]:
for row in await wrapped_response_future(
self.session.execute_async,
self.query,
**self.query_kwargs,
):
metadata = self.metadata.copy()
metadata.update(self.metadata_mapper(row))
yield Document(
page_content=self.page_content_mapper(row), metadata=metadata
)

@ -0,0 +1,24 @@
from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING, Any, Callable
if TYPE_CHECKING:
from cassandra.cluster import ResponseFuture
async def wrapped_response_future(
func: Callable[..., ResponseFuture], *args: Any, **kwargs: Any
) -> Any:
loop = asyncio.get_event_loop()
asyncio_future = loop.create_future()
response_future = func(*args, **kwargs)
def success_handler(_: Any) -> None:
loop.call_soon_threadsafe(asyncio_future.set_result, response_future.result())
def error_handler(exc: BaseException) -> None:
loop.call_soon_threadsafe(asyncio_future.set_exception, exc)
response_future.add_callbacks(success_handler, error_handler)
return await asyncio_future

@ -55,9 +55,9 @@ def keyspace() -> Iterator[str]:
session.execute(f"DROP TABLE IF EXISTS {keyspace}.{CASSANDRA_TABLE}")
def test_loader_table(keyspace: str) -> None:
async def test_loader_table(keyspace: str) -> None:
loader = CassandraLoader(table=CASSANDRA_TABLE)
assert loader.load() == [
expected = [
Document(
page_content="Row(row_id='id1', body_blob='text1')",
metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace},
@ -67,24 +67,28 @@ def test_loader_table(keyspace: str) -> None:
metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace},
),
]
assert loader.load() == expected
assert await loader.aload() == expected
def test_loader_query(keyspace: str) -> None:
async def test_loader_query(keyspace: str) -> None:
loader = CassandraLoader(
query=f"SELECT body_blob FROM {keyspace}.{CASSANDRA_TABLE}"
)
assert loader.load() == [
expected = [
Document(page_content="Row(body_blob='text1')"),
Document(page_content="Row(body_blob='text2')"),
]
assert loader.load() == expected
assert await loader.aload() == expected
def test_loader_page_content_mapper(keyspace: str) -> None:
async def test_loader_page_content_mapper(keyspace: str) -> None:
def mapper(row: Any) -> str:
return str(row.body_blob)
loader = CassandraLoader(table=CASSANDRA_TABLE, page_content_mapper=mapper)
assert loader.load() == [
expected = [
Document(
page_content="text1",
metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace},
@ -94,14 +98,16 @@ def test_loader_page_content_mapper(keyspace: str) -> None:
metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace},
),
]
assert loader.load() == expected
assert await loader.aload() == expected
def test_loader_metadata_mapper(keyspace: str) -> None:
async def test_loader_metadata_mapper(keyspace: str) -> None:
def mapper(row: Any) -> dict:
return {"id": row.row_id}
loader = CassandraLoader(table=CASSANDRA_TABLE, metadata_mapper=mapper)
assert loader.load() == [
expected = [
Document(
page_content="Row(row_id='id1', body_blob='text1')",
metadata={
@ -119,3 +125,5 @@ def test_loader_metadata_mapper(keyspace: str) -> None:
},
),
]
assert loader.load() == expected
assert await loader.aload() == expected

Loading…
Cancel
Save