From 789cd5198d92f900a77f50c4117798a2edd35ee5 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Thu, 15 Feb 2024 15:52:56 +0100 Subject: [PATCH] community[patch]: Use astrapy built-in pagination prefetch in AstraDBLoader (#17569) --- .../document_loaders/astradb.py | 73 +++++-------------- .../document_loaders/test_astradb.py | 16 ++-- 2 files changed, 25 insertions(+), 64 deletions(-) diff --git a/libs/community/langchain_community/document_loaders/astradb.py b/libs/community/langchain_community/document_loaders/astradb.py index 5106af9d67..8dae5c4528 100644 --- a/libs/community/langchain_community/document_loaders/astradb.py +++ b/libs/community/langchain_community/document_loaders/astradb.py @@ -2,8 +2,6 @@ from __future__ import annotations import json import logging -import threading -from queue import Queue from typing import ( TYPE_CHECKING, Any, @@ -16,7 +14,6 @@ from typing import ( ) from langchain_core.documents import Document -from langchain_core.runnables import run_in_executor from langchain_community.document_loaders.base import BaseLoader from langchain_community.utilities.astradb import _AstraDBEnvironment @@ -33,6 +30,7 @@ class AstraDBLoader(BaseLoader): def __init__( self, collection_name: str, + *, token: Optional[str] = None, api_endpoint: Optional[str] = None, astra_db_client: Optional[AstraDB] = None, @@ -65,38 +63,27 @@ class AstraDBLoader(BaseLoader): return list(self.lazy_load()) def lazy_load(self) -> Iterator[Document]: - queue = Queue(self.nb_prefetched) # type: ignore - t = threading.Thread(target=self.fetch_results, args=(queue,)) - t.start() - while True: - doc = queue.get() - if doc is None: - break - yield doc - t.join() + for doc in self.collection.paginated_find( + filter=self.filter, + options=self.find_options, + projection=self.projection, + sort=None, + prefetched=self.nb_prefetched, + ): + yield Document( + page_content=self.extraction_function(doc), + metadata={ + "namespace": self.collection.astra_db.namespace, + "api_endpoint": self.collection.astra_db.base_url, + "collection": self.collection_name, + }, + ) async def aload(self) -> List[Document]: """Load data into Document objects.""" return [doc async for doc in self.alazy_load()] async def alazy_load(self) -> AsyncIterator[Document]: - if not self.astra_env.async_astra_db: - iterator = run_in_executor( - None, - self.collection.paginated_find, - filter=self.filter, - options=self.find_options, - projection=self.projection, - sort=None, - prefetched=True, - ) - done = object() - while True: - item = await run_in_executor(None, lambda it: next(it, done), iterator) - if item is done: - break - yield item # type: ignore[misc] - return async_collection = await self.astra_env.async_astra_db.collection( self.collection_name ) @@ -105,7 +92,7 @@ class AstraDBLoader(BaseLoader): options=self.find_options, projection=self.projection, sort=None, - prefetched=True, + prefetched=self.nb_prefetched, ): yield Document( page_content=self.extraction_function(doc), @@ -115,29 +102,3 @@ class AstraDBLoader(BaseLoader): "collection": self.collection_name, }, ) - - def fetch_results(self, queue: Queue): # type: ignore[no-untyped-def] - self.fetch_page_result(queue) - while self.find_options.get("pageState"): - self.fetch_page_result(queue) - queue.put(None) - - def fetch_page_result(self, queue: Queue): # type: ignore[no-untyped-def] - res = self.collection.find( - filter=self.filter, - options=self.find_options, - projection=self.projection, - sort=None, - ) - self.find_options["pageState"] = res["data"].get("nextPageState") - for doc in res["data"]["documents"]: - queue.put( - Document( - page_content=self.extraction_function(doc), - metadata={ - "namespace": self.collection.astra_db.namespace, - "api_endpoint": self.collection.astra_db.base_url, - "collection": self.collection.collection_name, - }, - ) - ) diff --git a/libs/community/tests/integration_tests/document_loaders/test_astradb.py b/libs/community/tests/integration_tests/document_loaders/test_astradb.py index 8f9146aacb..b0a1104f82 100644 --- a/libs/community/tests/integration_tests/document_loaders/test_astradb.py +++ b/libs/community/tests/integration_tests/document_loaders/test_astradb.py @@ -15,7 +15,7 @@ from __future__ import annotations import json import os import uuid -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, AsyncIterator, Iterator import pytest @@ -37,12 +37,12 @@ def _has_env_vars() -> bool: @pytest.fixture -def astra_db_collection() -> AstraDBCollection: +def astra_db_collection() -> Iterator[AstraDBCollection]: from astrapy.db import AstraDB astra_db = AstraDB( - token=ASTRA_DB_APPLICATION_TOKEN, - api_endpoint=ASTRA_DB_API_ENDPOINT, + token=ASTRA_DB_APPLICATION_TOKEN or "", + api_endpoint=ASTRA_DB_API_ENDPOINT or "", namespace=ASTRA_DB_KEYSPACE, ) collection_name = f"lc_test_loader_{str(uuid.uuid4()).split('-')[0]}" @@ -58,12 +58,12 @@ def astra_db_collection() -> AstraDBCollection: @pytest.fixture -async def async_astra_db_collection() -> AsyncAstraDBCollection: +async def async_astra_db_collection() -> AsyncIterator[AsyncAstraDBCollection]: from astrapy.db import AsyncAstraDB astra_db = AsyncAstraDB( - token=ASTRA_DB_APPLICATION_TOKEN, - api_endpoint=ASTRA_DB_API_ENDPOINT, + token=ASTRA_DB_APPLICATION_TOKEN or "", + api_endpoint=ASTRA_DB_API_ENDPOINT or "", namespace=ASTRA_DB_KEYSPACE, ) collection_name = f"lc_test_loader_{str(uuid.uuid4()).split('-')[0]}" @@ -167,5 +167,5 @@ class TestAstraDB: find_options={"limit": 30}, extraction_function=lambda x: x["foo"], ) - doc = await anext(loader.alazy_load()) # type: ignore[name-defined] + doc = await loader.alazy_load().__anext__() assert doc.page_content == "bar"