From 19ebc7418e1f86b6616e6e6199157a7a67b10ad6 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Fri, 16 Feb 2024 17:28:16 +0100 Subject: [PATCH] community: Use _AstraDBCollectionEnvironment in AstraDB VectorStore (community) (#17635) Another PR will be done for the langchain-astradb package. Note: for future PRs, devs will be done in the partner package only. This one is just to align with the rest of the components in the community package and it fixes a bunch of issues. --- .../vectorstores/astradb.py | 209 ++++-------------- 1 file changed, 46 insertions(+), 163 deletions(-) diff --git a/libs/community/langchain_community/vectorstores/astradb.py b/libs/community/langchain_community/vectorstores/astradb.py index f079246e3a..67751e4410 100644 --- a/libs/community/langchain_community/vectorstores/astradb.py +++ b/libs/community/langchain_community/vectorstores/astradb.py @@ -1,13 +1,12 @@ from __future__ import annotations -import asyncio import uuid import warnings -from asyncio import Task from concurrent.futures import ThreadPoolExecutor from typing import ( TYPE_CHECKING, Any, + Awaitable, Callable, Dict, Iterable, @@ -17,17 +16,21 @@ from typing import ( Tuple, Type, TypeVar, + Union, ) import numpy as np from langchain_core._api.deprecation import deprecated from langchain_core.documents import Document from langchain_core.embeddings import Embeddings -from langchain_core.runnables import run_in_executor from langchain_core.runnables.utils import gather_with_concurrency from langchain_core.utils.iter import batch_iterate from langchain_core.vectorstores import VectorStore +from langchain_community.utilities.astradb import ( + SetupMode, + _AstraDBCollectionEnvironment, +) from langchain_community.vectorstores.utils import maximal_marginal_relevance if TYPE_CHECKING: @@ -167,28 +170,12 @@ class AstraDB(VectorStore): bulk_insert_batch_concurrency: Optional[int] = None, bulk_insert_overwrite_concurrency: Optional[int] = None, bulk_delete_concurrency: Optional[int] = None, + setup_mode: SetupMode = SetupMode.SYNC, pre_delete_collection: bool = False, ) -> None: """ Create an AstraDB vector store object. See class docstring for help. """ - try: - from astrapy.db import AstraDB as LibAstraDB - from astrapy.db import AstraDBCollection - except (ImportError, ModuleNotFoundError): - raise ImportError( - "Could not import a recent astrapy python package. " - "Please install it with `pip install --upgrade astrapy`." - ) - - # Conflicting-arg checks: - if astra_db_client is not None or async_astra_db_client is not None: - if token is not None or api_endpoint is not None: - raise ValueError( - "You cannot pass 'astra_db_client' or 'async_astra_db_client' to " - "AstraDB if passing 'token' and 'api_endpoint'." - ) - self.embedding = embedding self.collection_name = collection_name self.token = token @@ -207,105 +194,35 @@ class AstraDB(VectorStore): bulk_delete_concurrency or DEFAULT_BULK_DELETE_CONCURRENCY ) # "vector-related" settings - self._embedding_dimension: Optional[int] = None self.metric = metric + embedding_dimension: Union[int, Awaitable[int], None] = None + if setup_mode == SetupMode.ASYNC: + embedding_dimension = self._aget_embedding_dimension() + elif setup_mode == SetupMode.SYNC: + embedding_dimension = self._get_embedding_dimension() - self.astra_db = astra_db_client - self.async_astra_db = async_astra_db_client - self.collection = None - self.async_collection = None - - if token and api_endpoint: - self.astra_db = LibAstraDB( - token=self.token, - api_endpoint=self.api_endpoint, - namespace=self.namespace, - ) - try: - from astrapy.db import AsyncAstraDB - - self.async_astra_db = AsyncAstraDB( - token=self.token, - api_endpoint=self.api_endpoint, - namespace=self.namespace, - ) - except (ImportError, ModuleNotFoundError): - pass - - if self.astra_db is not None: - self.collection = AstraDBCollection( - collection_name=self.collection_name, - astra_db=self.astra_db, - ) - - self.async_setup_db_task: Optional[Task] = None - if self.async_astra_db is not None: - from astrapy.db import AsyncAstraDBCollection - - self.async_collection = AsyncAstraDBCollection( - collection_name=self.collection_name, - astra_db=self.async_astra_db, - ) - try: - self.async_setup_db_task = asyncio.create_task( - self._setup_db(pre_delete_collection) - ) - except RuntimeError: - pass - - if self.async_setup_db_task is None: - if not pre_delete_collection: - self._provision_collection() - else: - self.clear() - - def _ensure_astra_db_client(self): # type: ignore[no-untyped-def] - if not self.astra_db: - raise ValueError("Missing AstraDB client") - - async def _setup_db(self, pre_delete_collection: bool) -> None: - if pre_delete_collection: - await self.async_astra_db.delete_collection( # type: ignore[union-attr] - collection_name=self.collection_name, - ) - await self._aprovision_collection() - - async def _ensure_db_setup(self) -> None: - if self.async_setup_db_task: - await self.async_setup_db_task - - def _get_embedding_dimension(self) -> int: - if self._embedding_dimension is None: - self._embedding_dimension = len( - self.embedding.embed_query("This is a sample sentence.") - ) - return self._embedding_dimension - - def _provision_collection(self) -> None: - """ - Run the API invocation to create the collection on the backend. - - Internal-usage method, no object members are set, - other than working on the underlying actual storage. - """ - self.astra_db.create_collection( # type: ignore[union-attr] - dimension=self._get_embedding_dimension(), - collection_name=self.collection_name, - metric=self.metric, + self.astra_env = _AstraDBCollectionEnvironment( + collection_name=collection_name, + token=token, + api_endpoint=api_endpoint, + astra_db_client=astra_db_client, + async_astra_db_client=async_astra_db_client, + namespace=namespace, + setup_mode=setup_mode, + pre_delete_collection=pre_delete_collection, + embedding_dimension=embedding_dimension, + metric=metric, ) + self.astra_db = self.astra_env.astra_db + self.async_astra_db = self.astra_env.async_astra_db + self.collection = self.astra_env.collection + self.async_collection = self.astra_env.async_collection - async def _aprovision_collection(self) -> None: - """ - Run the API invocation to create the collection on the backend. + def _get_embedding_dimension(self) -> int: + return len(self.embedding.embed_query(text="This is a sample sentence.")) - Internal-usage method, no object members are set, - other than working on the underlying actual storage. - """ - await self.async_astra_db.create_collection( # type: ignore[union-attr] - dimension=self._get_embedding_dimension(), - collection_name=self.collection_name, - metric=self.metric, - ) + async def _aget_embedding_dimension(self) -> int: + return len(await self.embedding.aembed_query(text="This is a sample sentence.")) @property def embeddings(self) -> Embeddings: @@ -326,14 +243,12 @@ class AstraDB(VectorStore): def clear(self) -> None: """Empty the collection of all its stored entries.""" - self.delete_collection() - self._provision_collection() + self.astra_env.ensure_db_setup() + self.collection.delete_many({}) async def aclear(self) -> None: """Empty the collection of all its stored entries.""" - await self._ensure_db_setup() - if not self.async_astra_db: - await run_in_executor(None, self.clear) + await self.astra_env.aensure_db_setup() await self.async_collection.delete_many({}) # type: ignore[union-attr] def delete_by_document_id(self, document_id: str) -> bool: @@ -341,7 +256,7 @@ class AstraDB(VectorStore): Remove a single document from the store, given its document_id (str). Return True if a document has indeed been deleted, False if ID not found. """ - self._ensure_astra_db_client() + self.astra_env.ensure_db_setup() deletion_response = self.collection.delete_one(document_id) # type: ignore[union-attr] return ((deletion_response or {}).get("status") or {}).get( "deletedCount", 0 @@ -352,9 +267,7 @@ class AstraDB(VectorStore): Remove a single document from the store, given its document_id (str). Return True if a document has indeed been deleted, False if ID not found. """ - await self._ensure_db_setup() - if not self.async_collection: - return await run_in_executor(None, self.delete_by_document_id, document_id) + await self.astra_env.aensure_db_setup() deletion_response = await self.async_collection.delete_one(document_id) return ((deletion_response or {}).get("status") or {}).get( "deletedCount", 0 @@ -439,8 +352,8 @@ class AstraDB(VectorStore): Stored data is lost and unrecoverable, resources are freed. Use with caution. """ - self._ensure_astra_db_client() - self.astra_db.delete_collection( # type: ignore[union-attr] + self.astra_env.ensure_db_setup() + self.astra_db.delete_collection( collection_name=self.collection_name, ) @@ -451,10 +364,8 @@ class AstraDB(VectorStore): Stored data is lost and unrecoverable, resources are freed. Use with caution. """ - await self._ensure_db_setup() - if not self.async_astra_db: - await run_in_executor(None, self.delete_collection) - await self.async_astra_db.delete_collection( # type: ignore[union-attr] + await self.astra_env.aensure_db_setup() + await self.async_astra_db.delete_collection( collection_name=self.collection_name, ) @@ -569,7 +480,7 @@ class AstraDB(VectorStore): f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), " "which will be ignored." ) - self._ensure_astra_db_client() + self.astra_env.ensure_db_setup() embedding_vectors = self.embedding.embed_documents(list(texts)) documents_to_insert = self._get_documents_to_insert( @@ -655,22 +566,13 @@ class AstraDB(VectorStore): Returns: List[str]: List of ids of the added texts. """ - await self._ensure_db_setup() - if not self.async_collection: - await super().aadd_texts( - texts, - metadatas, - ids=ids, - batch_size=batch_size, - batch_concurrency=batch_concurrency, - overwrite_concurrency=overwrite_concurrency, - ) if kwargs: warnings.warn( "Method 'aadd_texts' of AstraDB vector store invoked with " f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), " "which will be ignored." ) + await self.astra_env.aensure_db_setup() embedding_vectors = await self.embedding.aembed_documents(list(texts)) documents_to_insert = self._get_documents_to_insert( @@ -731,7 +633,7 @@ class AstraDB(VectorStore): Returns: List of (Document, score, id), the most similar to the query vector. """ - self._ensure_astra_db_client() + self.astra_env.ensure_db_setup() metadata_parameter = self._filter_to_metadata(filter) # hits = list( @@ -773,15 +675,7 @@ class AstraDB(VectorStore): Returns: List of (Document, score, id), the most similar to the query vector. """ - await self._ensure_db_setup() - if not self.async_collection: - return await run_in_executor( - None, - self.asimilarity_search_with_score_id_by_vector, # type: ignore[arg-type] - embedding, - k, - filter, - ) + await self.astra_env.aensure_db_setup() metadata_parameter = self._filter_to_metadata(filter) # return [ @@ -1010,7 +904,7 @@ class AstraDB(VectorStore): Returns: List of Documents selected by maximal marginal relevance. """ - self._ensure_astra_db_client() + self.astra_env.ensure_db_setup() metadata_parameter = self._filter_to_metadata(filter) prefetch_hits = list( @@ -1051,18 +945,7 @@ class AstraDB(VectorStore): Returns: List of Documents selected by maximal marginal relevance. """ - await self._ensure_db_setup() - if not self.async_collection: - return await run_in_executor( - None, - self.max_marginal_relevance_search_by_vector, - embedding, - k, - fetch_k, - lambda_mult, - filter, - **kwargs, - ) + await self.astra_env.aensure_db_setup() metadata_parameter = self._filter_to_metadata(filter) prefetch_hits = [