docs: AstraDB VectorStore docstring (#17834)

pull/17899/head
Christophe Bornet 6 months ago committed by GitHub
parent 2f2b77602e
commit 5019951a5d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -71,73 +71,6 @@ def _unique_list(lst: List[T], key: Callable[[T], U]) -> List[T]:
alternative_import="langchain_astradb.AstraDBVectorStore", alternative_import="langchain_astradb.AstraDBVectorStore",
) )
class AstraDB(VectorStore): class AstraDB(VectorStore):
"""Wrapper around DataStax Astra DB for vector-store workloads.
To use it, you need a recent installation of the `astrapy` library
and an Astra DB cloud database.
For quickstart and details, visit:
docs.datastax.com/en/astra/home/astra.html
Example:
.. code-block:: python
from langchain_community.vectorstores import AstraDB
from langchain_community.embeddings.openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
vectorstore = AstraDB(
embedding=embeddings,
collection_name="my_store",
token="AstraCS:...",
api_endpoint="https://<DB-ID>-us-east1.apps.astra.datastax.com"
)
vectorstore.add_texts(["Giraffes", "All good here"])
results = vectorstore.similarity_search("Everything's ok", k=1)
Constructor Args (only keyword-arguments accepted):
embedding (Embeddings): embedding function to use.
collection_name (str): name of the Astra DB collection to create/use.
token (Optional[str]): API token for Astra DB usage.
api_endpoint (Optional[str]): full URL to the API endpoint,
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
astra_db_client (Optional[Any]): *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance.
namespace (Optional[str]): namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace".
metric (Optional[str]): similarity function to use out of those
available in Astra DB. If left out, it will use Astra DB API's
defaults (i.e. "cosine" - but, for performance reasons,
"dot_product" is suggested if embeddings are normalized to one).
Advanced arguments (coming with sensible defaults):
batch_size (Optional[int]): Size of batches for bulk insertions.
bulk_insert_batch_concurrency (Optional[int]): Number of threads
to insert batches concurrently.
bulk_insert_overwrite_concurrency (Optional[int]): Number of
threads in a batch to insert pre-existing entries.
bulk_delete_concurrency (Optional[int]): Number of threads
(for deleting multiple rows concurrently).
pre_delete_collection (Optional[bool]): whether to delete the collection
before creating it. If False and the collection already exists,
the collection will be used as is.
A note on concurrency: as a rule of thumb, on a typical client machine
it is suggested to keep the quantity
bulk_insert_batch_concurrency * bulk_insert_overwrite_concurrency
much below 1000 to avoid exhausting the client multithreading/networking
resources. The hardcoded defaults are somewhat conservative to meet
most machines' specs, but a sensible choice to test may be:
bulk_insert_batch_concurrency = 80
bulk_insert_overwrite_concurrency = 10
A bit of experimentation is required to nail the best results here,
depending on both the machine/network specs and the expected workload
(specifically, how often a write is an update of an existing id).
Remember you can pass concurrency settings to individual calls to
add_texts and add_documents as well.
"""
@staticmethod @staticmethod
def _filter_to_metadata(filter_dict: Optional[Dict[str, Any]]) -> Dict[str, Any]: def _filter_to_metadata(filter_dict: Optional[Dict[str, Any]]) -> Dict[str, Any]:
if filter_dict is None: if filter_dict is None:
@ -173,8 +106,71 @@ class AstraDB(VectorStore):
setup_mode: SetupMode = SetupMode.SYNC, setup_mode: SetupMode = SetupMode.SYNC,
pre_delete_collection: bool = False, pre_delete_collection: bool = False,
) -> None: ) -> None:
""" """Wrapper around DataStax Astra DB for vector-store workloads.
Create an AstraDB vector store object. See class docstring for help.
For quickstart and details, visit
https://docs.datastax.com/en/astra/astra-db-vector/
Example:
.. code-block:: python
from langchain_community.vectorstores import AstraDB
from langchain_openai.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
vectorstore = AstraDB(
embedding=embeddings,
collection_name="my_store",
token="AstraCS:...",
api_endpoint="https://<DB-ID>-<REGION>.apps.astra.datastax.com"
)
vectorstore.add_texts(["Giraffes", "All good here"])
results = vectorstore.similarity_search("Everything's ok", k=1)
Args:
embedding: embedding function to use.
collection_name: name of the Astra DB collection to create/use.
token: API token for Astra DB usage.
api_endpoint: full URL to the API endpoint, such as
`https://<DB-ID>-us-east1.apps.astra.datastax.com`.
astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance.
async_astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
namespace: namespace (aka keyspace) where the collection is created.
Defaults to the database's "default namespace".
metric: similarity function to use out of those available in Astra DB.
If left out, it will use Astra DB API's defaults (i.e. "cosine" - but,
for performance reasons, "dot_product" is suggested if embeddings are
normalized to one).
batch_size: Size of batches for bulk insertions.
bulk_insert_batch_concurrency: Number of threads or coroutines to insert
batches concurrently.
bulk_insert_overwrite_concurrency: Number of threads or coroutines in a
batch to insert pre-existing entries.
bulk_delete_concurrency: Number of threads (for deleting multiple rows
concurrently).
pre_delete_collection: whether to delete the collection before creating it.
If False and the collection already exists, the collection will be used
as is.
Note:
For concurrency in synchronous :meth:`~add_texts`:, as a rule of thumb, on a
typical client machine it is suggested to keep the quantity
bulk_insert_batch_concurrency * bulk_insert_overwrite_concurrency
much below 1000 to avoid exhausting the client multithreading/networking
resources. The hardcoded defaults are somewhat conservative to meet
most machines' specs, but a sensible choice to test may be:
- bulk_insert_batch_concurrency = 80
- bulk_insert_overwrite_concurrency = 10
A bit of experimentation is required to nail the best results here,
depending on both the machine/network specs and the expected workload
(specifically, how often a write is an update of an existing id).
Remember you can pass concurrency settings to individual calls to
:meth:`~add_texts` and :meth:`~add_documents` as well.
""" """
self.embedding = embedding self.embedding = embedding
self.collection_name = collection_name self.collection_name = collection_name
@ -253,8 +249,13 @@ class AstraDB(VectorStore):
def delete_by_document_id(self, document_id: str) -> bool: def delete_by_document_id(self, document_id: str) -> bool:
""" """
Remove a single document from the store, given its document_id (str). Remove a single document from the store, given its document ID.
Return True if a document has indeed been deleted, False if ID not found.
Args:
document_id: The document ID
Returns
True if a document has indeed been deleted, False if ID not found.
""" """
self.astra_env.ensure_db_setup() self.astra_env.ensure_db_setup()
deletion_response = self.collection.delete_one(document_id) # type: ignore[union-attr] deletion_response = self.collection.delete_one(document_id) # type: ignore[union-attr]
@ -264,8 +265,13 @@ class AstraDB(VectorStore):
async def adelete_by_document_id(self, document_id: str) -> bool: async def adelete_by_document_id(self, document_id: str) -> bool:
""" """
Remove a single document from the store, given its document_id (str). Remove a single document from the store, given its document ID.
Return True if a document has indeed been deleted, False if ID not found.
Args:
document_id: The document ID
Returns
True if a document has indeed been deleted, False if ID not found.
""" """
await self.astra_env.aensure_db_setup() await self.astra_env.aensure_db_setup()
deletion_response = await self.async_collection.delete_one(document_id) deletion_response = await self.async_collection.delete_one(document_id)
@ -282,13 +288,12 @@ class AstraDB(VectorStore):
"""Delete by vector ids. """Delete by vector ids.
Args: Args:
ids (Optional[List[str]]): List of ids to delete. ids: List of ids to delete.
concurrency (Optional[int]): max number of threads issuing concurrency: max number of threads issuing single-doc delete requests.
single-doc delete requests. Defaults to instance-level setting. Defaults to instance-level setting.
Returns: Returns:
Optional[bool]: True if deletion is successful, True if deletion is successful, False otherwise.
False otherwise, None if not implemented.
""" """
if kwargs: if kwargs:
@ -317,17 +322,16 @@ class AstraDB(VectorStore):
concurrency: Optional[int] = None, concurrency: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> Optional[bool]: ) -> Optional[bool]:
"""Delete by vector ID or other criteria. """Delete by vector ids.
Args: Args:
ids: List of ids to delete. ids: List of ids to delete.
concurrency (Optional[int]): max number of concurrent delete queries. concurrency: max concurrency of single-doc delete requests.
Defaults to instance-level setting. Defaults to instance-level setting.
**kwargs: Other keyword arguments that subclasses might use. **kwargs: Other keyword arguments that subclasses might use.
Returns: Returns:
Optional[bool]: True if deletion is successful, True if deletion is successful, False otherwise.
False otherwise, None if not implemented.
""" """
if kwargs: if kwargs:
warnings.warn( warnings.warn(
@ -348,7 +352,7 @@ class AstraDB(VectorStore):
def delete_collection(self) -> None: def delete_collection(self) -> None:
""" """
Completely delete the collection from the database (as opposed Completely delete the collection from the database (as opposed
to 'clear()', which empties it only). to :meth:`~clear`, which empties it only).
Stored data is lost and unrecoverable, resources are freed. Stored data is lost and unrecoverable, resources are freed.
Use with caution. Use with caution.
""" """
@ -360,7 +364,7 @@ class AstraDB(VectorStore):
async def adelete_collection(self) -> None: async def adelete_collection(self) -> None:
""" """
Completely delete the collection from the database (as opposed Completely delete the collection from the database (as opposed
to 'clear()', which empties it only). to :meth:`~aclear`, which empties it only).
Stored data is lost and unrecoverable, resources are freed. Stored data is lost and unrecoverable, resources are freed.
Use with caution. Use with caution.
""" """
@ -450,28 +454,29 @@ class AstraDB(VectorStore):
will be replaced. will be replaced.
Args: Args:
texts (Iterable[str]): Texts to add to the vectorstore. texts: Texts to add to the vectorstore.
metadatas (Optional[List[dict]], optional): Optional list of metadatas. metadatas: Optional list of metadatas.
ids (Optional[List[str]], optional): Optional list of ids. ids: Optional list of ids.
batch_size (Optional[int]): Number of documents in each API call. batch_size: Number of documents in each API call.
Check the underlying Astra DB HTTP API specs for the max value Check the underlying Astra DB HTTP API specs for the max value
(20 at the time of writing this). If not provided, defaults (20 at the time of writing this). If not provided, defaults
to the instance-level setting. to the instance-level setting.
batch_concurrency (Optional[int]): number of threads to process batch_concurrency: number of threads to process
insertion batches concurrently. Defaults to instance-level insertion batches concurrently. Defaults to instance-level
setting if not provided. setting if not provided.
overwrite_concurrency (Optional[int]): number of threads to process overwrite_concurrency: number of threads to process
pre-existing documents in each batch (which require individual pre-existing documents in each batch (which require individual
API calls). Defaults to instance-level setting if not provided. API calls). Defaults to instance-level setting if not provided.
A note on metadata: there are constraints on the allowed field names Note:
in this dictionary, coming from the underlying Astra DB API. There are constraints on the allowed field names
For instance, the `$` (dollar sign) cannot be used in the dict keys. in the metadata dictionaries, coming from the underlying Astra DB API.
See this document for details: For instance, the `$` (dollar sign) cannot be used in the dict keys.
docs.datastax.com/en/astra-serverless/docs/develop/dev-with-json.html See this document for details:
https://docs.datastax.com/en/astra/astra-db-vector/api-reference/data-api.html
Returns: Returns:
List[str]: List of ids of the added texts. The list of ids of the added texts.
""" """
if kwargs: if kwargs:
@ -488,7 +493,7 @@ class AstraDB(VectorStore):
) )
def _handle_batch(document_batch: List[DocDict]) -> List[str]: def _handle_batch(document_batch: List[DocDict]) -> List[str]:
im_result = self.collection.insert_many( # type: ignore[union-attr] im_result = self.collection.insert_many(
documents=document_batch, documents=document_batch,
options={"ordered": False}, options={"ordered": False},
partial_failures_allowed=True, partial_failures_allowed=True,
@ -498,7 +503,7 @@ class AstraDB(VectorStore):
) )
def _handle_missing_document(missing_document: DocDict) -> str: def _handle_missing_document(missing_document: DocDict) -> str:
replacement_result = self.collection.find_one_and_replace( # type: ignore[union-attr] replacement_result = self.collection.find_one_and_replace(
filter={"_id": missing_document["_id"]}, filter={"_id": missing_document["_id"]},
replacement=missing_document, replacement=missing_document,
) )
@ -544,27 +549,29 @@ class AstraDB(VectorStore):
will be replaced. will be replaced.
Args: Args:
texts (Iterable[str]): Texts to add to the vectorstore. texts: Texts to add to the vectorstore.
metadatas (Optional[List[dict]], optional): Optional list of metadatas. metadatas: Optional list of metadatas.
ids (Optional[List[str]], optional): Optional list of ids. ids: Optional list of ids.
batch_size (Optional[int]): Number of documents in each API call. batch_size: Number of documents in each API call.
Check the underlying Astra DB HTTP API specs for the max value Check the underlying Astra DB HTTP API specs for the max value
(20 at the time of writing this). If not provided, defaults (20 at the time of writing this). If not provided, defaults
to the instance-level setting. to the instance-level setting.
batch_concurrency (Optional[int]): number of concurrent batch insertions. batch_concurrency: number of threads to process
Defaults to instance-level setting if not provided. insertion batches concurrently. Defaults to instance-level
overwrite_concurrency (Optional[int]): number of concurrent API calls to setting if not provided.
process pre-existing documents in each batch. overwrite_concurrency: number of threads to process
Defaults to instance-level setting if not provided. pre-existing documents in each batch (which require individual
API calls). Defaults to instance-level setting if not provided.
A note on metadata: there are constraints on the allowed field names
in this dictionary, coming from the underlying Astra DB API. Note:
For instance, the `$` (dollar sign) cannot be used in the dict keys. There are constraints on the allowed field names
See this document for details: in the metadata dictionaries, coming from the underlying Astra DB API.
docs.datastax.com/en/astra-serverless/docs/develop/dev-with-json.html For instance, the `$` (dollar sign) cannot be used in the dict keys.
See this document for details:
https://docs.datastax.com/en/astra/astra-db-vector/api-reference/data-api.html
Returns: Returns:
List[str]: List of ids of the added texts. The list of ids of the added texts.
""" """
if kwargs: if kwargs:
warnings.warn( warnings.warn(
@ -580,7 +587,7 @@ class AstraDB(VectorStore):
) )
async def _handle_batch(document_batch: List[DocDict]) -> List[str]: async def _handle_batch(document_batch: List[DocDict]) -> List[str]:
im_result = await self.async_collection.insert_many( # type: ignore[union-attr] im_result = await self.async_collection.insert_many(
documents=document_batch, documents=document_batch,
options={"ordered": False}, options={"ordered": False},
partial_failures_allowed=True, partial_failures_allowed=True,
@ -590,7 +597,7 @@ class AstraDB(VectorStore):
) )
async def _handle_missing_document(missing_document: DocDict) -> str: async def _handle_missing_document(missing_document: DocDict) -> str:
replacement_result = await self.async_collection.find_one_and_replace( # type: ignore[union-attr] replacement_result = await self.async_collection.find_one_and_replace(
filter={"_id": missing_document["_id"]}, filter={"_id": missing_document["_id"]},
replacement=missing_document, replacement=missing_document,
) )
@ -625,19 +632,21 @@ class AstraDB(VectorStore):
k: int = 4, k: int = 4,
filter: Optional[Dict[str, Any]] = None, filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float, str]]: ) -> List[Tuple[Document, float, str]]:
"""Return docs most similar to embedding vector. """Return docs most similar to embedding vector with score and id.
Args: Args:
embedding (str): Embedding to look up documents similar to. embedding: Embedding to look up documents similar to.
k (int): Number of Documents to return. Defaults to 4. k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns: Returns:
List of (Document, score, id), the most similar to the query vector. The list of (Document, score, id), the most similar to the query vector.
""" """
self.astra_env.ensure_db_setup() self.astra_env.ensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter) metadata_parameter = self._filter_to_metadata(filter)
# #
hits = list( hits = list(
self.collection.paginated_find( # type: ignore[union-attr] self.collection.paginated_find(
filter=metadata_parameter, filter=metadata_parameter,
sort={"$vector": embedding}, sort={"$vector": embedding},
options={"limit": k, "includeSimilarity": True}, options={"limit": k, "includeSimilarity": True},
@ -667,13 +676,15 @@ class AstraDB(VectorStore):
k: int = 4, k: int = 4,
filter: Optional[Dict[str, Any]] = None, filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float, str]]: ) -> List[Tuple[Document, float, str]]:
"""Return docs most similar to embedding vector. """Return docs most similar to embedding vector with score and id.
Args: Args:
embedding (str): Embedding to look up documents similar to. embedding: Embedding to look up documents similar to.
k (int): Number of Documents to return. Defaults to 4. k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns: Returns:
List of (Document, score, id), the most similar to the query vector. The list of (Document, score, id), the most similar to the query vector.
""" """
await self.astra_env.aensure_db_setup() await self.astra_env.aensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter) metadata_parameter = self._filter_to_metadata(filter)
@ -705,6 +716,16 @@ class AstraDB(VectorStore):
k: int = 4, k: int = 4,
filter: Optional[Dict[str, Any]] = None, filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float, str]]: ) -> List[Tuple[Document, float, str]]:
"""Return docs most similar to the query with score and id.
Args:
query: Query to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns:
The list of (Document, score, id), the most similar to the query.
"""
embedding_vector = self.embedding.embed_query(query) embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_with_score_id_by_vector( return self.similarity_search_with_score_id_by_vector(
embedding=embedding_vector, embedding=embedding_vector,
@ -718,6 +739,16 @@ class AstraDB(VectorStore):
k: int = 4, k: int = 4,
filter: Optional[Dict[str, Any]] = None, filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float, str]]: ) -> List[Tuple[Document, float, str]]:
"""Return docs most similar to the query with score and id.
Args:
query: Query to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns:
The list of (Document, score, id), the most similar to the query.
"""
embedding_vector = await self.embedding.aembed_query(query) embedding_vector = await self.embedding.aembed_query(query)
return await self.asimilarity_search_with_score_id_by_vector( return await self.asimilarity_search_with_score_id_by_vector(
embedding=embedding_vector, embedding=embedding_vector,
@ -731,13 +762,15 @@ class AstraDB(VectorStore):
k: int = 4, k: int = 4,
filter: Optional[Dict[str, Any]] = None, filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
"""Return docs most similar to embedding vector. """Return docs most similar to embedding vector with score.
Args: Args:
embedding (str): Embedding to look up documents similar to. embedding: Embedding to look up documents similar to.
k (int): Number of Documents to return. Defaults to 4. k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns: Returns:
List of (Document, score), the most similar to the query vector. The list of (Document, score), the most similar to the query vector.
""" """
return [ return [
(doc, score) (doc, score)
@ -754,13 +787,15 @@ class AstraDB(VectorStore):
k: int = 4, k: int = 4,
filter: Optional[Dict[str, Any]] = None, filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
"""Return docs most similar to embedding vector. """Return docs most similar to embedding vector with score.
Args: Args:
embedding (str): Embedding to look up documents similar to. embedding: Embedding to look up documents similar to.
k (int): Number of Documents to return. Defaults to 4. k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns: Returns:
List of (Document, score), the most similar to the query vector. The list of (Document, score), the most similar to the query vector.
""" """
return [ return [
(doc, score) (doc, score)
@ -782,6 +817,16 @@ class AstraDB(VectorStore):
filter: Optional[Dict[str, Any]] = None, filter: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Return docs most similar to query.
Args:
query: Query to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns:
The list of Documents most similar to the query.
"""
embedding_vector = self.embedding.embed_query(query) embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_by_vector( return self.similarity_search_by_vector(
embedding_vector, embedding_vector,
@ -796,6 +841,16 @@ class AstraDB(VectorStore):
filter: Optional[Dict[str, Any]] = None, filter: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Return docs most similar to query.
Args:
query: Query to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns:
The list of Documents most similar to the query.
"""
embedding_vector = await self.embedding.aembed_query(query) embedding_vector = await self.embedding.aembed_query(query)
return await self.asimilarity_search_by_vector( return await self.asimilarity_search_by_vector(
embedding_vector, embedding_vector,
@ -810,6 +865,16 @@ class AstraDB(VectorStore):
filter: Optional[Dict[str, Any]] = None, filter: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Return docs most similar to embedding vector.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns:
The list of Documents most similar to the query vector.
"""
return [ return [
doc doc
for doc, _ in self.similarity_search_with_score_by_vector( for doc, _ in self.similarity_search_with_score_by_vector(
@ -826,6 +891,16 @@ class AstraDB(VectorStore):
filter: Optional[Dict[str, Any]] = None, filter: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Return docs most similar to embedding vector.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns:
The list of Documents most similar to the query vector.
"""
return [ return [
doc doc
for doc, _ in await self.asimilarity_search_with_score_by_vector( for doc, _ in await self.asimilarity_search_with_score_by_vector(
@ -841,6 +916,16 @@ class AstraDB(VectorStore):
k: int = 4, k: int = 4,
filter: Optional[Dict[str, Any]] = None, filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
"""Return docs most similar to query with score.
Args:
query: Query to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns:
The list of (Document, score), the most similar to the query vector.
"""
embedding_vector = self.embedding.embed_query(query) embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_with_score_by_vector( return self.similarity_search_with_score_by_vector(
embedding_vector, embedding_vector,
@ -854,6 +939,16 @@ class AstraDB(VectorStore):
k: int = 4, k: int = 4,
filter: Optional[Dict[str, Any]] = None, filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
"""Return docs most similar to query with score.
Args:
query: Query to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns:
The list of (Document, score), the most similar to the query vector.
"""
embedding_vector = await self.embedding.aembed_query(query) embedding_vector = await self.embedding.aembed_query(query)
return await self.asimilarity_search_with_score_by_vector( return await self.asimilarity_search_with_score_by_vector(
embedding_vector, embedding_vector,
@ -862,7 +957,9 @@ class AstraDB(VectorStore):
) )
@staticmethod @staticmethod
def _get_mmr_hits(embedding, k, lambda_mult, prefetch_hits): # type: ignore[no-untyped-def] def _get_mmr_hits(
embedding: List[float], k: int, lambda_mult: float, prefetch_hits: List[DocDict]
) -> List[Document]:
mmr_chosen_indices = maximal_marginal_relevance( mmr_chosen_indices = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32), np.array(embedding, dtype=np.float32),
[prefetch_hit["$vector"] for prefetch_hit in prefetch_hits], [prefetch_hit["$vector"] for prefetch_hit in prefetch_hits],
@ -892,23 +989,27 @@ class AstraDB(VectorStore):
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Return docs selected using the maximal marginal relevance. """Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents. among selected documents.
Args: Args:
embedding: Embedding to look up documents similar to. embedding: Embedding to look up documents similar to.
k: Number of Documents to return. k: Number of Documents to return.
fetch_k: Number of Documents to fetch to pass to MMR algorithm. fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity. to maximum diversity and 1 to minimum diversity.
filter: Filter on the metadata to apply.
Returns: Returns:
List of Documents selected by maximal marginal relevance. The list of Documents selected by maximal marginal relevance.
""" """
self.astra_env.ensure_db_setup() self.astra_env.ensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter) metadata_parameter = self._filter_to_metadata(filter)
prefetch_hits = list( prefetch_hits = list(
self.collection.paginated_find( # type: ignore[union-attr] self.collection.paginated_find(
filter=metadata_parameter, filter=metadata_parameter,
sort={"$vector": embedding}, sort={"$vector": embedding},
options={"limit": fetch_k, "includeSimilarity": True}, options={"limit": fetch_k, "includeSimilarity": True},
@ -933,17 +1034,21 @@ class AstraDB(VectorStore):
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Return docs selected using the maximal marginal relevance. """Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents. among selected documents.
Args: Args:
embedding: Embedding to look up documents similar to. embedding: Embedding to look up documents similar to.
k: Number of Documents to return. k: Number of Documents to return.
fetch_k: Number of Documents to fetch to pass to MMR algorithm. fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity. to maximum diversity and 1 to minimum diversity.
filter: Filter on the metadata to apply.
Returns: Returns:
List of Documents selected by maximal marginal relevance. The list of Documents selected by maximal marginal relevance.
""" """
await self.astra_env.aensure_db_setup() await self.astra_env.aensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter) metadata_parameter = self._filter_to_metadata(filter)
@ -975,18 +1080,21 @@ class AstraDB(VectorStore):
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Return docs selected using the maximal marginal relevance. """Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents. among selected documents.
Args: Args:
query (str): Text to look up documents similar to. query: Query to look up documents similar to.
k (int = 4): Number of Documents to return. k: Number of Documents to return.
fetch_k (int = 20): Number of Documents to fetch to pass to MMR algorithm. fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult (float = 0.5): Number between 0 and 1 that determines the degree lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity. to maximum diversity and 1 to minimum diversity.
Optional. filter: Filter on the metadata to apply.
Returns: Returns:
List of Documents selected by maximal marginal relevance. The list of Documents selected by maximal marginal relevance.
""" """
embedding_vector = self.embedding.embed_query(query) embedding_vector = self.embedding.embed_query(query)
return self.max_marginal_relevance_search_by_vector( return self.max_marginal_relevance_search_by_vector(
@ -1007,18 +1115,21 @@ class AstraDB(VectorStore):
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Return docs selected using the maximal marginal relevance. """Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents. among selected documents.
Args: Args:
query (str): Text to look up documents similar to. query: Query to look up documents similar to.
k (int = 4): Number of Documents to return. k: Number of Documents to return.
fetch_k (int = 20): Number of Documents to fetch to pass to MMR algorithm. fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult (float = 0.5): Number between 0 and 1 that determines the degree lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity. to maximum diversity and 1 to minimum diversity.
Optional. filter: Filter on the metadata to apply.
Returns: Returns:
List of Documents selected by maximal marginal relevance. The list of Documents selected by maximal marginal relevance.
""" """
embedding_vector = await self.embedding.aembed_query(query) embedding_vector = await self.embedding.aembed_query(query)
return await self.amax_marginal_relevance_search_by_vector( return await self.amax_marginal_relevance_search_by_vector(
@ -1096,12 +1207,12 @@ class AstraDB(VectorStore):
"""Create an Astra DB vectorstore from raw texts. """Create an Astra DB vectorstore from raw texts.
Args: Args:
texts (List[str]): the texts to insert. texts: the texts to insert.
embedding (Embeddings): the embedding function to use in the store. embedding: the embedding function to use in the store.
metadatas (Optional[List[dict]]): metadata dicts for the texts. metadatas: metadata dicts for the texts.
ids (Optional[List[str]]): ids to associate to the texts. ids: ids to associate to the texts.
*Additional arguments*: you can pass any argument that you would **kwargs: you can pass any argument that you would
to 'add_texts' and/or to the 'AstraDB' class constructor to :meth:`~add_texts` and/or to the 'AstraDB' constructor
(see these methods for details). These arguments will be (see these methods for details). These arguments will be
routed to the respective methods as they are. routed to the respective methods as they are.
@ -1131,12 +1242,12 @@ class AstraDB(VectorStore):
"""Create an Astra DB vectorstore from raw texts. """Create an Astra DB vectorstore from raw texts.
Args: Args:
texts (List[str]): the texts to insert. texts: the texts to insert.
embedding (Embeddings): the embedding function to use in the store. embedding: the embedding function to use in the store.
metadatas (Optional[List[dict]]): metadata dicts for the texts. metadatas: metadata dicts for the texts.
ids (Optional[List[str]]): ids to associate to the texts. ids: ids to associate to the texts.
*Additional arguments*: you can pass any argument that you would **kwargs: you can pass any argument that you would
to 'add_texts' and/or to the 'AstraDB' class constructor to :meth:`~add_texts` and/or to the 'AstraDB' constructor
(see these methods for details). These arguments will be (see these methods for details). These arguments will be
routed to the respective methods as they are. routed to the respective methods as they are.

Loading…
Cancel
Save