From d2d01370bc72bdbaea7f9e4be88ede3e90dcfb45 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Thu, 18 Apr 2024 21:45:20 +0200 Subject: [PATCH] community[minor]: Add async methods to CassandraLoader (#20609) Co-authored-by: Eugene Yurtsev --- .../document_loaders/cassandra.py | 14 +++++++++++ .../utilities/cassandra.py | 24 +++++++++++++++++++ .../document_loaders/test_cassandra.py | 24 ++++++++++++------- 3 files changed, 54 insertions(+), 8 deletions(-) create mode 100644 libs/community/langchain_community/utilities/cassandra.py diff --git a/libs/community/langchain_community/document_loaders/cassandra.py b/libs/community/langchain_community/document_loaders/cassandra.py index cc2e8f7d62..083a32d0a9 100644 --- a/libs/community/langchain_community/document_loaders/cassandra.py +++ b/libs/community/langchain_community/document_loaders/cassandra.py @@ -3,6 +3,7 @@ from __future__ import annotations from typing import ( TYPE_CHECKING, Any, + AsyncIterator, Callable, Iterator, Optional, @@ -13,6 +14,7 @@ from typing import ( from langchain_core.documents import Document from langchain_community.document_loaders.base import BaseLoader +from langchain_community.utilities.cassandra import wrapped_response_future _NOT_SET = object() @@ -112,3 +114,15 @@ class CassandraLoader(BaseLoader): yield Document( page_content=self.page_content_mapper(row), metadata=metadata ) + + async def alazy_load(self) -> AsyncIterator[Document]: + for row in await wrapped_response_future( + self.session.execute_async, + self.query, + **self.query_kwargs, + ): + metadata = self.metadata.copy() + metadata.update(self.metadata_mapper(row)) + yield Document( + page_content=self.page_content_mapper(row), metadata=metadata + ) diff --git a/libs/community/langchain_community/utilities/cassandra.py b/libs/community/langchain_community/utilities/cassandra.py new file mode 100644 index 0000000000..6ef65b71f4 --- /dev/null +++ b/libs/community/langchain_community/utilities/cassandra.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Any, Callable + +if TYPE_CHECKING: + from cassandra.cluster import ResponseFuture + + +async def wrapped_response_future( + func: Callable[..., ResponseFuture], *args: Any, **kwargs: Any +) -> Any: + loop = asyncio.get_event_loop() + asyncio_future = loop.create_future() + response_future = func(*args, **kwargs) + + def success_handler(_: Any) -> None: + loop.call_soon_threadsafe(asyncio_future.set_result, response_future.result()) + + def error_handler(exc: BaseException) -> None: + loop.call_soon_threadsafe(asyncio_future.set_exception, exc) + + response_future.add_callbacks(success_handler, error_handler) + return await asyncio_future diff --git a/libs/community/tests/integration_tests/document_loaders/test_cassandra.py b/libs/community/tests/integration_tests/document_loaders/test_cassandra.py index a93a6abba6..e154e2b5cd 100644 --- a/libs/community/tests/integration_tests/document_loaders/test_cassandra.py +++ b/libs/community/tests/integration_tests/document_loaders/test_cassandra.py @@ -55,9 +55,9 @@ def keyspace() -> Iterator[str]: session.execute(f"DROP TABLE IF EXISTS {keyspace}.{CASSANDRA_TABLE}") -def test_loader_table(keyspace: str) -> None: +async def test_loader_table(keyspace: str) -> None: loader = CassandraLoader(table=CASSANDRA_TABLE) - assert loader.load() == [ + expected = [ Document( page_content="Row(row_id='id1', body_blob='text1')", metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace}, @@ -67,24 +67,28 @@ def test_loader_table(keyspace: str) -> None: metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace}, ), ] + assert loader.load() == expected + assert await loader.aload() == expected -def test_loader_query(keyspace: str) -> None: +async def test_loader_query(keyspace: str) -> None: loader = CassandraLoader( query=f"SELECT body_blob FROM {keyspace}.{CASSANDRA_TABLE}" ) - assert loader.load() == [ + expected = [ Document(page_content="Row(body_blob='text1')"), Document(page_content="Row(body_blob='text2')"), ] + assert loader.load() == expected + assert await loader.aload() == expected -def test_loader_page_content_mapper(keyspace: str) -> None: +async def test_loader_page_content_mapper(keyspace: str) -> None: def mapper(row: Any) -> str: return str(row.body_blob) loader = CassandraLoader(table=CASSANDRA_TABLE, page_content_mapper=mapper) - assert loader.load() == [ + expected = [ Document( page_content="text1", metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace}, @@ -94,14 +98,16 @@ def test_loader_page_content_mapper(keyspace: str) -> None: metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace}, ), ] + assert loader.load() == expected + assert await loader.aload() == expected -def test_loader_metadata_mapper(keyspace: str) -> None: +async def test_loader_metadata_mapper(keyspace: str) -> None: def mapper(row: Any) -> dict: return {"id": row.row_id} loader = CassandraLoader(table=CASSANDRA_TABLE, metadata_mapper=mapper) - assert loader.load() == [ + expected = [ Document( page_content="Row(row_id='id1', body_blob='text1')", metadata={ @@ -119,3 +125,5 @@ def test_loader_metadata_mapper(keyspace: str) -> None: }, ), ] + assert loader.load() == expected + assert await loader.aload() == expected