mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
ca2d4078f3
Adds async methods to AstraDBCache
170 lines
5.9 KiB
Python
170 lines
5.9 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import inspect
|
|
from asyncio import InvalidStateError, Task
|
|
from enum import Enum
|
|
from typing import TYPE_CHECKING, Awaitable, Optional, Union
|
|
|
|
if TYPE_CHECKING:
|
|
from astrapy.db import (
|
|
AstraDB,
|
|
AsyncAstraDB,
|
|
)
|
|
|
|
|
|
class SetupMode(Enum):
|
|
SYNC = 1
|
|
ASYNC = 2
|
|
OFF = 3
|
|
|
|
|
|
class _AstraDBEnvironment:
|
|
def __init__(
|
|
self,
|
|
token: Optional[str] = None,
|
|
api_endpoint: Optional[str] = None,
|
|
astra_db_client: Optional[AstraDB] = None,
|
|
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
|
namespace: Optional[str] = None,
|
|
) -> None:
|
|
self.token = token
|
|
self.api_endpoint = api_endpoint
|
|
astra_db = astra_db_client
|
|
async_astra_db = async_astra_db_client
|
|
self.namespace = namespace
|
|
|
|
try:
|
|
from astrapy.db import (
|
|
AstraDB,
|
|
AsyncAstraDB,
|
|
)
|
|
except (ImportError, ModuleNotFoundError):
|
|
raise ImportError(
|
|
"Could not import a recent astrapy python package. "
|
|
"Please install it with `pip install --upgrade astrapy`."
|
|
)
|
|
|
|
# Conflicting-arg checks:
|
|
if astra_db_client is not None or async_astra_db_client is not None:
|
|
if token is not None or api_endpoint is not None:
|
|
raise ValueError(
|
|
"You cannot pass 'astra_db_client' or 'async_astra_db_client' to "
|
|
"AstraDBEnvironment if passing 'token' and 'api_endpoint'."
|
|
)
|
|
|
|
if token and api_endpoint:
|
|
astra_db = AstraDB(
|
|
token=token,
|
|
api_endpoint=api_endpoint,
|
|
namespace=self.namespace,
|
|
)
|
|
async_astra_db = AsyncAstraDB(
|
|
token=token,
|
|
api_endpoint=api_endpoint,
|
|
namespace=self.namespace,
|
|
)
|
|
|
|
if astra_db:
|
|
self.astra_db = astra_db
|
|
if async_astra_db:
|
|
self.async_astra_db = async_astra_db
|
|
else:
|
|
self.async_astra_db = AsyncAstraDB(
|
|
token=self.astra_db.token,
|
|
api_endpoint=self.astra_db.base_url,
|
|
api_path=self.astra_db.api_path,
|
|
api_version=self.astra_db.api_version,
|
|
namespace=self.astra_db.namespace,
|
|
)
|
|
elif async_astra_db:
|
|
self.async_astra_db = async_astra_db
|
|
self.astra_db = AstraDB(
|
|
token=self.async_astra_db.token,
|
|
api_endpoint=self.async_astra_db.base_url,
|
|
api_path=self.async_astra_db.api_path,
|
|
api_version=self.async_astra_db.api_version,
|
|
namespace=self.async_astra_db.namespace,
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
"Must provide 'astra_db_client' or 'async_astra_db_client' or "
|
|
"'token' and 'api_endpoint'"
|
|
)
|
|
|
|
|
|
class _AstraDBCollectionEnvironment(_AstraDBEnvironment):
|
|
def __init__(
|
|
self,
|
|
collection_name: str,
|
|
token: Optional[str] = None,
|
|
api_endpoint: Optional[str] = None,
|
|
astra_db_client: Optional[AstraDB] = None,
|
|
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
|
namespace: Optional[str] = None,
|
|
setup_mode: SetupMode = SetupMode.SYNC,
|
|
pre_delete_collection: bool = False,
|
|
embedding_dimension: Union[int, Awaitable[int], None] = None,
|
|
metric: Optional[str] = None,
|
|
) -> None:
|
|
from astrapy.db import AstraDBCollection, AsyncAstraDBCollection
|
|
|
|
super().__init__(
|
|
token, api_endpoint, astra_db_client, async_astra_db_client, namespace
|
|
)
|
|
self.collection_name = collection_name
|
|
self.collection = AstraDBCollection(
|
|
collection_name=collection_name,
|
|
astra_db=self.astra_db,
|
|
)
|
|
|
|
self.async_collection = AsyncAstraDBCollection(
|
|
collection_name=collection_name,
|
|
astra_db=self.async_astra_db,
|
|
)
|
|
|
|
self.async_setup_db_task: Optional[Task] = None
|
|
if setup_mode == SetupMode.ASYNC:
|
|
async_astra_db = self.async_astra_db
|
|
|
|
async def _setup_db() -> None:
|
|
if pre_delete_collection:
|
|
await async_astra_db.delete_collection(collection_name)
|
|
if inspect.isawaitable(embedding_dimension):
|
|
dimension = await embedding_dimension
|
|
else:
|
|
dimension = embedding_dimension
|
|
await async_astra_db.create_collection(
|
|
collection_name, dimension=dimension, metric=metric
|
|
)
|
|
|
|
self.async_setup_db_task = asyncio.create_task(_setup_db())
|
|
elif setup_mode == SetupMode.SYNC:
|
|
if pre_delete_collection:
|
|
self.astra_db.delete_collection(collection_name)
|
|
if inspect.isawaitable(embedding_dimension):
|
|
raise ValueError(
|
|
"Cannot use an awaitable embedding_dimension with async_setup "
|
|
"set to False"
|
|
)
|
|
self.astra_db.create_collection(
|
|
collection_name,
|
|
dimension=embedding_dimension, # type: ignore[arg-type]
|
|
metric=metric,
|
|
)
|
|
|
|
def ensure_db_setup(self) -> None:
|
|
if self.async_setup_db_task:
|
|
try:
|
|
self.async_setup_db_task.result()
|
|
except InvalidStateError:
|
|
raise ValueError(
|
|
"Asynchronous setup of the DB not finished. "
|
|
"NB: AstraDB components sync methods shouldn't be called from the "
|
|
"event loop. Consider using their async equivalents."
|
|
)
|
|
|
|
async def aensure_db_setup(self) -> None:
|
|
if self.async_setup_db_task:
|
|
await self.async_setup_db_task
|