From 2e3af040809d81a7ecb54db07f036a48bb852ce4 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Mon, 29 Jan 2024 01:39:27 +0100 Subject: [PATCH] Use Postponed Evaluation of Annotations in Astra and Cassandra doc loaders (#16694) Minor/cosmetic change --- .../document_loaders/astradb.py | 6 ++++-- .../document_loaders/cassandra.py | 8 +++++--- .../document_loaders/test_astradb.py | 16 ++++++++-------- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/libs/community/langchain_community/document_loaders/astradb.py b/libs/community/langchain_community/document_loaders/astradb.py index b8f33b1966..3562d42489 100644 --- a/libs/community/langchain_community/document_loaders/astradb.py +++ b/libs/community/langchain_community/document_loaders/astradb.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import logging import threading @@ -31,8 +33,8 @@ class AstraDBLoader(BaseLoader): collection_name: str, token: Optional[str] = None, api_endpoint: Optional[str] = None, - astra_db_client: Optional["AstraDB"] = None, - async_astra_db_client: Optional["AsyncAstraDB"] = None, + astra_db_client: Optional[AstraDB] = None, + async_astra_db_client: Optional[AsyncAstraDB] = None, namespace: Optional[str] = None, filter_criteria: Optional[Dict[str, Any]] = None, projection: Optional[Dict[str, Any]] = None, diff --git a/libs/community/langchain_community/document_loaders/cassandra.py b/libs/community/langchain_community/document_loaders/cassandra.py index ae7c7e86b5..3167711228 100644 --- a/libs/community/langchain_community/document_loaders/cassandra.py +++ b/libs/community/langchain_community/document_loaders/cassandra.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import ( TYPE_CHECKING, Any, @@ -25,9 +27,9 @@ class CassandraLoader(BaseLoader): def __init__( self, table: Optional[str] = None, - session: Optional["Session"] = None, + session: Optional[Session] = None, keyspace: Optional[str] = None, - query: Optional[Union[str, "Statement"]] = None, + query: Optional[Union[str, Statement]] = None, page_content_mapper: Callable[[Any], str] = str, metadata_mapper: Callable[[Any], dict] = lambda _: {}, *, @@ -37,7 +39,7 @@ class CassandraLoader(BaseLoader): query_custom_payload: dict = None, query_execution_profile: Any = _NOT_SET, query_paging_state: Any = None, - query_host: "Host" = None, + query_host: Host = None, query_execute_as: str = None, ) -> None: """ diff --git a/libs/community/tests/integration_tests/document_loaders/test_astradb.py b/libs/community/tests/integration_tests/document_loaders/test_astradb.py index e6b0043428..0a8518885a 100644 --- a/libs/community/tests/integration_tests/document_loaders/test_astradb.py +++ b/libs/community/tests/integration_tests/document_loaders/test_astradb.py @@ -10,6 +10,8 @@ Required to run this test: - optionally this as well (otherwise defaults are used): export ASTRA_DB_KEYSPACE="my_keyspace" """ +from __future__ import annotations + import json import os import uuid @@ -35,7 +37,7 @@ def _has_env_vars() -> bool: @pytest.fixture -def astra_db_collection() -> "AstraDBCollection": +def astra_db_collection() -> AstraDBCollection: from astrapy.db import AstraDB astra_db = AstraDB( @@ -56,7 +58,7 @@ def astra_db_collection() -> "AstraDBCollection": @pytest.fixture -async def async_astra_db_collection() -> "AsyncAstraDBCollection": +async def async_astra_db_collection() -> AsyncAstraDBCollection: from astrapy.db import AsyncAstraDB astra_db = AsyncAstraDB( @@ -79,7 +81,7 @@ async def async_astra_db_collection() -> "AsyncAstraDBCollection": @pytest.mark.requires("astrapy") @pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") class TestAstraDB: - def test_astradb_loader(self, astra_db_collection: "AstraDBCollection") -> None: + def test_astradb_loader(self, astra_db_collection: AstraDBCollection) -> None: loader = AstraDBLoader( astra_db_collection.collection_name, token=ASTRA_DB_APPLICATION_TOKEN, @@ -106,9 +108,7 @@ class TestAstraDB: "collection": astra_db_collection.collection_name, } - def test_extraction_function( - self, astra_db_collection: "AstraDBCollection" - ) -> None: + def test_extraction_function(self, astra_db_collection: AstraDBCollection) -> None: loader = AstraDBLoader( astra_db_collection.collection_name, token=ASTRA_DB_APPLICATION_TOKEN, @@ -123,7 +123,7 @@ class TestAstraDB: assert doc.page_content == "bar" async def test_astradb_loader_async( - self, async_astra_db_collection: "AsyncAstraDBCollection" + self, async_astra_db_collection: AsyncAstraDBCollection ) -> None: await async_astra_db_collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20) await async_astra_db_collection.insert_many( @@ -157,7 +157,7 @@ class TestAstraDB: } async def test_extraction_function_async( - self, async_astra_db_collection: "AsyncAstraDBCollection" + self, async_astra_db_collection: AsyncAstraDBCollection ) -> None: loader = AstraDBLoader( async_astra_db_collection.collection_name,