diff --git a/libs/community/langchain_community/vectorstores/astradb.py b/libs/community/langchain_community/vectorstores/astradb.py index f079246e3a..67751e4410 100644 --- a/libs/community/langchain_community/vectorstores/astradb.py +++ b/libs/community/langchain_community/vectorstores/astradb.py @@ -1,13 +1,12 @@ from __future__ import annotations -import asyncio import uuid import warnings -from asyncio import Task from concurrent.futures import ThreadPoolExecutor from typing import ( TYPE_CHECKING, Any, + Awaitable, Callable, Dict, Iterable, @@ -17,17 +16,21 @@ from typing import ( Tuple, Type, TypeVar, + Union, ) import numpy as np from langchain_core._api.deprecation import deprecated from langchain_core.documents import Document from langchain_core.embeddings import Embeddings -from langchain_core.runnables import run_in_executor from langchain_core.runnables.utils import gather_with_concurrency from langchain_core.utils.iter import batch_iterate from langchain_core.vectorstores import VectorStore +from langchain_community.utilities.astradb import ( + SetupMode, + _AstraDBCollectionEnvironment, +) from langchain_community.vectorstores.utils import maximal_marginal_relevance if TYPE_CHECKING: @@ -167,28 +170,12 @@ class AstraDB(VectorStore): bulk_insert_batch_concurrency: Optional[int] = None, bulk_insert_overwrite_concurrency: Optional[int] = None, bulk_delete_concurrency: Optional[int] = None, + setup_mode: SetupMode = SetupMode.SYNC, pre_delete_collection: bool = False, ) -> None: """ Create an AstraDB vector store object. See class docstring for help. """ - try: - from astrapy.db import AstraDB as LibAstraDB - from astrapy.db import AstraDBCollection - except (ImportError, ModuleNotFoundError): - raise ImportError( - "Could not import a recent astrapy python package. " - "Please install it with `pip install --upgrade astrapy`." - ) - - # Conflicting-arg checks: - if astra_db_client is not None or async_astra_db_client is not None: - if token is not None or api_endpoint is not None: - raise ValueError( - "You cannot pass 'astra_db_client' or 'async_astra_db_client' to " - "AstraDB if passing 'token' and 'api_endpoint'." - ) - self.embedding = embedding self.collection_name = collection_name self.token = token @@ -207,105 +194,35 @@ class AstraDB(VectorStore): bulk_delete_concurrency or DEFAULT_BULK_DELETE_CONCURRENCY ) # "vector-related" settings - self._embedding_dimension: Optional[int] = None self.metric = metric + embedding_dimension: Union[int, Awaitable[int], None] = None + if setup_mode == SetupMode.ASYNC: + embedding_dimension = self._aget_embedding_dimension() + elif setup_mode == SetupMode.SYNC: + embedding_dimension = self._get_embedding_dimension() - self.astra_db = astra_db_client - self.async_astra_db = async_astra_db_client - self.collection = None - self.async_collection = None - - if token and api_endpoint: - self.astra_db = LibAstraDB( - token=self.token, - api_endpoint=self.api_endpoint, - namespace=self.namespace, - ) - try: - from astrapy.db import AsyncAstraDB - - self.async_astra_db = AsyncAstraDB( - token=self.token, - api_endpoint=self.api_endpoint, - namespace=self.namespace, - ) - except (ImportError, ModuleNotFoundError): - pass - - if self.astra_db is not None: - self.collection = AstraDBCollection( - collection_name=self.collection_name, - astra_db=self.astra_db, - ) - - self.async_setup_db_task: Optional[Task] = None - if self.async_astra_db is not None: - from astrapy.db import AsyncAstraDBCollection - - self.async_collection = AsyncAstraDBCollection( - collection_name=self.collection_name, - astra_db=self.async_astra_db, - ) - try: - self.async_setup_db_task = asyncio.create_task( - self._setup_db(pre_delete_collection) - ) - except RuntimeError: - pass - - if self.async_setup_db_task is None: - if not pre_delete_collection: - self._provision_collection() - else: - self.clear() - - def _ensure_astra_db_client(self): # type: ignore[no-untyped-def] - if not self.astra_db: - raise ValueError("Missing AstraDB client") - - async def _setup_db(self, pre_delete_collection: bool) -> None: - if pre_delete_collection: - await self.async_astra_db.delete_collection( # type: ignore[union-attr] - collection_name=self.collection_name, - ) - await self._aprovision_collection() - - async def _ensure_db_setup(self) -> None: - if self.async_setup_db_task: - await self.async_setup_db_task - - def _get_embedding_dimension(self) -> int: - if self._embedding_dimension is None: - self._embedding_dimension = len( - self.embedding.embed_query("This is a sample sentence.") - ) - return self._embedding_dimension - - def _provision_collection(self) -> None: - """ - Run the API invocation to create the collection on the backend. - - Internal-usage method, no object members are set, - other than working on the underlying actual storage. - """ - self.astra_db.create_collection( # type: ignore[union-attr] - dimension=self._get_embedding_dimension(), - collection_name=self.collection_name, - metric=self.metric, + self.astra_env = _AstraDBCollectionEnvironment( + collection_name=collection_name, + token=token, + api_endpoint=api_endpoint, + astra_db_client=astra_db_client, + async_astra_db_client=async_astra_db_client, + namespace=namespace, + setup_mode=setup_mode, + pre_delete_collection=pre_delete_collection, + embedding_dimension=embedding_dimension, + metric=metric, ) + self.astra_db = self.astra_env.astra_db + self.async_astra_db = self.astra_env.async_astra_db + self.collection = self.astra_env.collection + self.async_collection = self.astra_env.async_collection - async def _aprovision_collection(self) -> None: - """ - Run the API invocation to create the collection on the backend. + def _get_embedding_dimension(self) -> int: + return len(self.embedding.embed_query(text="This is a sample sentence.")) - Internal-usage method, no object members are set, - other than working on the underlying actual storage. - """ - await self.async_astra_db.create_collection( # type: ignore[union-attr] - dimension=self._get_embedding_dimension(), - collection_name=self.collection_name, - metric=self.metric, - ) + async def _aget_embedding_dimension(self) -> int: + return len(await self.embedding.aembed_query(text="This is a sample sentence.")) @property def embeddings(self) -> Embeddings: @@ -326,14 +243,12 @@ class AstraDB(VectorStore): def clear(self) -> None: """Empty the collection of all its stored entries.""" - self.delete_collection() - self._provision_collection() + self.astra_env.ensure_db_setup() + self.collection.delete_many({}) async def aclear(self) -> None: """Empty the collection of all its stored entries.""" - await self._ensure_db_setup() - if not self.async_astra_db: - await run_in_executor(None, self.clear) + await self.astra_env.aensure_db_setup() await self.async_collection.delete_many({}) # type: ignore[union-attr] def delete_by_document_id(self, document_id: str) -> bool: @@ -341,7 +256,7 @@ class AstraDB(VectorStore): Remove a single document from the store, given its document_id (str). Return True if a document has indeed been deleted, False if ID not found. """ - self._ensure_astra_db_client() + self.astra_env.ensure_db_setup() deletion_response = self.collection.delete_one(document_id) # type: ignore[union-attr] return ((deletion_response or {}).get("status") or {}).get( "deletedCount", 0 @@ -352,9 +267,7 @@ class AstraDB(VectorStore): Remove a single document from the store, given its document_id (str). Return True if a document has indeed been deleted, False if ID not found. """ - await self._ensure_db_setup() - if not self.async_collection: - return await run_in_executor(None, self.delete_by_document_id, document_id) + await self.astra_env.aensure_db_setup() deletion_response = await self.async_collection.delete_one(document_id) return ((deletion_response or {}).get("status") or {}).get( "deletedCount", 0 @@ -439,8 +352,8 @@ class AstraDB(VectorStore): Stored data is lost and unrecoverable, resources are freed. Use with caution. """ - self._ensure_astra_db_client() - self.astra_db.delete_collection( # type: ignore[union-attr] + self.astra_env.ensure_db_setup() + self.astra_db.delete_collection( collection_name=self.collection_name, ) @@ -451,10 +364,8 @@ class AstraDB(VectorStore): Stored data is lost and unrecoverable, resources are freed. Use with caution. """ - await self._ensure_db_setup() - if not self.async_astra_db: - await run_in_executor(None, self.delete_collection) - await self.async_astra_db.delete_collection( # type: ignore[union-attr] + await self.astra_env.aensure_db_setup() + await self.async_astra_db.delete_collection( collection_name=self.collection_name, ) @@ -569,7 +480,7 @@ class AstraDB(VectorStore): f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), " "which will be ignored." ) - self._ensure_astra_db_client() + self.astra_env.ensure_db_setup() embedding_vectors = self.embedding.embed_documents(list(texts)) documents_to_insert = self._get_documents_to_insert( @@ -655,22 +566,13 @@ class AstraDB(VectorStore): Returns: List[str]: List of ids of the added texts. """ - await self._ensure_db_setup() - if not self.async_collection: - await super().aadd_texts( - texts, - metadatas, - ids=ids, - batch_size=batch_size, - batch_concurrency=batch_concurrency, - overwrite_concurrency=overwrite_concurrency, - ) if kwargs: warnings.warn( "Method 'aadd_texts' of AstraDB vector store invoked with " f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), " "which will be ignored." ) + await self.astra_env.aensure_db_setup() embedding_vectors = await self.embedding.aembed_documents(list(texts)) documents_to_insert = self._get_documents_to_insert( @@ -731,7 +633,7 @@ class AstraDB(VectorStore): Returns: List of (Document, score, id), the most similar to the query vector. """ - self._ensure_astra_db_client() + self.astra_env.ensure_db_setup() metadata_parameter = self._filter_to_metadata(filter) # hits = list( @@ -773,15 +675,7 @@ class AstraDB(VectorStore): Returns: List of (Document, score, id), the most similar to the query vector. """ - await self._ensure_db_setup() - if not self.async_collection: - return await run_in_executor( - None, - self.asimilarity_search_with_score_id_by_vector, # type: ignore[arg-type] - embedding, - k, - filter, - ) + await self.astra_env.aensure_db_setup() metadata_parameter = self._filter_to_metadata(filter) # return [ @@ -1010,7 +904,7 @@ class AstraDB(VectorStore): Returns: List of Documents selected by maximal marginal relevance. """ - self._ensure_astra_db_client() + self.astra_env.ensure_db_setup() metadata_parameter = self._filter_to_metadata(filter) prefetch_hits = list( @@ -1051,18 +945,7 @@ class AstraDB(VectorStore): Returns: List of Documents selected by maximal marginal relevance. """ - await self._ensure_db_setup() - if not self.async_collection: - return await run_in_executor( - None, - self.max_marginal_relevance_search_by_vector, - embedding, - k, - fetch_k, - lambda_mult, - filter, - **kwargs, - ) + await self.astra_env.aensure_db_setup() metadata_parameter = self._filter_to_metadata(filter) prefetch_hits = [