Add async methods for the AstraDB VectorStore (#16391)

- **Description**: fully async versions are available for astrapy 0.7+.
For older astrapy versions or if the user provides a sync client without
an async one, the async methods will call the sync ones wrapped in
`run_in_executor`
  - **Twitter handle:** cbornet_
pull/16763/head
Christophe Bornet 5 months ago committed by GitHub
parent f8f2649f12
commit 744070ee85
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,9 +1,12 @@
from __future__ import annotations
import asyncio
import uuid
import warnings
from asyncio import Task
from concurrent.futures import ThreadPoolExecutor
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
@ -19,11 +22,17 @@ from typing import (
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.runnables import run_in_executor
from langchain_core.runnables.utils import gather_with_concurrency
from langchain_core.utils.iter import batch_iterate
from langchain_core.vectorstores import VectorStore
from langchain_community.vectorstores.utils import maximal_marginal_relevance
if TYPE_CHECKING:
from astrapy.db import AstraDB as LibAstraDB
from astrapy.db import AsyncAstraDB
ADBVST = TypeVar("ADBVST", bound="AstraDB")
T = TypeVar("T")
U = TypeVar("U")
@ -144,7 +153,8 @@ class AstraDB(VectorStore):
collection_name: str,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed
astra_db_client: Optional[LibAstraDB] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
namespace: Optional[str] = None,
metric: Optional[str] = None,
batch_size: Optional[int] = None,
@ -157,12 +167,8 @@ class AstraDB(VectorStore):
Create an AstraDB vector store object. See class docstring for help.
"""
try:
from astrapy.db import (
AstraDB as LibAstraDB,
)
from astrapy.db import (
AstraDBCollection as LibAstraDBCollection,
)
from astrapy.db import AstraDB as LibAstraDB
from astrapy.db import AstraDBCollection
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Could not import a recent astrapy python package. "
@ -170,11 +176,11 @@ class AstraDB(VectorStore):
)
# Conflicting-arg checks:
if astra_db_client is not None:
if astra_db_client is not None or async_astra_db_client is not None:
if token is not None or api_endpoint is not None:
raise ValueError(
"You cannot pass 'astra_db_client' to AstraDB if passing "
"'token' and 'api_endpoint'."
"You cannot pass 'astra_db_client' or 'async_astra_db_client' to "
"AstraDB if passing 'token' and 'api_endpoint'."
)
self.embedding = embedding
@ -198,23 +204,69 @@ class AstraDB(VectorStore):
self._embedding_dimension: Optional[int] = None
self.metric = metric
if astra_db_client is not None:
self.astra_db = astra_db_client
else:
self.astra_db = astra_db_client
self.async_astra_db = async_astra_db_client
self.collection = None
self.async_collection = None
if token and api_endpoint:
self.astra_db = LibAstraDB(
token=self.token,
api_endpoint=self.api_endpoint,
namespace=self.namespace,
)
if not pre_delete_collection:
self._provision_collection()
else:
self.clear()
try:
from astrapy.db import AsyncAstraDB
self.collection = LibAstraDBCollection(
collection_name=self.collection_name,
astra_db=self.astra_db,
)
self.async_astra_db = AsyncAstraDB(
token=self.token,
api_endpoint=self.api_endpoint,
namespace=self.namespace,
)
except (ImportError, ModuleNotFoundError):
pass
if self.astra_db is not None:
self.collection = AstraDBCollection(
collection_name=self.collection_name,
astra_db=self.astra_db,
)
self.async_setup_db_task: Optional[Task] = None
if self.async_astra_db is not None:
from astrapy.db import AsyncAstraDBCollection
self.async_collection = AsyncAstraDBCollection(
collection_name=self.collection_name,
astra_db=self.async_astra_db,
)
try:
self.async_setup_db_task = asyncio.create_task(
self._setup_db(pre_delete_collection)
)
except RuntimeError:
pass
if self.async_setup_db_task is None:
if not pre_delete_collection:
self._provision_collection()
else:
self.clear()
def _ensure_astra_db_client(self):
if not self.astra_db:
raise ValueError("Missing AstraDB client")
async def _setup_db(self, pre_delete_collection: bool) -> None:
if pre_delete_collection:
await self.async_astra_db.delete_collection(
collection_name=self.collection_name,
)
await self._aprovision_collection()
async def _ensure_db_setup(self) -> None:
if self.async_setup_db_task:
await self.async_setup_db_task
def _get_embedding_dimension(self) -> int:
if self._embedding_dimension is None:
@ -223,31 +275,31 @@ class AstraDB(VectorStore):
)
return self._embedding_dimension
def _drop_collection(self) -> None:
def _provision_collection(self) -> None:
"""
Drop the collection from storage.
Run the API invocation to create the collection on the backend.
This is meant as an internal-usage method, no members
are set other than actual deletion on the backend.
Internal-usage method, no object members are set,
other than working on the underlying actual storage.
"""
_ = self.astra_db.delete_collection(
self.astra_db.create_collection(
dimension=self._get_embedding_dimension(),
collection_name=self.collection_name,
metric=self.metric,
)
return None
def _provision_collection(self) -> None:
async def _aprovision_collection(self) -> None:
"""
Run the API invocation to create the collection on the backend.
Internal-usage method, no object members are set,
other than working on the underlying actual storage.
"""
_ = self.astra_db.create_collection(
await self.async_astra_db.create_collection(
dimension=self._get_embedding_dimension(),
collection_name=self.collection_name,
metric=self.metric,
)
return None
@property
def embeddings(self) -> Embeddings:
@ -268,16 +320,36 @@ class AstraDB(VectorStore):
def clear(self) -> None:
"""Empty the collection of all its stored entries."""
self._drop_collection()
self.delete_collection()
self._provision_collection()
return None
async def aclear(self) -> None:
"""Empty the collection of all its stored entries."""
await self._ensure_db_setup()
if not self.async_astra_db:
await run_in_executor(None, self.clear)
await self.async_collection.delete_many({})
def delete_by_document_id(self, document_id: str) -> bool:
"""
Remove a single document from the store, given its document_id (str).
Return True if a document has indeed been deleted, False if ID not found.
"""
deletion_response = self.collection.delete(document_id)
self._ensure_astra_db_client()
deletion_response = self.collection.delete_one(document_id)
return ((deletion_response or {}).get("status") or {}).get(
"deletedCount", 0
) == 1
async def adelete_by_document_id(self, document_id: str) -> bool:
"""
Remove a single document from the store, given its document_id (str).
Return True if a document has indeed been deleted, False if ID not found.
"""
await self._ensure_db_setup()
if not self.async_collection:
return await run_in_executor(None, self.delete_by_document_id, document_id)
deletion_response = await self.async_collection.delete_one(document_id)
return ((deletion_response or {}).get("status") or {}).get(
"deletedCount", 0
) == 1
@ -320,6 +392,40 @@ class AstraDB(VectorStore):
)
return True
async def adelete(
self,
ids: Optional[List[str]] = None,
concurrency: Optional[int] = None,
**kwargs: Any,
) -> Optional[bool]:
"""Delete by vector ID or other criteria.
Args:
ids: List of ids to delete.
concurrency (Optional[int]): max number of concurrent delete queries.
Defaults to instance-level setting.
**kwargs: Other keyword arguments that subclasses might use.
Returns:
Optional[bool]: True if deletion is successful,
False otherwise, None if not implemented.
"""
if kwargs:
warnings.warn(
"Method 'adelete' of AstraDB vector store invoked with "
f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), "
"which will be ignored."
)
if ids is None:
raise ValueError("No ids provided to delete.")
return all(
await gather_with_concurrency(
concurrency, *[self.adelete_by_document_id(doc_id) for doc_id in ids]
)
)
def delete_collection(self) -> None:
"""
Completely delete the collection from the database (as opposed
@ -327,8 +433,88 @@ class AstraDB(VectorStore):
Stored data is lost and unrecoverable, resources are freed.
Use with caution.
"""
self._drop_collection()
return None
self._ensure_astra_db_client()
self.astra_db.delete_collection(
collection_name=self.collection_name,
)
async def adelete_collection(self) -> None:
"""
Completely delete the collection from the database (as opposed
to 'clear()', which empties it only).
Stored data is lost and unrecoverable, resources are freed.
Use with caution.
"""
await self._ensure_db_setup()
if not self.async_astra_db:
await run_in_executor(None, self.delete_collection)
await self.async_astra_db.delete_collection(
collection_name=self.collection_name,
)
@staticmethod
def _get_documents_to_insert(
texts: Iterable[str],
embedding_vectors: List[List[float]],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
) -> List[DocDict]:
if ids is None:
ids = [uuid.uuid4().hex for _ in texts]
if metadatas is None:
metadatas = [{} for _ in texts]
#
documents_to_insert = [
{
"content": b_txt,
"_id": b_id,
"$vector": b_emb,
"metadata": b_md,
}
for b_txt, b_emb, b_id, b_md in zip(
texts,
embedding_vectors,
ids,
metadatas,
)
]
# make unique by id, keeping the last
uniqued_documents_to_insert = _unique_list(
documents_to_insert[::-1],
lambda document: document["_id"],
)[::-1]
return uniqued_documents_to_insert
@staticmethod
def _get_missing_from_batch(
document_batch: List[DocDict], insert_result: Dict[str, Any]
) -> Tuple[List[str], List[DocDict]]:
if "status" not in insert_result:
raise ValueError(
f"API Exception while running bulk insertion: {str(insert_result)}"
)
batch_inserted = insert_result["status"]["insertedIds"]
# estimation of the preexisting documents that failed
missed_inserted_ids = {document["_id"] for document in document_batch} - set(
batch_inserted
)
errors = insert_result.get("errors", [])
# careful for other sources of error other than "doc already exists"
num_errors = len(errors)
unexpected_errors = any(
error.get("errorCode") != "DOCUMENT_ALREADY_EXISTS" for error in errors
)
if num_errors != len(missed_inserted_ids) or unexpected_errors:
raise ValueError(
f"API Exception while running bulk insertion: {str(errors)}"
)
# deal with the missing insertions as upserts
missing_from_batch = [
document
for document in document_batch
if document["_id"] in missed_inserted_ids
]
return batch_inserted, missing_from_batch
def add_texts(
self,
@ -377,36 +563,12 @@ class AstraDB(VectorStore):
f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), "
"which will be ignored."
)
self._ensure_astra_db_client()
_texts = list(texts)
if ids is None:
ids = [uuid.uuid4().hex for _ in _texts]
if metadatas is None:
metadatas = [{} for _ in _texts]
#
embedding_vectors = self.embedding.embed_documents(_texts)
documents_to_insert = [
{
"content": b_txt,
"_id": b_id,
"$vector": b_emb,
"metadata": b_md,
}
for b_txt, b_emb, b_id, b_md in zip(
_texts,
embedding_vectors,
ids,
metadatas,
)
]
# make unique by id, keeping the last
uniqued_documents_to_insert = _unique_list(
documents_to_insert[::-1],
lambda document: document["_id"],
)[::-1]
all_ids = []
embedding_vectors = self.embedding.embed_documents(list(texts))
documents_to_insert = self._get_documents_to_insert(
texts, embedding_vectors, metadatas, ids
)
def _handle_batch(document_batch: List[DocDict]) -> List[str]:
im_result = self.collection.insert_many(
@ -414,33 +576,9 @@ class AstraDB(VectorStore):
options={"ordered": False},
partial_failures_allowed=True,
)
if "status" not in im_result:
raise ValueError(
f"API Exception while running bulk insertion: {str(im_result)}"
)
batch_inserted = im_result["status"]["insertedIds"]
# estimation of the preexisting documents that failed
missed_inserted_ids = {
document["_id"] for document in document_batch
} - set(batch_inserted)
errors = im_result.get("errors", [])
# careful for other sources of error other than "doc already exists"
num_errors = len(errors)
unexpected_errors = any(
error.get("errorCode") != "DOCUMENT_ALREADY_EXISTS" for error in errors
)
if num_errors != len(missed_inserted_ids) or unexpected_errors:
raise ValueError(
f"API Exception while running bulk insertion: {str(errors)}"
)
# deal with the missing insertions as upserts
missing_from_batch = [
document
for document in document_batch
if document["_id"] in missed_inserted_ids
]
batch_inserted, missing_from_batch = self._get_missing_from_batch(
document_batch, im_result
)
def _handle_missing_document(missing_document: DocDict) -> str:
replacement_result = self.collection.find_one_and_replace(
@ -459,9 +597,7 @@ class AstraDB(VectorStore):
missing_from_batch,
)
)
upsert_ids = batch_inserted + batch_replaced
return upsert_ids
return batch_inserted + batch_replaced
_b_max_workers = batch_concurrency or self.bulk_insert_batch_concurrency
with ThreadPoolExecutor(max_workers=_b_max_workers) as tpe:
@ -469,13 +605,111 @@ class AstraDB(VectorStore):
_handle_batch,
batch_iterate(
batch_size or self.batch_size,
uniqued_documents_to_insert,
documents_to_insert,
),
)
return [iid for id_list in all_ids_nested for iid in id_list]
async def aadd_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
*,
batch_size: Optional[int] = None,
batch_concurrency: Optional[int] = None,
overwrite_concurrency: Optional[int] = None,
**kwargs: Any,
) -> List[str]:
"""Run texts through the embeddings and add them to the vectorstore.
If passing explicit ids, those entries whose id is in the store already
will be replaced.
Args:
texts (Iterable[str]): Texts to add to the vectorstore.
metadatas (Optional[List[dict]], optional): Optional list of metadatas.
ids (Optional[List[str]], optional): Optional list of ids.
batch_size (Optional[int]): Number of documents in each API call.
Check the underlying Astra DB HTTP API specs for the max value
(20 at the time of writing this). If not provided, defaults
to the instance-level setting.
batch_concurrency (Optional[int]): number of concurrent batch insertions.
Defaults to instance-level setting if not provided.
overwrite_concurrency (Optional[int]): number of concurrent API calls to
process pre-existing documents in each batch.
Defaults to instance-level setting if not provided.
A note on metadata: there are constraints on the allowed field names
in this dictionary, coming from the underlying Astra DB API.
For instance, the `$` (dollar sign) cannot be used in the dict keys.
See this document for details:
docs.datastax.com/en/astra-serverless/docs/develop/dev-with-json.html
Returns:
List[str]: List of ids of the added texts.
"""
await self._ensure_db_setup()
if not self.async_collection:
await super().aadd_texts(
texts,
metadatas,
ids=ids,
batch_size=batch_size,
batch_concurrency=batch_concurrency,
overwrite_concurrency=overwrite_concurrency,
)
if kwargs:
warnings.warn(
"Method 'aadd_texts' of AstraDB vector store invoked with "
f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), "
"which will be ignored."
)
embedding_vectors = await self.embedding.aembed_documents(list(texts))
documents_to_insert = self._get_documents_to_insert(
texts, embedding_vectors, metadatas, ids
)
async def _handle_batch(document_batch: List[DocDict]) -> List[str]:
im_result = await self.async_collection.insert_many(
documents=document_batch,
options={"ordered": False},
partial_failures_allowed=True,
)
batch_inserted, missing_from_batch = self._get_missing_from_batch(
document_batch, im_result
)
async def _handle_missing_document(missing_document: DocDict) -> str:
replacement_result = await self.async_collection.find_one_and_replace(
filter={"_id": missing_document["_id"]},
replacement=missing_document,
)
return replacement_result["data"]["document"]["_id"]
all_ids = [iid for id_list in all_ids_nested for iid in id_list]
_u_max_workers = (
overwrite_concurrency or self.bulk_insert_overwrite_concurrency
)
batch_replaced = await gather_with_concurrency(
_u_max_workers,
*[_handle_missing_document(doc) for doc in missing_from_batch],
)
return batch_inserted + batch_replaced
_b_max_workers = batch_concurrency or self.bulk_insert_batch_concurrency
all_ids_nested = await gather_with_concurrency(
_b_max_workers,
*[
_handle_batch(batch)
for batch in batch_iterate(
batch_size or self.batch_size,
documents_to_insert,
)
],
)
return all_ids
return [iid for id_list in all_ids_nested for iid in id_list]
def similarity_search_with_score_id_by_vector(
self,
@ -491,6 +725,7 @@ class AstraDB(VectorStore):
Returns:
List of (Document, score, id), the most similar to the query vector.
"""
self._ensure_astra_db_client()
metadata_parameter = self._filter_to_metadata(filter)
#
hits = list(
@ -518,6 +753,52 @@ class AstraDB(VectorStore):
for hit in hits
]
async def asimilarity_search_with_score_id_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float, str]]:
"""Return docs most similar to embedding vector.
Args:
embedding (str): Embedding to look up documents similar to.
k (int): Number of Documents to return. Defaults to 4.
Returns:
List of (Document, score, id), the most similar to the query vector.
"""
await self._ensure_db_setup()
if not self.async_collection:
return await run_in_executor(
None,
self.asimilarity_search_with_score_id_by_vector,
embedding,
k,
filter,
)
metadata_parameter = self._filter_to_metadata(filter)
#
return [
(
Document(
page_content=hit["content"],
metadata=hit["metadata"],
),
hit["$similarity"],
hit["_id"],
)
async for hit in self.async_collection.paginated_find(
filter=metadata_parameter,
sort={"$vector": embedding},
options={"limit": k, "includeSimilarity": True},
projection={
"_id": 1,
"content": 1,
"metadata": 1,
},
)
]
def similarity_search_with_score_id(
self,
query: str,
@ -531,6 +812,19 @@ class AstraDB(VectorStore):
filter=filter,
)
async def asimilarity_search_with_score_id(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float, str]]:
embedding_vector = await self.embedding.aembed_query(query)
return await self.asimilarity_search_with_score_id_by_vector(
embedding=embedding_vector,
k=k,
filter=filter,
)
def similarity_search_with_score_by_vector(
self,
embedding: List[float],
@ -554,6 +848,33 @@ class AstraDB(VectorStore):
)
]
async def asimilarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float]]:
"""Return docs most similar to embedding vector.
Args:
embedding (str): Embedding to look up documents similar to.
k (int): Number of Documents to return. Defaults to 4.
Returns:
List of (Document, score), the most similar to the query vector.
"""
return [
(doc, score)
for (
doc,
score,
doc_id,
) in await self.asimilarity_search_with_score_id_by_vector(
embedding=embedding,
k=k,
filter=filter,
)
]
def similarity_search(
self,
query: str,
@ -568,6 +889,20 @@ class AstraDB(VectorStore):
filter=filter,
)
async def asimilarity_search(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
embedding_vector = await self.embedding.aembed_query(query)
return await self.asimilarity_search_by_vector(
embedding_vector,
k,
filter=filter,
)
def similarity_search_by_vector(
self,
embedding: List[float],
@ -584,6 +919,22 @@ class AstraDB(VectorStore):
)
]
async def asimilarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
return [
doc
for doc, _ in await self.asimilarity_search_with_score_by_vector(
embedding,
k,
filter=filter,
)
]
def similarity_search_with_score(
self,
query: str,
@ -597,6 +948,40 @@ class AstraDB(VectorStore):
filter=filter,
)
async def asimilarity_search_with_score(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float]]:
embedding_vector = await self.embedding.aembed_query(query)
return await self.asimilarity_search_with_score_by_vector(
embedding_vector,
k,
filter=filter,
)
@staticmethod
def _get_mmr_hits(embedding, k, lambda_mult, prefetch_hits):
mmr_chosen_indices = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32),
[prefetch_hit["$vector"] for prefetch_hit in prefetch_hits],
k=k,
lambda_mult=lambda_mult,
)
mmr_hits = [
prefetch_hit
for prefetch_index, prefetch_hit in enumerate(prefetch_hits)
if prefetch_index in mmr_chosen_indices
]
return [
Document(
page_content=hit["content"],
metadata=hit["metadata"],
)
for hit in mmr_hits
]
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
@ -619,6 +1004,7 @@ class AstraDB(VectorStore):
Returns:
List of Documents selected by maximal marginal relevance.
"""
self._ensure_astra_db_client()
metadata_parameter = self._filter_to_metadata(filter)
prefetch_hits = list(
@ -635,25 +1021,61 @@ class AstraDB(VectorStore):
)
)
mmr_chosen_indices = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32),
[prefetch_hit["$vector"] for prefetch_hit in prefetch_hits],
k=k,
lambda_mult=lambda_mult,
)
mmr_hits = [
prefetch_hit
for prefetch_index, prefetch_hit in enumerate(prefetch_hits)
if prefetch_index in mmr_chosen_indices
]
return [
Document(
page_content=hit["content"],
metadata=hit["metadata"],
return self._get_mmr_hits(embedding, k, lambda_mult, prefetch_hits)
async def amax_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Returns:
List of Documents selected by maximal marginal relevance.
"""
await self._ensure_db_setup()
if not self.async_collection:
return await run_in_executor(
None,
self.max_marginal_relevance_search_by_vector,
embedding,
k,
fetch_k,
lambda_mult,
filter,
**kwargs,
)
metadata_parameter = self._filter_to_metadata(filter)
prefetch_hits = [
hit
async for hit in self.async_collection.paginated_find(
filter=metadata_parameter,
sort={"$vector": embedding},
options={"limit": fetch_k, "includeSimilarity": True},
projection={
"_id": 1,
"content": 1,
"metadata": 1,
"$vector": 1,
},
)
for hit in mmr_hits
]
return self._get_mmr_hits(embedding, k, lambda_mult, prefetch_hits)
def max_marginal_relevance_search(
self,
query: str,
@ -686,36 +1108,50 @@ class AstraDB(VectorStore):
filter=filter,
)
@classmethod
def from_texts(
cls: Type[ADBVST],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
async def amax_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> ADBVST:
"""Create an Astra DB vectorstore from raw texts.
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
texts (List[str]): the texts to insert.
embedding (Embeddings): the embedding function to use in the store.
metadatas (Optional[List[dict]]): metadata dicts for the texts.
ids (Optional[List[str]]): ids to associate to the texts.
*Additional arguments*: you can pass any argument that you would
to 'add_texts' and/or to the 'AstraDB' class constructor
(see these methods for details). These arguments will be
routed to the respective methods as they are.
query (str): Text to look up documents similar to.
k (int = 4): Number of Documents to return.
fetch_k (int = 20): Number of Documents to fetch to pass to MMR algorithm.
lambda_mult (float = 0.5): Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Optional.
Returns:
an `AstraDb` vectorstore.
List of Documents selected by maximal marginal relevance.
"""
embedding_vector = await self.embedding.aembed_query(query)
return await self.amax_marginal_relevance_search_by_vector(
embedding_vector,
k,
fetch_k,
lambda_mult=lambda_mult,
filter=filter,
)
@classmethod
def _from_kwargs(
cls: Type[ADBVST],
embedding: Embeddings,
**kwargs: Any,
) -> ADBVST:
known_kwargs = {
"collection_name",
"token",
"api_endpoint",
"astra_db_client",
"async_astra_db_client",
"namespace",
"metric",
"batch_size",
@ -738,15 +1174,17 @@ class AstraDB(VectorStore):
token = kwargs.get("token")
api_endpoint = kwargs.get("api_endpoint")
astra_db_client = kwargs.get("astra_db_client")
async_astra_db_client = kwargs.get("async_astra_db_client")
namespace = kwargs.get("namespace")
metric = kwargs.get("metric")
astra_db_store = cls(
return cls(
embedding=embedding,
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
metric=metric,
batch_size=kwargs.get("batch_size"),
@ -756,6 +1194,32 @@ class AstraDB(VectorStore):
),
bulk_delete_concurrency=kwargs.get("bulk_delete_concurrency"),
)
@classmethod
def from_texts(
cls: Type[ADBVST],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> ADBVST:
"""Create an Astra DB vectorstore from raw texts.
Args:
texts (List[str]): the texts to insert.
embedding (Embeddings): the embedding function to use in the store.
metadatas (Optional[List[dict]]): metadata dicts for the texts.
ids (Optional[List[str]]): ids to associate to the texts.
*Additional arguments*: you can pass any argument that you would
to 'add_texts' and/or to the 'AstraDB' class constructor
(see these methods for details). These arguments will be
routed to the respective methods as they are.
Returns:
an `AstraDb` vectorstore.
"""
astra_db_store = AstraDB._from_kwargs(embedding, **kwargs)
astra_db_store.add_texts(
texts=texts,
metadatas=metadatas,
@ -766,6 +1230,41 @@ class AstraDB(VectorStore):
)
return astra_db_store
@classmethod
async def afrom_texts(
cls: Type[ADBVST],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> ADBVST:
"""Create an Astra DB vectorstore from raw texts.
Args:
texts (List[str]): the texts to insert.
embedding (Embeddings): the embedding function to use in the store.
metadatas (Optional[List[dict]]): metadata dicts for the texts.
ids (Optional[List[str]]): ids to associate to the texts.
*Additional arguments*: you can pass any argument that you would
to 'add_texts' and/or to the 'AstraDB' class constructor
(see these methods for details). These arguments will be
routed to the respective methods as they are.
Returns:
an `AstraDb` vectorstore.
"""
astra_db_store = AstraDB._from_kwargs(embedding, **kwargs)
await astra_db_store.aadd_texts(
texts=texts,
metadatas=metadatas,
ids=ids,
batch_size=kwargs.get("batch_size"),
batch_concurrency=kwargs.get("batch_concurrency"),
overwrite_concurrency=kwargs.get("overwrite_concurrency"),
)
return astra_db_store
@classmethod
def from_documents(
cls: Type[ADBVST],

@ -148,6 +148,33 @@ class TestAstraDB:
)
v_store_2.delete_collection()
async def test_astradb_vectorstore_create_delete_async(self) -> None:
"""Create and delete."""
emb = SomeEmbeddings(dimension=2)
# creation by passing the connection secrets
v_store = AstraDB(
embedding=emb,
collection_name="lc_test_1_async",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)
await v_store.adelete_collection()
# Creation by passing a ready-made astrapy client:
from astrapy.db import AsyncAstraDB
astra_db_client = AsyncAstraDB(
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)
v_store_2 = AstraDB(
embedding=emb,
collection_name="lc_test_2_async",
async_astra_db_client=astra_db_client,
)
await v_store_2.adelete_collection()
def test_astradb_vectorstore_pre_delete_collection(self) -> None:
"""Create and delete."""
emb = SomeEmbeddings(dimension=2)
@ -183,6 +210,41 @@ class TestAstraDB:
finally:
v_store.delete_collection()
async def test_astradb_vectorstore_pre_delete_collection_async(self) -> None:
"""Create and delete."""
emb = SomeEmbeddings(dimension=2)
# creation by passing the connection secrets
v_store = AstraDB(
embedding=emb,
collection_name="lc_test_pre_del_async",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)
try:
await v_store.aadd_texts(
texts=["aa"],
metadatas=[
{"k": "a", "ord": 0},
],
ids=["a"],
)
res1 = await v_store.asimilarity_search("aa", k=5)
assert len(res1) == 1
v_store = AstraDB(
embedding=emb,
pre_delete_collection=True,
collection_name="lc_test_pre_del_async",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)
res1 = await v_store.asimilarity_search("aa", k=5)
assert len(res1) == 0
finally:
await v_store.adelete_collection()
def test_astradb_vectorstore_from_x(self) -> None:
"""from_texts and from_documents methods."""
emb = SomeEmbeddings(dimension=2)
@ -200,7 +262,7 @@ class TestAstraDB:
finally:
v_store.delete_collection()
# from_texts
# from_documents
v_store_2 = AstraDB.from_documents(
[
Document(page_content="Hee"),
@ -217,6 +279,42 @@ class TestAstraDB:
finally:
v_store_2.delete_collection()
async def test_astradb_vectorstore_from_x_async(self) -> None:
"""from_texts and from_documents methods."""
emb = SomeEmbeddings(dimension=2)
# from_texts
v_store = await AstraDB.afrom_texts(
texts=["Hi", "Ho"],
embedding=emb,
collection_name="lc_test_ft_async",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)
try:
assert (await v_store.asimilarity_search("Ho", k=1))[0].page_content == "Ho"
finally:
await v_store.adelete_collection()
# from_documents
v_store_2 = await AstraDB.afrom_documents(
[
Document(page_content="Hee"),
Document(page_content="Hoi"),
],
embedding=emb,
collection_name="lc_test_fd_async",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)
try:
assert (await v_store_2.asimilarity_search("Hoi", k=1))[
0
].page_content == "Hoi"
finally:
await v_store_2.adelete_collection()
def test_astradb_vectorstore_crud(self, store_someemb: AstraDB) -> None:
"""Basic add/delete/update behaviour."""
res0 = store_someemb.similarity_search("Abc", k=2)
@ -275,25 +373,106 @@ class TestAstraDB:
res4 = store_someemb.similarity_search("ww", k=1, filter={"k": "w"})
assert res4[0].metadata["ord"] == 205
async def test_astradb_vectorstore_crud_async(self, store_someemb: AstraDB) -> None:
"""Basic add/delete/update behaviour."""
res0 = await store_someemb.asimilarity_search("Abc", k=2)
assert res0 == []
# write and check again
await store_someemb.aadd_texts(
texts=["aa", "bb", "cc"],
metadatas=[
{"k": "a", "ord": 0},
{"k": "b", "ord": 1},
{"k": "c", "ord": 2},
],
ids=["a", "b", "c"],
)
res1 = await store_someemb.asimilarity_search("Abc", k=5)
assert {doc.page_content for doc in res1} == {"aa", "bb", "cc"}
# partial overwrite and count total entries
await store_someemb.aadd_texts(
texts=["cc", "dd"],
metadatas=[
{"k": "c_new", "ord": 102},
{"k": "d_new", "ord": 103},
],
ids=["c", "d"],
)
res2 = await store_someemb.asimilarity_search("Abc", k=10)
assert len(res2) == 4
# pick one that was just updated and check its metadata
res3 = await store_someemb.asimilarity_search_with_score_id(
query="cc", k=1, filter={"k": "c_new"}
)
print(str(res3))
doc3, score3, id3 = res3[0]
assert doc3.page_content == "cc"
assert doc3.metadata == {"k": "c_new", "ord": 102}
assert score3 > 0.999 # leaving some leeway for approximations...
assert id3 == "c"
# delete and count again
del1_res = await store_someemb.adelete(["b"])
assert del1_res is True
del2_res = await store_someemb.adelete(["a", "c", "Z!"])
assert del2_res is False # a non-existing ID was supplied
assert len(await store_someemb.asimilarity_search("xy", k=10)) == 1
# clear store
await store_someemb.aclear()
assert await store_someemb.asimilarity_search("Abc", k=2) == []
# add_documents with "ids" arg passthrough
await store_someemb.aadd_documents(
[
Document(page_content="vv", metadata={"k": "v", "ord": 204}),
Document(page_content="ww", metadata={"k": "w", "ord": 205}),
],
ids=["v", "w"],
)
assert len(await store_someemb.asimilarity_search("xy", k=10)) == 2
res4 = await store_someemb.asimilarity_search("ww", k=1, filter={"k": "w"})
assert res4[0].metadata["ord"] == 205
@staticmethod
def _v_from_i(i: int, N: int) -> str:
angle = 2 * math.pi * i / N
vector = [math.cos(angle), math.sin(angle)]
return json.dumps(vector)
def test_astradb_vectorstore_mmr(self, store_parseremb: AstraDB) -> None:
"""
MMR testing. We work on the unit circle with angle multiples
of 2*pi/20 and prepare a store with known vectors for a controlled
MMR outcome.
"""
def _v_from_i(i: int, N: int) -> str:
angle = 2 * math.pi * i / N
vector = [math.cos(angle), math.sin(angle)]
return json.dumps(vector)
i_vals = [0, 4, 5, 13]
N_val = 20
store_parseremb.add_texts(
[_v_from_i(i, N_val) for i in i_vals], metadatas=[{"i": i} for i in i_vals]
[self._v_from_i(i, N_val) for i in i_vals],
metadatas=[{"i": i} for i in i_vals],
)
res1 = store_parseremb.max_marginal_relevance_search(
_v_from_i(3, N_val),
self._v_from_i(3, N_val),
k=2,
fetch_k=3,
)
res_i_vals = {doc.metadata["i"] for doc in res1}
assert res_i_vals == {0, 4}
async def test_astradb_vectorstore_mmr_async(
self, store_parseremb: AstraDB
) -> None:
"""
MMR testing. We work on the unit circle with angle multiples
of 2*pi/20 and prepare a store with known vectors for a controlled
MMR outcome.
"""
i_vals = [0, 4, 5, 13]
N_val = 20
await store_parseremb.aadd_texts(
[self._v_from_i(i, N_val) for i in i_vals],
metadatas=[{"i": i} for i in i_vals],
)
res1 = await store_parseremb.amax_marginal_relevance_search(
self._v_from_i(3, N_val),
k=2,
fetch_k=3,
)
@ -381,6 +560,25 @@ class TestAstraDB:
sco_near, sco_far = scores
assert abs(1 - sco_near) < 0.001 and abs(sco_far) < 0.001
async def test_astradb_vectorstore_similarity_scale_async(
self, store_parseremb: AstraDB
) -> None:
"""Scale of the similarity scores."""
await store_parseremb.aadd_texts(
texts=[
json.dumps([1, 1]),
json.dumps([-1, -1]),
],
ids=["near", "far"],
)
res1 = await store_parseremb.asimilarity_search_with_score(
json.dumps([0.5, 0.5]),
k=2,
)
scores = [sco for _, sco in res1]
sco_near, sco_far = scores
assert abs(1 - sco_near) < 0.001 and abs(sco_far) < 0.001
def test_astradb_vectorstore_massive_delete(self, store_someemb: AstraDB) -> None:
"""Larger-scale bulk deletes."""
M = 50
@ -458,6 +656,40 @@ class TestAstraDB:
finally:
v_store.delete_collection()
async def test_astradb_vectorstore_custom_params_async(self) -> None:
"""Custom batch size and concurrency params."""
emb = SomeEmbeddings(dimension=2)
v_store = AstraDB(
embedding=emb,
collection_name="lc_test_c_async",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
batch_size=17,
bulk_insert_batch_concurrency=13,
bulk_insert_overwrite_concurrency=7,
bulk_delete_concurrency=19,
)
try:
# add_texts
N = 50
texts = [str(i + 1 / 7.0) for i in range(N)]
ids = ["doc_%i" % i for i in range(N)]
await v_store.aadd_texts(texts=texts, ids=ids)
await v_store.aadd_texts(
texts=texts,
ids=ids,
batch_size=19,
batch_concurrency=7,
overwrite_concurrency=13,
)
#
await v_store.adelete(ids[: N // 2])
await v_store.adelete(ids[N // 2 :], concurrency=23)
#
finally:
await v_store.adelete_collection()
def test_astradb_vectorstore_metrics(self) -> None:
"""
Different choices of similarity metric.

Loading…
Cancel
Save