diff --git a/libs/community/langchain_community/vectorstores/astradb.py b/libs/community/langchain_community/vectorstores/astradb.py index 1c71d3f7b8..7d59bc91eb 100644 --- a/libs/community/langchain_community/vectorstores/astradb.py +++ b/libs/community/langchain_community/vectorstores/astradb.py @@ -1,9 +1,12 @@ from __future__ import annotations +import asyncio import uuid import warnings +from asyncio import Task from concurrent.futures import ThreadPoolExecutor from typing import ( + TYPE_CHECKING, Any, Callable, Dict, @@ -19,11 +22,17 @@ from typing import ( import numpy as np from langchain_core.documents import Document from langchain_core.embeddings import Embeddings +from langchain_core.runnables import run_in_executor +from langchain_core.runnables.utils import gather_with_concurrency from langchain_core.utils.iter import batch_iterate from langchain_core.vectorstores import VectorStore from langchain_community.vectorstores.utils import maximal_marginal_relevance +if TYPE_CHECKING: + from astrapy.db import AstraDB as LibAstraDB + from astrapy.db import AsyncAstraDB + ADBVST = TypeVar("ADBVST", bound="AstraDB") T = TypeVar("T") U = TypeVar("U") @@ -144,7 +153,8 @@ class AstraDB(VectorStore): collection_name: str, token: Optional[str] = None, api_endpoint: Optional[str] = None, - astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed + astra_db_client: Optional[LibAstraDB] = None, + async_astra_db_client: Optional[AsyncAstraDB] = None, namespace: Optional[str] = None, metric: Optional[str] = None, batch_size: Optional[int] = None, @@ -157,12 +167,8 @@ class AstraDB(VectorStore): Create an AstraDB vector store object. See class docstring for help. """ try: - from astrapy.db import ( - AstraDB as LibAstraDB, - ) - from astrapy.db import ( - AstraDBCollection as LibAstraDBCollection, - ) + from astrapy.db import AstraDB as LibAstraDB + from astrapy.db import AstraDBCollection except (ImportError, ModuleNotFoundError): raise ImportError( "Could not import a recent astrapy python package. " @@ -170,11 +176,11 @@ class AstraDB(VectorStore): ) # Conflicting-arg checks: - if astra_db_client is not None: + if astra_db_client is not None or async_astra_db_client is not None: if token is not None or api_endpoint is not None: raise ValueError( - "You cannot pass 'astra_db_client' to AstraDB if passing " - "'token' and 'api_endpoint'." + "You cannot pass 'astra_db_client' or 'async_astra_db_client' to " + "AstraDB if passing 'token' and 'api_endpoint'." ) self.embedding = embedding @@ -198,23 +204,69 @@ class AstraDB(VectorStore): self._embedding_dimension: Optional[int] = None self.metric = metric - if astra_db_client is not None: - self.astra_db = astra_db_client - else: + self.astra_db = astra_db_client + self.async_astra_db = async_astra_db_client + self.collection = None + self.async_collection = None + + if token and api_endpoint: self.astra_db = LibAstraDB( token=self.token, api_endpoint=self.api_endpoint, namespace=self.namespace, ) - if not pre_delete_collection: - self._provision_collection() - else: - self.clear() + try: + from astrapy.db import AsyncAstraDB - self.collection = LibAstraDBCollection( - collection_name=self.collection_name, - astra_db=self.astra_db, - ) + self.async_astra_db = AsyncAstraDB( + token=self.token, + api_endpoint=self.api_endpoint, + namespace=self.namespace, + ) + except (ImportError, ModuleNotFoundError): + pass + + if self.astra_db is not None: + self.collection = AstraDBCollection( + collection_name=self.collection_name, + astra_db=self.astra_db, + ) + + self.async_setup_db_task: Optional[Task] = None + if self.async_astra_db is not None: + from astrapy.db import AsyncAstraDBCollection + + self.async_collection = AsyncAstraDBCollection( + collection_name=self.collection_name, + astra_db=self.async_astra_db, + ) + try: + self.async_setup_db_task = asyncio.create_task( + self._setup_db(pre_delete_collection) + ) + except RuntimeError: + pass + + if self.async_setup_db_task is None: + if not pre_delete_collection: + self._provision_collection() + else: + self.clear() + + def _ensure_astra_db_client(self): + if not self.astra_db: + raise ValueError("Missing AstraDB client") + + async def _setup_db(self, pre_delete_collection: bool) -> None: + if pre_delete_collection: + await self.async_astra_db.delete_collection( + collection_name=self.collection_name, + ) + await self._aprovision_collection() + + async def _ensure_db_setup(self) -> None: + if self.async_setup_db_task: + await self.async_setup_db_task def _get_embedding_dimension(self) -> int: if self._embedding_dimension is None: @@ -223,31 +275,31 @@ class AstraDB(VectorStore): ) return self._embedding_dimension - def _drop_collection(self) -> None: + def _provision_collection(self) -> None: """ - Drop the collection from storage. + Run the API invocation to create the collection on the backend. - This is meant as an internal-usage method, no members - are set other than actual deletion on the backend. + Internal-usage method, no object members are set, + other than working on the underlying actual storage. """ - _ = self.astra_db.delete_collection( + self.astra_db.create_collection( + dimension=self._get_embedding_dimension(), collection_name=self.collection_name, + metric=self.metric, ) - return None - def _provision_collection(self) -> None: + async def _aprovision_collection(self) -> None: """ Run the API invocation to create the collection on the backend. Internal-usage method, no object members are set, other than working on the underlying actual storage. """ - _ = self.astra_db.create_collection( + await self.async_astra_db.create_collection( dimension=self._get_embedding_dimension(), collection_name=self.collection_name, metric=self.metric, ) - return None @property def embeddings(self) -> Embeddings: @@ -268,16 +320,36 @@ class AstraDB(VectorStore): def clear(self) -> None: """Empty the collection of all its stored entries.""" - self._drop_collection() + self.delete_collection() self._provision_collection() - return None + + async def aclear(self) -> None: + """Empty the collection of all its stored entries.""" + await self._ensure_db_setup() + if not self.async_astra_db: + await run_in_executor(None, self.clear) + await self.async_collection.delete_many({}) def delete_by_document_id(self, document_id: str) -> bool: """ Remove a single document from the store, given its document_id (str). Return True if a document has indeed been deleted, False if ID not found. """ - deletion_response = self.collection.delete(document_id) + self._ensure_astra_db_client() + deletion_response = self.collection.delete_one(document_id) + return ((deletion_response or {}).get("status") or {}).get( + "deletedCount", 0 + ) == 1 + + async def adelete_by_document_id(self, document_id: str) -> bool: + """ + Remove a single document from the store, given its document_id (str). + Return True if a document has indeed been deleted, False if ID not found. + """ + await self._ensure_db_setup() + if not self.async_collection: + return await run_in_executor(None, self.delete_by_document_id, document_id) + deletion_response = await self.async_collection.delete_one(document_id) return ((deletion_response or {}).get("status") or {}).get( "deletedCount", 0 ) == 1 @@ -320,6 +392,40 @@ class AstraDB(VectorStore): ) return True + async def adelete( + self, + ids: Optional[List[str]] = None, + concurrency: Optional[int] = None, + **kwargs: Any, + ) -> Optional[bool]: + """Delete by vector ID or other criteria. + + Args: + ids: List of ids to delete. + concurrency (Optional[int]): max number of concurrent delete queries. + Defaults to instance-level setting. + **kwargs: Other keyword arguments that subclasses might use. + + Returns: + Optional[bool]: True if deletion is successful, + False otherwise, None if not implemented. + """ + if kwargs: + warnings.warn( + "Method 'adelete' of AstraDB vector store invoked with " + f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), " + "which will be ignored." + ) + + if ids is None: + raise ValueError("No ids provided to delete.") + + return all( + await gather_with_concurrency( + concurrency, *[self.adelete_by_document_id(doc_id) for doc_id in ids] + ) + ) + def delete_collection(self) -> None: """ Completely delete the collection from the database (as opposed @@ -327,8 +433,88 @@ class AstraDB(VectorStore): Stored data is lost and unrecoverable, resources are freed. Use with caution. """ - self._drop_collection() - return None + self._ensure_astra_db_client() + self.astra_db.delete_collection( + collection_name=self.collection_name, + ) + + async def adelete_collection(self) -> None: + """ + Completely delete the collection from the database (as opposed + to 'clear()', which empties it only). + Stored data is lost and unrecoverable, resources are freed. + Use with caution. + """ + await self._ensure_db_setup() + if not self.async_astra_db: + await run_in_executor(None, self.delete_collection) + await self.async_astra_db.delete_collection( + collection_name=self.collection_name, + ) + + @staticmethod + def _get_documents_to_insert( + texts: Iterable[str], + embedding_vectors: List[List[float]], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + ) -> List[DocDict]: + if ids is None: + ids = [uuid.uuid4().hex for _ in texts] + if metadatas is None: + metadatas = [{} for _ in texts] + # + documents_to_insert = [ + { + "content": b_txt, + "_id": b_id, + "$vector": b_emb, + "metadata": b_md, + } + for b_txt, b_emb, b_id, b_md in zip( + texts, + embedding_vectors, + ids, + metadatas, + ) + ] + # make unique by id, keeping the last + uniqued_documents_to_insert = _unique_list( + documents_to_insert[::-1], + lambda document: document["_id"], + )[::-1] + return uniqued_documents_to_insert + + @staticmethod + def _get_missing_from_batch( + document_batch: List[DocDict], insert_result: Dict[str, Any] + ) -> Tuple[List[str], List[DocDict]]: + if "status" not in insert_result: + raise ValueError( + f"API Exception while running bulk insertion: {str(insert_result)}" + ) + batch_inserted = insert_result["status"]["insertedIds"] + # estimation of the preexisting documents that failed + missed_inserted_ids = {document["_id"] for document in document_batch} - set( + batch_inserted + ) + errors = insert_result.get("errors", []) + # careful for other sources of error other than "doc already exists" + num_errors = len(errors) + unexpected_errors = any( + error.get("errorCode") != "DOCUMENT_ALREADY_EXISTS" for error in errors + ) + if num_errors != len(missed_inserted_ids) or unexpected_errors: + raise ValueError( + f"API Exception while running bulk insertion: {str(errors)}" + ) + # deal with the missing insertions as upserts + missing_from_batch = [ + document + for document in document_batch + if document["_id"] in missed_inserted_ids + ] + return batch_inserted, missing_from_batch def add_texts( self, @@ -377,36 +563,12 @@ class AstraDB(VectorStore): f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), " "which will be ignored." ) + self._ensure_astra_db_client() - _texts = list(texts) - if ids is None: - ids = [uuid.uuid4().hex for _ in _texts] - if metadatas is None: - metadatas = [{} for _ in _texts] - # - embedding_vectors = self.embedding.embed_documents(_texts) - - documents_to_insert = [ - { - "content": b_txt, - "_id": b_id, - "$vector": b_emb, - "metadata": b_md, - } - for b_txt, b_emb, b_id, b_md in zip( - _texts, - embedding_vectors, - ids, - metadatas, - ) - ] - # make unique by id, keeping the last - uniqued_documents_to_insert = _unique_list( - documents_to_insert[::-1], - lambda document: document["_id"], - )[::-1] - - all_ids = [] + embedding_vectors = self.embedding.embed_documents(list(texts)) + documents_to_insert = self._get_documents_to_insert( + texts, embedding_vectors, metadatas, ids + ) def _handle_batch(document_batch: List[DocDict]) -> List[str]: im_result = self.collection.insert_many( @@ -414,33 +576,9 @@ class AstraDB(VectorStore): options={"ordered": False}, partial_failures_allowed=True, ) - if "status" not in im_result: - raise ValueError( - f"API Exception while running bulk insertion: {str(im_result)}" - ) - - batch_inserted = im_result["status"]["insertedIds"] - # estimation of the preexisting documents that failed - missed_inserted_ids = { - document["_id"] for document in document_batch - } - set(batch_inserted) - errors = im_result.get("errors", []) - # careful for other sources of error other than "doc already exists" - num_errors = len(errors) - unexpected_errors = any( - error.get("errorCode") != "DOCUMENT_ALREADY_EXISTS" for error in errors - ) - if num_errors != len(missed_inserted_ids) or unexpected_errors: - raise ValueError( - f"API Exception while running bulk insertion: {str(errors)}" - ) - - # deal with the missing insertions as upserts - missing_from_batch = [ - document - for document in document_batch - if document["_id"] in missed_inserted_ids - ] + batch_inserted, missing_from_batch = self._get_missing_from_batch( + document_batch, im_result + ) def _handle_missing_document(missing_document: DocDict) -> str: replacement_result = self.collection.find_one_and_replace( @@ -459,9 +597,7 @@ class AstraDB(VectorStore): missing_from_batch, ) ) - - upsert_ids = batch_inserted + batch_replaced - return upsert_ids + return batch_inserted + batch_replaced _b_max_workers = batch_concurrency or self.bulk_insert_batch_concurrency with ThreadPoolExecutor(max_workers=_b_max_workers) as tpe: @@ -469,13 +605,111 @@ class AstraDB(VectorStore): _handle_batch, batch_iterate( batch_size or self.batch_size, - uniqued_documents_to_insert, + documents_to_insert, ), ) + return [iid for id_list in all_ids_nested for iid in id_list] + + async def aadd_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + *, + batch_size: Optional[int] = None, + batch_concurrency: Optional[int] = None, + overwrite_concurrency: Optional[int] = None, + **kwargs: Any, + ) -> List[str]: + """Run texts through the embeddings and add them to the vectorstore. + + If passing explicit ids, those entries whose id is in the store already + will be replaced. + + Args: + texts (Iterable[str]): Texts to add to the vectorstore. + metadatas (Optional[List[dict]], optional): Optional list of metadatas. + ids (Optional[List[str]], optional): Optional list of ids. + batch_size (Optional[int]): Number of documents in each API call. + Check the underlying Astra DB HTTP API specs for the max value + (20 at the time of writing this). If not provided, defaults + to the instance-level setting. + batch_concurrency (Optional[int]): number of concurrent batch insertions. + Defaults to instance-level setting if not provided. + overwrite_concurrency (Optional[int]): number of concurrent API calls to + process pre-existing documents in each batch. + Defaults to instance-level setting if not provided. + + A note on metadata: there are constraints on the allowed field names + in this dictionary, coming from the underlying Astra DB API. + For instance, the `$` (dollar sign) cannot be used in the dict keys. + See this document for details: + docs.datastax.com/en/astra-serverless/docs/develop/dev-with-json.html + + Returns: + List[str]: List of ids of the added texts. + """ + await self._ensure_db_setup() + if not self.async_collection: + await super().aadd_texts( + texts, + metadatas, + ids=ids, + batch_size=batch_size, + batch_concurrency=batch_concurrency, + overwrite_concurrency=overwrite_concurrency, + ) + if kwargs: + warnings.warn( + "Method 'aadd_texts' of AstraDB vector store invoked with " + f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), " + "which will be ignored." + ) + + embedding_vectors = await self.embedding.aembed_documents(list(texts)) + documents_to_insert = self._get_documents_to_insert( + texts, embedding_vectors, metadatas, ids + ) + + async def _handle_batch(document_batch: List[DocDict]) -> List[str]: + im_result = await self.async_collection.insert_many( + documents=document_batch, + options={"ordered": False}, + partial_failures_allowed=True, + ) + batch_inserted, missing_from_batch = self._get_missing_from_batch( + document_batch, im_result + ) + + async def _handle_missing_document(missing_document: DocDict) -> str: + replacement_result = await self.async_collection.find_one_and_replace( + filter={"_id": missing_document["_id"]}, + replacement=missing_document, + ) + return replacement_result["data"]["document"]["_id"] - all_ids = [iid for id_list in all_ids_nested for iid in id_list] + _u_max_workers = ( + overwrite_concurrency or self.bulk_insert_overwrite_concurrency + ) + batch_replaced = await gather_with_concurrency( + _u_max_workers, + *[_handle_missing_document(doc) for doc in missing_from_batch], + ) + return batch_inserted + batch_replaced + + _b_max_workers = batch_concurrency or self.bulk_insert_batch_concurrency + all_ids_nested = await gather_with_concurrency( + _b_max_workers, + *[ + _handle_batch(batch) + for batch in batch_iterate( + batch_size or self.batch_size, + documents_to_insert, + ) + ], + ) - return all_ids + return [iid for id_list in all_ids_nested for iid in id_list] def similarity_search_with_score_id_by_vector( self, @@ -491,6 +725,7 @@ class AstraDB(VectorStore): Returns: List of (Document, score, id), the most similar to the query vector. """ + self._ensure_astra_db_client() metadata_parameter = self._filter_to_metadata(filter) # hits = list( @@ -518,6 +753,52 @@ class AstraDB(VectorStore): for hit in hits ] + async def asimilarity_search_with_score_id_by_vector( + self, + embedding: List[float], + k: int = 4, + filter: Optional[Dict[str, Any]] = None, + ) -> List[Tuple[Document, float, str]]: + """Return docs most similar to embedding vector. + + Args: + embedding (str): Embedding to look up documents similar to. + k (int): Number of Documents to return. Defaults to 4. + Returns: + List of (Document, score, id), the most similar to the query vector. + """ + await self._ensure_db_setup() + if not self.async_collection: + return await run_in_executor( + None, + self.asimilarity_search_with_score_id_by_vector, + embedding, + k, + filter, + ) + metadata_parameter = self._filter_to_metadata(filter) + # + return [ + ( + Document( + page_content=hit["content"], + metadata=hit["metadata"], + ), + hit["$similarity"], + hit["_id"], + ) + async for hit in self.async_collection.paginated_find( + filter=metadata_parameter, + sort={"$vector": embedding}, + options={"limit": k, "includeSimilarity": True}, + projection={ + "_id": 1, + "content": 1, + "metadata": 1, + }, + ) + ] + def similarity_search_with_score_id( self, query: str, @@ -531,6 +812,19 @@ class AstraDB(VectorStore): filter=filter, ) + async def asimilarity_search_with_score_id( + self, + query: str, + k: int = 4, + filter: Optional[Dict[str, Any]] = None, + ) -> List[Tuple[Document, float, str]]: + embedding_vector = await self.embedding.aembed_query(query) + return await self.asimilarity_search_with_score_id_by_vector( + embedding=embedding_vector, + k=k, + filter=filter, + ) + def similarity_search_with_score_by_vector( self, embedding: List[float], @@ -554,6 +848,33 @@ class AstraDB(VectorStore): ) ] + async def asimilarity_search_with_score_by_vector( + self, + embedding: List[float], + k: int = 4, + filter: Optional[Dict[str, Any]] = None, + ) -> List[Tuple[Document, float]]: + """Return docs most similar to embedding vector. + + Args: + embedding (str): Embedding to look up documents similar to. + k (int): Number of Documents to return. Defaults to 4. + Returns: + List of (Document, score), the most similar to the query vector. + """ + return [ + (doc, score) + for ( + doc, + score, + doc_id, + ) in await self.asimilarity_search_with_score_id_by_vector( + embedding=embedding, + k=k, + filter=filter, + ) + ] + def similarity_search( self, query: str, @@ -568,6 +889,20 @@ class AstraDB(VectorStore): filter=filter, ) + async def asimilarity_search( + self, + query: str, + k: int = 4, + filter: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> List[Document]: + embedding_vector = await self.embedding.aembed_query(query) + return await self.asimilarity_search_by_vector( + embedding_vector, + k, + filter=filter, + ) + def similarity_search_by_vector( self, embedding: List[float], @@ -584,6 +919,22 @@ class AstraDB(VectorStore): ) ] + async def asimilarity_search_by_vector( + self, + embedding: List[float], + k: int = 4, + filter: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> List[Document]: + return [ + doc + for doc, _ in await self.asimilarity_search_with_score_by_vector( + embedding, + k, + filter=filter, + ) + ] + def similarity_search_with_score( self, query: str, @@ -597,6 +948,40 @@ class AstraDB(VectorStore): filter=filter, ) + async def asimilarity_search_with_score( + self, + query: str, + k: int = 4, + filter: Optional[Dict[str, Any]] = None, + ) -> List[Tuple[Document, float]]: + embedding_vector = await self.embedding.aembed_query(query) + return await self.asimilarity_search_with_score_by_vector( + embedding_vector, + k, + filter=filter, + ) + + @staticmethod + def _get_mmr_hits(embedding, k, lambda_mult, prefetch_hits): + mmr_chosen_indices = maximal_marginal_relevance( + np.array(embedding, dtype=np.float32), + [prefetch_hit["$vector"] for prefetch_hit in prefetch_hits], + k=k, + lambda_mult=lambda_mult, + ) + mmr_hits = [ + prefetch_hit + for prefetch_index, prefetch_hit in enumerate(prefetch_hits) + if prefetch_index in mmr_chosen_indices + ] + return [ + Document( + page_content=hit["content"], + metadata=hit["metadata"], + ) + for hit in mmr_hits + ] + def max_marginal_relevance_search_by_vector( self, embedding: List[float], @@ -619,6 +1004,7 @@ class AstraDB(VectorStore): Returns: List of Documents selected by maximal marginal relevance. """ + self._ensure_astra_db_client() metadata_parameter = self._filter_to_metadata(filter) prefetch_hits = list( @@ -635,25 +1021,61 @@ class AstraDB(VectorStore): ) ) - mmr_chosen_indices = maximal_marginal_relevance( - np.array(embedding, dtype=np.float32), - [prefetch_hit["$vector"] for prefetch_hit in prefetch_hits], - k=k, - lambda_mult=lambda_mult, - ) - mmr_hits = [ - prefetch_hit - for prefetch_index, prefetch_hit in enumerate(prefetch_hits) - if prefetch_index in mmr_chosen_indices - ] - return [ - Document( - page_content=hit["content"], - metadata=hit["metadata"], + return self._get_mmr_hits(embedding, k, lambda_mult, prefetch_hits) + + async def amax_marginal_relevance_search_by_vector( + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Returns: + List of Documents selected by maximal marginal relevance. + """ + await self._ensure_db_setup() + if not self.async_collection: + return await run_in_executor( + None, + self.max_marginal_relevance_search_by_vector, + embedding, + k, + fetch_k, + lambda_mult, + filter, + **kwargs, + ) + metadata_parameter = self._filter_to_metadata(filter) + + prefetch_hits = [ + hit + async for hit in self.async_collection.paginated_find( + filter=metadata_parameter, + sort={"$vector": embedding}, + options={"limit": fetch_k, "includeSimilarity": True}, + projection={ + "_id": 1, + "content": 1, + "metadata": 1, + "$vector": 1, + }, ) - for hit in mmr_hits ] + return self._get_mmr_hits(embedding, k, lambda_mult, prefetch_hits) + def max_marginal_relevance_search( self, query: str, @@ -686,36 +1108,50 @@ class AstraDB(VectorStore): filter=filter, ) - @classmethod - def from_texts( - cls: Type[ADBVST], - texts: List[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, + async def amax_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> ADBVST: - """Create an Astra DB vectorstore from raw texts. - + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. Args: - texts (List[str]): the texts to insert. - embedding (Embeddings): the embedding function to use in the store. - metadatas (Optional[List[dict]]): metadata dicts for the texts. - ids (Optional[List[str]]): ids to associate to the texts. - *Additional arguments*: you can pass any argument that you would - to 'add_texts' and/or to the 'AstraDB' class constructor - (see these methods for details). These arguments will be - routed to the respective methods as they are. - + query (str): Text to look up documents similar to. + k (int = 4): Number of Documents to return. + fetch_k (int = 20): Number of Documents to fetch to pass to MMR algorithm. + lambda_mult (float = 0.5): Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Optional. Returns: - an `AstraDb` vectorstore. + List of Documents selected by maximal marginal relevance. """ + embedding_vector = await self.embedding.aembed_query(query) + return await self.amax_marginal_relevance_search_by_vector( + embedding_vector, + k, + fetch_k, + lambda_mult=lambda_mult, + filter=filter, + ) + @classmethod + def _from_kwargs( + cls: Type[ADBVST], + embedding: Embeddings, + **kwargs: Any, + ) -> ADBVST: known_kwargs = { "collection_name", "token", "api_endpoint", "astra_db_client", + "async_astra_db_client", "namespace", "metric", "batch_size", @@ -738,15 +1174,17 @@ class AstraDB(VectorStore): token = kwargs.get("token") api_endpoint = kwargs.get("api_endpoint") astra_db_client = kwargs.get("astra_db_client") + async_astra_db_client = kwargs.get("async_astra_db_client") namespace = kwargs.get("namespace") metric = kwargs.get("metric") - astra_db_store = cls( + return cls( embedding=embedding, collection_name=collection_name, token=token, api_endpoint=api_endpoint, astra_db_client=astra_db_client, + async_astra_db_client=async_astra_db_client, namespace=namespace, metric=metric, batch_size=kwargs.get("batch_size"), @@ -756,6 +1194,32 @@ class AstraDB(VectorStore): ), bulk_delete_concurrency=kwargs.get("bulk_delete_concurrency"), ) + + @classmethod + def from_texts( + cls: Type[ADBVST], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> ADBVST: + """Create an Astra DB vectorstore from raw texts. + + Args: + texts (List[str]): the texts to insert. + embedding (Embeddings): the embedding function to use in the store. + metadatas (Optional[List[dict]]): metadata dicts for the texts. + ids (Optional[List[str]]): ids to associate to the texts. + *Additional arguments*: you can pass any argument that you would + to 'add_texts' and/or to the 'AstraDB' class constructor + (see these methods for details). These arguments will be + routed to the respective methods as they are. + + Returns: + an `AstraDb` vectorstore. + """ + astra_db_store = AstraDB._from_kwargs(embedding, **kwargs) astra_db_store.add_texts( texts=texts, metadatas=metadatas, @@ -766,6 +1230,41 @@ class AstraDB(VectorStore): ) return astra_db_store + @classmethod + async def afrom_texts( + cls: Type[ADBVST], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> ADBVST: + """Create an Astra DB vectorstore from raw texts. + + Args: + texts (List[str]): the texts to insert. + embedding (Embeddings): the embedding function to use in the store. + metadatas (Optional[List[dict]]): metadata dicts for the texts. + ids (Optional[List[str]]): ids to associate to the texts. + *Additional arguments*: you can pass any argument that you would + to 'add_texts' and/or to the 'AstraDB' class constructor + (see these methods for details). These arguments will be + routed to the respective methods as they are. + + Returns: + an `AstraDb` vectorstore. + """ + astra_db_store = AstraDB._from_kwargs(embedding, **kwargs) + await astra_db_store.aadd_texts( + texts=texts, + metadatas=metadatas, + ids=ids, + batch_size=kwargs.get("batch_size"), + batch_concurrency=kwargs.get("batch_concurrency"), + overwrite_concurrency=kwargs.get("overwrite_concurrency"), + ) + return astra_db_store + @classmethod def from_documents( cls: Type[ADBVST], diff --git a/libs/community/tests/integration_tests/vectorstores/test_astradb.py b/libs/community/tests/integration_tests/vectorstores/test_astradb.py index 4263b56161..f652342e5c 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_astradb.py +++ b/libs/community/tests/integration_tests/vectorstores/test_astradb.py @@ -148,6 +148,33 @@ class TestAstraDB: ) v_store_2.delete_collection() + async def test_astradb_vectorstore_create_delete_async(self) -> None: + """Create and delete.""" + emb = SomeEmbeddings(dimension=2) + # creation by passing the connection secrets + v_store = AstraDB( + embedding=emb, + collection_name="lc_test_1_async", + token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], + api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], + namespace=os.environ.get("ASTRA_DB_KEYSPACE"), + ) + await v_store.adelete_collection() + # Creation by passing a ready-made astrapy client: + from astrapy.db import AsyncAstraDB + + astra_db_client = AsyncAstraDB( + token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], + api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], + namespace=os.environ.get("ASTRA_DB_KEYSPACE"), + ) + v_store_2 = AstraDB( + embedding=emb, + collection_name="lc_test_2_async", + async_astra_db_client=astra_db_client, + ) + await v_store_2.adelete_collection() + def test_astradb_vectorstore_pre_delete_collection(self) -> None: """Create and delete.""" emb = SomeEmbeddings(dimension=2) @@ -183,6 +210,41 @@ class TestAstraDB: finally: v_store.delete_collection() + async def test_astradb_vectorstore_pre_delete_collection_async(self) -> None: + """Create and delete.""" + emb = SomeEmbeddings(dimension=2) + # creation by passing the connection secrets + + v_store = AstraDB( + embedding=emb, + collection_name="lc_test_pre_del_async", + token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], + api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], + namespace=os.environ.get("ASTRA_DB_KEYSPACE"), + ) + try: + await v_store.aadd_texts( + texts=["aa"], + metadatas=[ + {"k": "a", "ord": 0}, + ], + ids=["a"], + ) + res1 = await v_store.asimilarity_search("aa", k=5) + assert len(res1) == 1 + v_store = AstraDB( + embedding=emb, + pre_delete_collection=True, + collection_name="lc_test_pre_del_async", + token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], + api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], + namespace=os.environ.get("ASTRA_DB_KEYSPACE"), + ) + res1 = await v_store.asimilarity_search("aa", k=5) + assert len(res1) == 0 + finally: + await v_store.adelete_collection() + def test_astradb_vectorstore_from_x(self) -> None: """from_texts and from_documents methods.""" emb = SomeEmbeddings(dimension=2) @@ -200,7 +262,7 @@ class TestAstraDB: finally: v_store.delete_collection() - # from_texts + # from_documents v_store_2 = AstraDB.from_documents( [ Document(page_content="Hee"), @@ -217,6 +279,42 @@ class TestAstraDB: finally: v_store_2.delete_collection() + async def test_astradb_vectorstore_from_x_async(self) -> None: + """from_texts and from_documents methods.""" + emb = SomeEmbeddings(dimension=2) + # from_texts + v_store = await AstraDB.afrom_texts( + texts=["Hi", "Ho"], + embedding=emb, + collection_name="lc_test_ft_async", + token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], + api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], + namespace=os.environ.get("ASTRA_DB_KEYSPACE"), + ) + try: + assert (await v_store.asimilarity_search("Ho", k=1))[0].page_content == "Ho" + finally: + await v_store.adelete_collection() + + # from_documents + v_store_2 = await AstraDB.afrom_documents( + [ + Document(page_content="Hee"), + Document(page_content="Hoi"), + ], + embedding=emb, + collection_name="lc_test_fd_async", + token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], + api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], + namespace=os.environ.get("ASTRA_DB_KEYSPACE"), + ) + try: + assert (await v_store_2.asimilarity_search("Hoi", k=1))[ + 0 + ].page_content == "Hoi" + finally: + await v_store_2.adelete_collection() + def test_astradb_vectorstore_crud(self, store_someemb: AstraDB) -> None: """Basic add/delete/update behaviour.""" res0 = store_someemb.similarity_search("Abc", k=2) @@ -275,25 +373,106 @@ class TestAstraDB: res4 = store_someemb.similarity_search("ww", k=1, filter={"k": "w"}) assert res4[0].metadata["ord"] == 205 + async def test_astradb_vectorstore_crud_async(self, store_someemb: AstraDB) -> None: + """Basic add/delete/update behaviour.""" + res0 = await store_someemb.asimilarity_search("Abc", k=2) + assert res0 == [] + # write and check again + await store_someemb.aadd_texts( + texts=["aa", "bb", "cc"], + metadatas=[ + {"k": "a", "ord": 0}, + {"k": "b", "ord": 1}, + {"k": "c", "ord": 2}, + ], + ids=["a", "b", "c"], + ) + res1 = await store_someemb.asimilarity_search("Abc", k=5) + assert {doc.page_content for doc in res1} == {"aa", "bb", "cc"} + # partial overwrite and count total entries + await store_someemb.aadd_texts( + texts=["cc", "dd"], + metadatas=[ + {"k": "c_new", "ord": 102}, + {"k": "d_new", "ord": 103}, + ], + ids=["c", "d"], + ) + res2 = await store_someemb.asimilarity_search("Abc", k=10) + assert len(res2) == 4 + # pick one that was just updated and check its metadata + res3 = await store_someemb.asimilarity_search_with_score_id( + query="cc", k=1, filter={"k": "c_new"} + ) + print(str(res3)) + doc3, score3, id3 = res3[0] + assert doc3.page_content == "cc" + assert doc3.metadata == {"k": "c_new", "ord": 102} + assert score3 > 0.999 # leaving some leeway for approximations... + assert id3 == "c" + # delete and count again + del1_res = await store_someemb.adelete(["b"]) + assert del1_res is True + del2_res = await store_someemb.adelete(["a", "c", "Z!"]) + assert del2_res is False # a non-existing ID was supplied + assert len(await store_someemb.asimilarity_search("xy", k=10)) == 1 + # clear store + await store_someemb.aclear() + assert await store_someemb.asimilarity_search("Abc", k=2) == [] + # add_documents with "ids" arg passthrough + await store_someemb.aadd_documents( + [ + Document(page_content="vv", metadata={"k": "v", "ord": 204}), + Document(page_content="ww", metadata={"k": "w", "ord": 205}), + ], + ids=["v", "w"], + ) + assert len(await store_someemb.asimilarity_search("xy", k=10)) == 2 + res4 = await store_someemb.asimilarity_search("ww", k=1, filter={"k": "w"}) + assert res4[0].metadata["ord"] == 205 + + @staticmethod + def _v_from_i(i: int, N: int) -> str: + angle = 2 * math.pi * i / N + vector = [math.cos(angle), math.sin(angle)] + return json.dumps(vector) + def test_astradb_vectorstore_mmr(self, store_parseremb: AstraDB) -> None: """ MMR testing. We work on the unit circle with angle multiples of 2*pi/20 and prepare a store with known vectors for a controlled MMR outcome. """ - - def _v_from_i(i: int, N: int) -> str: - angle = 2 * math.pi * i / N - vector = [math.cos(angle), math.sin(angle)] - return json.dumps(vector) - i_vals = [0, 4, 5, 13] N_val = 20 store_parseremb.add_texts( - [_v_from_i(i, N_val) for i in i_vals], metadatas=[{"i": i} for i in i_vals] + [self._v_from_i(i, N_val) for i in i_vals], + metadatas=[{"i": i} for i in i_vals], ) res1 = store_parseremb.max_marginal_relevance_search( - _v_from_i(3, N_val), + self._v_from_i(3, N_val), + k=2, + fetch_k=3, + ) + res_i_vals = {doc.metadata["i"] for doc in res1} + assert res_i_vals == {0, 4} + + async def test_astradb_vectorstore_mmr_async( + self, store_parseremb: AstraDB + ) -> None: + """ + MMR testing. We work on the unit circle with angle multiples + of 2*pi/20 and prepare a store with known vectors for a controlled + MMR outcome. + """ + i_vals = [0, 4, 5, 13] + N_val = 20 + await store_parseremb.aadd_texts( + [self._v_from_i(i, N_val) for i in i_vals], + metadatas=[{"i": i} for i in i_vals], + ) + res1 = await store_parseremb.amax_marginal_relevance_search( + self._v_from_i(3, N_val), k=2, fetch_k=3, ) @@ -381,6 +560,25 @@ class TestAstraDB: sco_near, sco_far = scores assert abs(1 - sco_near) < 0.001 and abs(sco_far) < 0.001 + async def test_astradb_vectorstore_similarity_scale_async( + self, store_parseremb: AstraDB + ) -> None: + """Scale of the similarity scores.""" + await store_parseremb.aadd_texts( + texts=[ + json.dumps([1, 1]), + json.dumps([-1, -1]), + ], + ids=["near", "far"], + ) + res1 = await store_parseremb.asimilarity_search_with_score( + json.dumps([0.5, 0.5]), + k=2, + ) + scores = [sco for _, sco in res1] + sco_near, sco_far = scores + assert abs(1 - sco_near) < 0.001 and abs(sco_far) < 0.001 + def test_astradb_vectorstore_massive_delete(self, store_someemb: AstraDB) -> None: """Larger-scale bulk deletes.""" M = 50 @@ -458,6 +656,40 @@ class TestAstraDB: finally: v_store.delete_collection() + async def test_astradb_vectorstore_custom_params_async(self) -> None: + """Custom batch size and concurrency params.""" + emb = SomeEmbeddings(dimension=2) + v_store = AstraDB( + embedding=emb, + collection_name="lc_test_c_async", + token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], + api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], + namespace=os.environ.get("ASTRA_DB_KEYSPACE"), + batch_size=17, + bulk_insert_batch_concurrency=13, + bulk_insert_overwrite_concurrency=7, + bulk_delete_concurrency=19, + ) + try: + # add_texts + N = 50 + texts = [str(i + 1 / 7.0) for i in range(N)] + ids = ["doc_%i" % i for i in range(N)] + await v_store.aadd_texts(texts=texts, ids=ids) + await v_store.aadd_texts( + texts=texts, + ids=ids, + batch_size=19, + batch_concurrency=7, + overwrite_concurrency=13, + ) + # + await v_store.adelete(ids[: N // 2]) + await v_store.adelete(ids[N // 2 :], concurrency=23) + # + finally: + await v_store.adelete_collection() + def test_astradb_vectorstore_metrics(self) -> None: """ Different choices of similarity metric.