mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Use Postponed Evaluation of Annotations in Astra and Cassandra doc loaders (#16694)
Minor/cosmetic change
This commit is contained in:
parent
bc7607a4e9
commit
2e3af04080
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
@ -31,8 +33,8 @@ class AstraDBLoader(BaseLoader):
|
|||||||
collection_name: str,
|
collection_name: str,
|
||||||
token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
api_endpoint: Optional[str] = None,
|
api_endpoint: Optional[str] = None,
|
||||||
astra_db_client: Optional["AstraDB"] = None,
|
astra_db_client: Optional[AstraDB] = None,
|
||||||
async_astra_db_client: Optional["AsyncAstraDB"] = None,
|
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
||||||
namespace: Optional[str] = None,
|
namespace: Optional[str] = None,
|
||||||
filter_criteria: Optional[Dict[str, Any]] = None,
|
filter_criteria: Optional[Dict[str, Any]] = None,
|
||||||
projection: Optional[Dict[str, Any]] = None,
|
projection: Optional[Dict[str, Any]] = None,
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
@ -25,9 +27,9 @@ class CassandraLoader(BaseLoader):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
table: Optional[str] = None,
|
table: Optional[str] = None,
|
||||||
session: Optional["Session"] = None,
|
session: Optional[Session] = None,
|
||||||
keyspace: Optional[str] = None,
|
keyspace: Optional[str] = None,
|
||||||
query: Optional[Union[str, "Statement"]] = None,
|
query: Optional[Union[str, Statement]] = None,
|
||||||
page_content_mapper: Callable[[Any], str] = str,
|
page_content_mapper: Callable[[Any], str] = str,
|
||||||
metadata_mapper: Callable[[Any], dict] = lambda _: {},
|
metadata_mapper: Callable[[Any], dict] = lambda _: {},
|
||||||
*,
|
*,
|
||||||
@ -37,7 +39,7 @@ class CassandraLoader(BaseLoader):
|
|||||||
query_custom_payload: dict = None,
|
query_custom_payload: dict = None,
|
||||||
query_execution_profile: Any = _NOT_SET,
|
query_execution_profile: Any = _NOT_SET,
|
||||||
query_paging_state: Any = None,
|
query_paging_state: Any = None,
|
||||||
query_host: "Host" = None,
|
query_host: Host = None,
|
||||||
query_execute_as: str = None,
|
query_execute_as: str = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -10,6 +10,8 @@ Required to run this test:
|
|||||||
- optionally this as well (otherwise defaults are used):
|
- optionally this as well (otherwise defaults are used):
|
||||||
export ASTRA_DB_KEYSPACE="my_keyspace"
|
export ASTRA_DB_KEYSPACE="my_keyspace"
|
||||||
"""
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
@ -35,7 +37,7 @@ def _has_env_vars() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def astra_db_collection() -> "AstraDBCollection":
|
def astra_db_collection() -> AstraDBCollection:
|
||||||
from astrapy.db import AstraDB
|
from astrapy.db import AstraDB
|
||||||
|
|
||||||
astra_db = AstraDB(
|
astra_db = AstraDB(
|
||||||
@ -56,7 +58,7 @@ def astra_db_collection() -> "AstraDBCollection":
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def async_astra_db_collection() -> "AsyncAstraDBCollection":
|
async def async_astra_db_collection() -> AsyncAstraDBCollection:
|
||||||
from astrapy.db import AsyncAstraDB
|
from astrapy.db import AsyncAstraDB
|
||||||
|
|
||||||
astra_db = AsyncAstraDB(
|
astra_db = AsyncAstraDB(
|
||||||
@ -79,7 +81,7 @@ async def async_astra_db_collection() -> "AsyncAstraDBCollection":
|
|||||||
@pytest.mark.requires("astrapy")
|
@pytest.mark.requires("astrapy")
|
||||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
||||||
class TestAstraDB:
|
class TestAstraDB:
|
||||||
def test_astradb_loader(self, astra_db_collection: "AstraDBCollection") -> None:
|
def test_astradb_loader(self, astra_db_collection: AstraDBCollection) -> None:
|
||||||
loader = AstraDBLoader(
|
loader = AstraDBLoader(
|
||||||
astra_db_collection.collection_name,
|
astra_db_collection.collection_name,
|
||||||
token=ASTRA_DB_APPLICATION_TOKEN,
|
token=ASTRA_DB_APPLICATION_TOKEN,
|
||||||
@ -106,9 +108,7 @@ class TestAstraDB:
|
|||||||
"collection": astra_db_collection.collection_name,
|
"collection": astra_db_collection.collection_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
def test_extraction_function(
|
def test_extraction_function(self, astra_db_collection: AstraDBCollection) -> None:
|
||||||
self, astra_db_collection: "AstraDBCollection"
|
|
||||||
) -> None:
|
|
||||||
loader = AstraDBLoader(
|
loader = AstraDBLoader(
|
||||||
astra_db_collection.collection_name,
|
astra_db_collection.collection_name,
|
||||||
token=ASTRA_DB_APPLICATION_TOKEN,
|
token=ASTRA_DB_APPLICATION_TOKEN,
|
||||||
@ -123,7 +123,7 @@ class TestAstraDB:
|
|||||||
assert doc.page_content == "bar"
|
assert doc.page_content == "bar"
|
||||||
|
|
||||||
async def test_astradb_loader_async(
|
async def test_astradb_loader_async(
|
||||||
self, async_astra_db_collection: "AsyncAstraDBCollection"
|
self, async_astra_db_collection: AsyncAstraDBCollection
|
||||||
) -> None:
|
) -> None:
|
||||||
await async_astra_db_collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20)
|
await async_astra_db_collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20)
|
||||||
await async_astra_db_collection.insert_many(
|
await async_astra_db_collection.insert_many(
|
||||||
@ -157,7 +157,7 @@ class TestAstraDB:
|
|||||||
}
|
}
|
||||||
|
|
||||||
async def test_extraction_function_async(
|
async def test_extraction_function_async(
|
||||||
self, async_astra_db_collection: "AsyncAstraDBCollection"
|
self, async_astra_db_collection: AsyncAstraDBCollection
|
||||||
) -> None:
|
) -> None:
|
||||||
loader = AstraDBLoader(
|
loader = AstraDBLoader(
|
||||||
async_astra_db_collection.collection_name,
|
async_astra_db_collection.collection_name,
|
||||||
|
Loading…
Reference in New Issue
Block a user