@ -1,9 +1,12 @@
from __future__ import annotations
from __future__ import annotations
import asyncio
import uuid
import uuid
import warnings
import warnings
from asyncio import Task
from concurrent . futures import ThreadPoolExecutor
from concurrent . futures import ThreadPoolExecutor
from typing import (
from typing import (
TYPE_CHECKING ,
Any ,
Any ,
Callable ,
Callable ,
Dict ,
Dict ,
@ -19,11 +22,17 @@ from typing import (
import numpy as np
import numpy as np
from langchain_core . documents import Document
from langchain_core . documents import Document
from langchain_core . embeddings import Embeddings
from langchain_core . embeddings import Embeddings
from langchain_core . runnables import run_in_executor
from langchain_core . runnables . utils import gather_with_concurrency
from langchain_core . utils . iter import batch_iterate
from langchain_core . utils . iter import batch_iterate
from langchain_core . vectorstores import VectorStore
from langchain_core . vectorstores import VectorStore
from langchain_community . vectorstores . utils import maximal_marginal_relevance
from langchain_community . vectorstores . utils import maximal_marginal_relevance
if TYPE_CHECKING :
from astrapy . db import AstraDB as LibAstraDB
from astrapy . db import AsyncAstraDB
ADBVST = TypeVar ( " ADBVST " , bound = " AstraDB " )
ADBVST = TypeVar ( " ADBVST " , bound = " AstraDB " )
T = TypeVar ( " T " )
T = TypeVar ( " T " )
U = TypeVar ( " U " )
U = TypeVar ( " U " )
@ -144,7 +153,8 @@ class AstraDB(VectorStore):
collection_name : str ,
collection_name : str ,
token : Optional [ str ] = None ,
token : Optional [ str ] = None ,
api_endpoint : Optional [ str ] = None ,
api_endpoint : Optional [ str ] = None ,
astra_db_client : Optional [ Any ] = None , # 'astrapy.db.AstraDB' if passed
astra_db_client : Optional [ LibAstraDB ] = None ,
async_astra_db_client : Optional [ AsyncAstraDB ] = None ,
namespace : Optional [ str ] = None ,
namespace : Optional [ str ] = None ,
metric : Optional [ str ] = None ,
metric : Optional [ str ] = None ,
batch_size : Optional [ int ] = None ,
batch_size : Optional [ int ] = None ,
@ -157,12 +167,8 @@ class AstraDB(VectorStore):
Create an AstraDB vector store object . See class docstring for help .
Create an AstraDB vector store object . See class docstring for help .
"""
"""
try :
try :
from astrapy . db import (
from astrapy . db import AstraDB as LibAstraDB
AstraDB as LibAstraDB ,
from astrapy . db import AstraDBCollection
)
from astrapy . db import (
AstraDBCollection as LibAstraDBCollection ,
)
except ( ImportError , ModuleNotFoundError ) :
except ( ImportError , ModuleNotFoundError ) :
raise ImportError (
raise ImportError (
" Could not import a recent astrapy python package. "
" Could not import a recent astrapy python package. "
@ -170,11 +176,11 @@ class AstraDB(VectorStore):
)
)
# Conflicting-arg checks:
# Conflicting-arg checks:
if astra_db_client is not None :
if astra_db_client is not None or async_astra_db_client is not None :
if token is not None or api_endpoint is not None :
if token is not None or api_endpoint is not None :
raise ValueError (
raise ValueError (
" You cannot pass ' astra_db_client ' to AstraDB if passing "
" You cannot pass ' astra_db_client ' or ' async_astra_db_client ' to "
" 'token ' and ' api_endpoint ' . "
" AstraDB if passing 'token ' and ' api_endpoint ' . "
)
)
self . embedding = embedding
self . embedding = embedding
@ -198,23 +204,69 @@ class AstraDB(VectorStore):
self . _embedding_dimension : Optional [ int ] = None
self . _embedding_dimension : Optional [ int ] = None
self . metric = metric
self . metric = metric
if astra_db_client is not None :
self . astra_db = astra_db_client
self . astra_db = astra_db_client
self . async_astra_db = async_astra_db_client
else :
self . collection = None
self . async_collection = None
if token and api_endpoint :
self . astra_db = LibAstraDB (
self . astra_db = LibAstraDB (
token = self . token ,
token = self . token ,
api_endpoint = self . api_endpoint ,
api_endpoint = self . api_endpoint ,
namespace = self . namespace ,
namespace = self . namespace ,
)
)
if not pre_delete_collection :
try :
self . _provision_collection ( )
from astrapy . db import AsyncAstraDB
else :
self . clear ( )
self . collection = LibAstraDBCollection (
self . async_astra_db = AsyncAstraDB (
collection_name = self . collection_name ,
token = self . token ,
astra_db = self . astra_db ,
api_endpoint = self . api_endpoint ,
)
namespace = self . namespace ,
)
except ( ImportError , ModuleNotFoundError ) :
pass
if self . astra_db is not None :
self . collection = AstraDBCollection (
collection_name = self . collection_name ,
astra_db = self . astra_db ,
)
self . async_setup_db_task : Optional [ Task ] = None
if self . async_astra_db is not None :
from astrapy . db import AsyncAstraDBCollection
self . async_collection = AsyncAstraDBCollection (
collection_name = self . collection_name ,
astra_db = self . async_astra_db ,
)
try :
self . async_setup_db_task = asyncio . create_task (
self . _setup_db ( pre_delete_collection )
)
except RuntimeError :
pass
if self . async_setup_db_task is None :
if not pre_delete_collection :
self . _provision_collection ( )
else :
self . clear ( )
def _ensure_astra_db_client ( self ) :
if not self . astra_db :
raise ValueError ( " Missing AstraDB client " )
async def _setup_db ( self , pre_delete_collection : bool ) - > None :
if pre_delete_collection :
await self . async_astra_db . delete_collection (
collection_name = self . collection_name ,
)
await self . _aprovision_collection ( )
async def _ensure_db_setup ( self ) - > None :
if self . async_setup_db_task :
await self . async_setup_db_task
def _get_embedding_dimension ( self ) - > int :
def _get_embedding_dimension ( self ) - > int :
if self . _embedding_dimension is None :
if self . _embedding_dimension is None :
@ -223,31 +275,31 @@ class AstraDB(VectorStore):
)
)
return self . _embedding_dimension
return self . _embedding_dimension
def _ drop _collection( self ) - > None :
def _ provision _collection( self ) - > None :
"""
"""
Drop the collection from storage .
Run the API invocation to create the collection on the backend .
This is meant as an i nternal- usage method , no members
I nternal- usage method , no object members are set ,
are set other than actual deletion on the backend .
other than working on the underlying actual storage .
"""
"""
_ = self . astra_db . delete_collection (
self . astra_db . create_collection (
dimension = self . _get_embedding_dimension ( ) ,
collection_name = self . collection_name ,
collection_name = self . collection_name ,
metric = self . metric ,
)
)
return None
def _ provision_collection( self ) - > None :
async def _ a provision_collection( self ) - > None :
"""
"""
Run the API invocation to create the collection on the backend .
Run the API invocation to create the collection on the backend .
Internal - usage method , no object members are set ,
Internal - usage method , no object members are set ,
other than working on the underlying actual storage .
other than working on the underlying actual storage .
"""
"""
_ = self . astra_db. create_collection (
await self . async_ astra_db. create_collection (
dimension = self . _get_embedding_dimension ( ) ,
dimension = self . _get_embedding_dimension ( ) ,
collection_name = self . collection_name ,
collection_name = self . collection_name ,
metric = self . metric ,
metric = self . metric ,
)
)
return None
@property
@property
def embeddings ( self ) - > Embeddings :
def embeddings ( self ) - > Embeddings :
@ -268,16 +320,36 @@ class AstraDB(VectorStore):
def clear ( self ) - > None :
def clear ( self ) - > None :
""" Empty the collection of all its stored entries. """
""" Empty the collection of all its stored entries. """
self . _drop _collection( )
self . delete _collection( )
self . _provision_collection ( )
self . _provision_collection ( )
return None
async def aclear ( self ) - > None :
""" Empty the collection of all its stored entries. """
await self . _ensure_db_setup ( )
if not self . async_astra_db :
await run_in_executor ( None , self . clear )
await self . async_collection . delete_many ( { } )
def delete_by_document_id ( self , document_id : str ) - > bool :
def delete_by_document_id ( self , document_id : str ) - > bool :
"""
"""
Remove a single document from the store , given its document_id ( str ) .
Remove a single document from the store , given its document_id ( str ) .
Return True if a document has indeed been deleted , False if ID not found .
Return True if a document has indeed been deleted , False if ID not found .
"""
"""
deletion_response = self . collection . delete ( document_id )
self . _ensure_astra_db_client ( )
deletion_response = self . collection . delete_one ( document_id )
return ( ( deletion_response or { } ) . get ( " status " ) or { } ) . get (
" deletedCount " , 0
) == 1
async def adelete_by_document_id ( self , document_id : str ) - > bool :
"""
Remove a single document from the store , given its document_id ( str ) .
Return True if a document has indeed been deleted , False if ID not found .
"""
await self . _ensure_db_setup ( )
if not self . async_collection :
return await run_in_executor ( None , self . delete_by_document_id , document_id )
deletion_response = await self . async_collection . delete_one ( document_id )
return ( ( deletion_response or { } ) . get ( " status " ) or { } ) . get (
return ( ( deletion_response or { } ) . get ( " status " ) or { } ) . get (
" deletedCount " , 0
" deletedCount " , 0
) == 1
) == 1
@ -320,6 +392,40 @@ class AstraDB(VectorStore):
)
)
return True
return True
async def adelete (
self ,
ids : Optional [ List [ str ] ] = None ,
concurrency : Optional [ int ] = None ,
* * kwargs : Any ,
) - > Optional [ bool ] :
""" Delete by vector ID or other criteria.
Args :
ids : List of ids to delete .
concurrency ( Optional [ int ] ) : max number of concurrent delete queries .
Defaults to instance - level setting .
* * kwargs : Other keyword arguments that subclasses might use .
Returns :
Optional [ bool ] : True if deletion is successful ,
False otherwise , None if not implemented .
"""
if kwargs :
warnings . warn (
" Method ' adelete ' of AstraDB vector store invoked with "
f " unsupported arguments ( { ' , ' . join ( sorted ( kwargs . keys ( ) ) ) } ), "
" which will be ignored. "
)
if ids is None :
raise ValueError ( " No ids provided to delete. " )
return all (
await gather_with_concurrency (
concurrency , * [ self . adelete_by_document_id ( doc_id ) for doc_id in ids ]
)
)
def delete_collection ( self ) - > None :
def delete_collection ( self ) - > None :
"""
"""
Completely delete the collection from the database ( as opposed
Completely delete the collection from the database ( as opposed
@ -327,8 +433,88 @@ class AstraDB(VectorStore):
Stored data is lost and unrecoverable , resources are freed .
Stored data is lost and unrecoverable , resources are freed .
Use with caution .
Use with caution .
"""
"""
self . _drop_collection ( )
self . _ensure_astra_db_client ( )
return None
self . astra_db . delete_collection (
collection_name = self . collection_name ,
)
async def adelete_collection ( self ) - > None :
"""
Completely delete the collection from the database ( as opposed
to ' clear() ' , which empties it only ) .
Stored data is lost and unrecoverable , resources are freed .
Use with caution .
"""
await self . _ensure_db_setup ( )
if not self . async_astra_db :
await run_in_executor ( None , self . delete_collection )
await self . async_astra_db . delete_collection (
collection_name = self . collection_name ,
)
@staticmethod
def _get_documents_to_insert (
texts : Iterable [ str ] ,
embedding_vectors : List [ List [ float ] ] ,
metadatas : Optional [ List [ dict ] ] = None ,
ids : Optional [ List [ str ] ] = None ,
) - > List [ DocDict ] :
if ids is None :
ids = [ uuid . uuid4 ( ) . hex for _ in texts ]
if metadatas is None :
metadatas = [ { } for _ in texts ]
#
documents_to_insert = [
{
" content " : b_txt ,
" _id " : b_id ,
" $vector " : b_emb ,
" metadata " : b_md ,
}
for b_txt , b_emb , b_id , b_md in zip (
texts ,
embedding_vectors ,
ids ,
metadatas ,
)
]
# make unique by id, keeping the last
uniqued_documents_to_insert = _unique_list (
documents_to_insert [ : : - 1 ] ,
lambda document : document [ " _id " ] ,
) [ : : - 1 ]
return uniqued_documents_to_insert
@staticmethod
def _get_missing_from_batch (
document_batch : List [ DocDict ] , insert_result : Dict [ str , Any ]
) - > Tuple [ List [ str ] , List [ DocDict ] ] :
if " status " not in insert_result :
raise ValueError (
f " API Exception while running bulk insertion: { str ( insert_result ) } "
)
batch_inserted = insert_result [ " status " ] [ " insertedIds " ]
# estimation of the preexisting documents that failed
missed_inserted_ids = { document [ " _id " ] for document in document_batch } - set (
batch_inserted
)
errors = insert_result . get ( " errors " , [ ] )
# careful for other sources of error other than "doc already exists"
num_errors = len ( errors )
unexpected_errors = any (
error . get ( " errorCode " ) != " DOCUMENT_ALREADY_EXISTS " for error in errors
)
if num_errors != len ( missed_inserted_ids ) or unexpected_errors :
raise ValueError (
f " API Exception while running bulk insertion: { str ( errors ) } "
)
# deal with the missing insertions as upserts
missing_from_batch = [
document
for document in document_batch
if document [ " _id " ] in missed_inserted_ids
]
return batch_inserted , missing_from_batch
def add_texts (
def add_texts (
self ,
self ,
@ -377,36 +563,12 @@ class AstraDB(VectorStore):
f " unsupported arguments ( { ' , ' . join ( sorted ( kwargs . keys ( ) ) ) } ), "
f " unsupported arguments ( { ' , ' . join ( sorted ( kwargs . keys ( ) ) ) } ), "
" which will be ignored. "
" which will be ignored. "
)
)
self . _ensure_astra_db_client ( )
_texts = list ( texts )
embedding_vectors = self . embedding . embed_documents ( list ( texts ) )
if ids is None :
documents_to_insert = self . _get_documents_to_insert (
ids = [ uuid . uuid4 ( ) . hex for _ in _texts ]
texts , embedding_vectors , metadatas , ids
if metadatas is None :
)
metadatas = [ { } for _ in _texts ]
#
embedding_vectors = self . embedding . embed_documents ( _texts )
documents_to_insert = [
{
" content " : b_txt ,
" _id " : b_id ,
" $vector " : b_emb ,
" metadata " : b_md ,
}
for b_txt , b_emb , b_id , b_md in zip (
_texts ,
embedding_vectors ,
ids ,
metadatas ,
)
]
# make unique by id, keeping the last
uniqued_documents_to_insert = _unique_list (
documents_to_insert [ : : - 1 ] ,
lambda document : document [ " _id " ] ,
) [ : : - 1 ]
all_ids = [ ]
def _handle_batch ( document_batch : List [ DocDict ] ) - > List [ str ] :
def _handle_batch ( document_batch : List [ DocDict ] ) - > List [ str ] :
im_result = self . collection . insert_many (
im_result = self . collection . insert_many (
@ -414,33 +576,9 @@ class AstraDB(VectorStore):
options = { " ordered " : False } ,
options = { " ordered " : False } ,
partial_failures_allowed = True ,
partial_failures_allowed = True ,
)
)
if " status " not in im_result :
batch_inserted , missing_from_batch = self . _get_missing_from_batch (
raise ValueError (
document_batch , im_result
f " API Exception while running bulk insertion: { str ( im_result ) } "
)
)
batch_inserted = im_result [ " status " ] [ " insertedIds " ]
# estimation of the preexisting documents that failed
missed_inserted_ids = {
document [ " _id " ] for document in document_batch
} - set ( batch_inserted )
errors = im_result . get ( " errors " , [ ] )
# careful for other sources of error other than "doc already exists"
num_errors = len ( errors )
unexpected_errors = any (
error . get ( " errorCode " ) != " DOCUMENT_ALREADY_EXISTS " for error in errors
)
if num_errors != len ( missed_inserted_ids ) or unexpected_errors :
raise ValueError (
f " API Exception while running bulk insertion: { str ( errors ) } "
)
# deal with the missing insertions as upserts
missing_from_batch = [
document
for document in document_batch
if document [ " _id " ] in missed_inserted_ids
]
def _handle_missing_document ( missing_document : DocDict ) - > str :
def _handle_missing_document ( missing_document : DocDict ) - > str :
replacement_result = self . collection . find_one_and_replace (
replacement_result = self . collection . find_one_and_replace (
@ -459,9 +597,7 @@ class AstraDB(VectorStore):
missing_from_batch ,
missing_from_batch ,
)
)
)
)
return batch_inserted + batch_replaced
upsert_ids = batch_inserted + batch_replaced
return upsert_ids
_b_max_workers = batch_concurrency or self . bulk_insert_batch_concurrency
_b_max_workers = batch_concurrency or self . bulk_insert_batch_concurrency
with ThreadPoolExecutor ( max_workers = _b_max_workers ) as tpe :
with ThreadPoolExecutor ( max_workers = _b_max_workers ) as tpe :
@ -469,13 +605,111 @@ class AstraDB(VectorStore):
_handle_batch ,
_handle_batch ,
batch_iterate (
batch_iterate (
batch_size or self . batch_size ,
batch_size or self . batch_size ,
uniqued_ documents_to_insert,
documents_to_insert,
) ,
) ,
)
)
return [ iid for id_list in all_ids_nested for iid in id_list ]
async def aadd_texts (
self ,
texts : Iterable [ str ] ,
metadatas : Optional [ List [ dict ] ] = None ,
ids : Optional [ List [ str ] ] = None ,
* ,
batch_size : Optional [ int ] = None ,
batch_concurrency : Optional [ int ] = None ,
overwrite_concurrency : Optional [ int ] = None ,
* * kwargs : Any ,
) - > List [ str ] :
""" Run texts through the embeddings and add them to the vectorstore.
If passing explicit ids , those entries whose id is in the store already
will be replaced .
Args :
texts ( Iterable [ str ] ) : Texts to add to the vectorstore .
metadatas ( Optional [ List [ dict ] ] , optional ) : Optional list of metadatas .
ids ( Optional [ List [ str ] ] , optional ) : Optional list of ids .
batch_size ( Optional [ int ] ) : Number of documents in each API call .
Check the underlying Astra DB HTTP API specs for the max value
( 20 at the time of writing this ) . If not provided , defaults
to the instance - level setting .
batch_concurrency ( Optional [ int ] ) : number of concurrent batch insertions .
Defaults to instance - level setting if not provided .
overwrite_concurrency ( Optional [ int ] ) : number of concurrent API calls to
process pre - existing documents in each batch .
Defaults to instance - level setting if not provided .
A note on metadata : there are constraints on the allowed field names
in this dictionary , coming from the underlying Astra DB API .
For instance , the ` $ ` ( dollar sign ) cannot be used in the dict keys .
See this document for details :
docs . datastax . com / en / astra - serverless / docs / develop / dev - with - json . html
Returns :
List [ str ] : List of ids of the added texts .
"""
await self . _ensure_db_setup ( )
if not self . async_collection :
await super ( ) . aadd_texts (
texts ,
metadatas ,
ids = ids ,
batch_size = batch_size ,
batch_concurrency = batch_concurrency ,
overwrite_concurrency = overwrite_concurrency ,
)
if kwargs :
warnings . warn (
" Method ' aadd_texts ' of AstraDB vector store invoked with "
f " unsupported arguments ( { ' , ' . join ( sorted ( kwargs . keys ( ) ) ) } ), "
" which will be ignored. "
)
embedding_vectors = await self . embedding . aembed_documents ( list ( texts ) )
documents_to_insert = self . _get_documents_to_insert (
texts , embedding_vectors , metadatas , ids
)
async def _handle_batch ( document_batch : List [ DocDict ] ) - > List [ str ] :
im_result = await self . async_collection . insert_many (
documents = document_batch ,
options = { " ordered " : False } ,
partial_failures_allowed = True ,
)
batch_inserted , missing_from_batch = self . _get_missing_from_batch (
document_batch , im_result
)
async def _handle_missing_document ( missing_document : DocDict ) - > str :
replacement_result = await self . async_collection . find_one_and_replace (
filter = { " _id " : missing_document [ " _id " ] } ,
replacement = missing_document ,
)
return replacement_result [ " data " ] [ " document " ] [ " _id " ]
all_ids = [ iid for id_list in all_ids_nested for iid in id_list ]
_u_max_workers = (
overwrite_concurrency or self . bulk_insert_overwrite_concurrency
)
batch_replaced = await gather_with_concurrency (
_u_max_workers ,
* [ _handle_missing_document ( doc ) for doc in missing_from_batch ] ,
)
return batch_inserted + batch_replaced
_b_max_workers = batch_concurrency or self . bulk_insert_batch_concurrency
all_ids_nested = await gather_with_concurrency (
_b_max_workers ,
* [
_handle_batch ( batch )
for batch in batch_iterate (
batch_size or self . batch_size ,
documents_to_insert ,
)
] ,
)
return all_ids
return [ iid for id_list in all_ids _nested for iid in id_list ]
def similarity_search_with_score_id_by_vector (
def similarity_search_with_score_id_by_vector (
self ,
self ,
@ -491,6 +725,7 @@ class AstraDB(VectorStore):
Returns :
Returns :
List of ( Document , score , id ) , the most similar to the query vector .
List of ( Document , score , id ) , the most similar to the query vector .
"""
"""
self . _ensure_astra_db_client ( )
metadata_parameter = self . _filter_to_metadata ( filter )
metadata_parameter = self . _filter_to_metadata ( filter )
#
#
hits = list (
hits = list (
@ -518,6 +753,52 @@ class AstraDB(VectorStore):
for hit in hits
for hit in hits
]
]
async def asimilarity_search_with_score_id_by_vector (
self ,
embedding : List [ float ] ,
k : int = 4 ,
filter : Optional [ Dict [ str , Any ] ] = None ,
) - > List [ Tuple [ Document , float , str ] ] :
""" Return docs most similar to embedding vector.
Args :
embedding ( str ) : Embedding to look up documents similar to .
k ( int ) : Number of Documents to return . Defaults to 4.
Returns :
List of ( Document , score , id ) , the most similar to the query vector .
"""
await self . _ensure_db_setup ( )
if not self . async_collection :
return await run_in_executor (
None ,
self . asimilarity_search_with_score_id_by_vector ,
embedding ,
k ,
filter ,
)
metadata_parameter = self . _filter_to_metadata ( filter )
#
return [
(
Document (
page_content = hit [ " content " ] ,
metadata = hit [ " metadata " ] ,
) ,
hit [ " $similarity " ] ,
hit [ " _id " ] ,
)
async for hit in self . async_collection . paginated_find (
filter = metadata_parameter ,
sort = { " $vector " : embedding } ,
options = { " limit " : k , " includeSimilarity " : True } ,
projection = {
" _id " : 1 ,
" content " : 1 ,
" metadata " : 1 ,
} ,
)
]
def similarity_search_with_score_id (
def similarity_search_with_score_id (
self ,
self ,
query : str ,
query : str ,
@ -531,6 +812,19 @@ class AstraDB(VectorStore):
filter = filter ,
filter = filter ,
)
)
async def asimilarity_search_with_score_id (
self ,
query : str ,
k : int = 4 ,
filter : Optional [ Dict [ str , Any ] ] = None ,
) - > List [ Tuple [ Document , float , str ] ] :
embedding_vector = await self . embedding . aembed_query ( query )
return await self . asimilarity_search_with_score_id_by_vector (
embedding = embedding_vector ,
k = k ,
filter = filter ,
)
def similarity_search_with_score_by_vector (
def similarity_search_with_score_by_vector (
self ,
self ,
embedding : List [ float ] ,
embedding : List [ float ] ,
@ -554,6 +848,33 @@ class AstraDB(VectorStore):
)
)
]
]
async def asimilarity_search_with_score_by_vector (
self ,
embedding : List [ float ] ,
k : int = 4 ,
filter : Optional [ Dict [ str , Any ] ] = None ,
) - > List [ Tuple [ Document , float ] ] :
""" Return docs most similar to embedding vector.
Args :
embedding ( str ) : Embedding to look up documents similar to .
k ( int ) : Number of Documents to return . Defaults to 4.
Returns :
List of ( Document , score ) , the most similar to the query vector .
"""
return [
( doc , score )
for (
doc ,
score ,
doc_id ,
) in await self . asimilarity_search_with_score_id_by_vector (
embedding = embedding ,
k = k ,
filter = filter ,
)
]
def similarity_search (
def similarity_search (
self ,
self ,
query : str ,
query : str ,
@ -568,6 +889,20 @@ class AstraDB(VectorStore):
filter = filter ,
filter = filter ,
)
)
async def asimilarity_search (
self ,
query : str ,
k : int = 4 ,
filter : Optional [ Dict [ str , Any ] ] = None ,
* * kwargs : Any ,
) - > List [ Document ] :
embedding_vector = await self . embedding . aembed_query ( query )
return await self . asimilarity_search_by_vector (
embedding_vector ,
k ,
filter = filter ,
)
def similarity_search_by_vector (
def similarity_search_by_vector (
self ,
self ,
embedding : List [ float ] ,
embedding : List [ float ] ,
@ -584,6 +919,22 @@ class AstraDB(VectorStore):
)
)
]
]
async def asimilarity_search_by_vector (
self ,
embedding : List [ float ] ,
k : int = 4 ,
filter : Optional [ Dict [ str , Any ] ] = None ,
* * kwargs : Any ,
) - > List [ Document ] :
return [
doc
for doc , _ in await self . asimilarity_search_with_score_by_vector (
embedding ,
k ,
filter = filter ,
)
]
def similarity_search_with_score (
def similarity_search_with_score (
self ,
self ,
query : str ,
query : str ,
@ -597,6 +948,40 @@ class AstraDB(VectorStore):
filter = filter ,
filter = filter ,
)
)
async def asimilarity_search_with_score (
self ,
query : str ,
k : int = 4 ,
filter : Optional [ Dict [ str , Any ] ] = None ,
) - > List [ Tuple [ Document , float ] ] :
embedding_vector = await self . embedding . aembed_query ( query )
return await self . asimilarity_search_with_score_by_vector (
embedding_vector ,
k ,
filter = filter ,
)
@staticmethod
def _get_mmr_hits ( embedding , k , lambda_mult , prefetch_hits ) :
mmr_chosen_indices = maximal_marginal_relevance (
np . array ( embedding , dtype = np . float32 ) ,
[ prefetch_hit [ " $vector " ] for prefetch_hit in prefetch_hits ] ,
k = k ,
lambda_mult = lambda_mult ,
)
mmr_hits = [
prefetch_hit
for prefetch_index , prefetch_hit in enumerate ( prefetch_hits )
if prefetch_index in mmr_chosen_indices
]
return [
Document (
page_content = hit [ " content " ] ,
metadata = hit [ " metadata " ] ,
)
for hit in mmr_hits
]
def max_marginal_relevance_search_by_vector (
def max_marginal_relevance_search_by_vector (
self ,
self ,
embedding : List [ float ] ,
embedding : List [ float ] ,
@ -619,6 +1004,7 @@ class AstraDB(VectorStore):
Returns :
Returns :
List of Documents selected by maximal marginal relevance .
List of Documents selected by maximal marginal relevance .
"""
"""
self . _ensure_astra_db_client ( )
metadata_parameter = self . _filter_to_metadata ( filter )
metadata_parameter = self . _filter_to_metadata ( filter )
prefetch_hits = list (
prefetch_hits = list (
@ -635,25 +1021,61 @@ class AstraDB(VectorStore):
)
)
)
)
mmr_chosen_indices = maximal_marginal_relevance (
return self . _get_mmr_hits ( embedding , k , lambda_mult , prefetch_hits )
np . array ( embedding , dtype = np . float32 ) ,
[ prefetch_hit [ " $vector " ] for prefetch_hit in prefetch_hits ] ,
async def amax_marginal_relevance_search_by_vector (
k = k ,
self ,
lambda_mult = lambda_mult ,
embedding : List [ float ] ,
)
k : int = 4 ,
mmr_hits = [
fetch_k : int = 20 ,
prefetch_hit
lambda_mult : float = 0.5 ,
for prefetch_index , prefetch_hit in enumerate ( prefetch_hits )
filter : Optional [ Dict [ str , Any ] ] = None ,
if prefetch_index in mmr_chosen_indices
* * kwargs : Any ,
]
) - > List [ Document ] :
return [
""" Return docs selected using the maximal marginal relevance.
Document (
Maximal marginal relevance optimizes for similarity to query AND diversity
page_content = hit [ " content " ] ,
among selected documents .
metadata = hit [ " metadata " ] ,
Args :
embedding : Embedding to look up documents similar to .
k : Number of Documents to return .
fetch_k : Number of Documents to fetch to pass to MMR algorithm .
lambda_mult : Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity .
Returns :
List of Documents selected by maximal marginal relevance .
"""
await self . _ensure_db_setup ( )
if not self . async_collection :
return await run_in_executor (
None ,
self . max_marginal_relevance_search_by_vector ,
embedding ,
k ,
fetch_k ,
lambda_mult ,
filter ,
* * kwargs ,
)
metadata_parameter = self . _filter_to_metadata ( filter )
prefetch_hits = [
hit
async for hit in self . async_collection . paginated_find (
filter = metadata_parameter ,
sort = { " $vector " : embedding } ,
options = { " limit " : fetch_k , " includeSimilarity " : True } ,
projection = {
" _id " : 1 ,
" content " : 1 ,
" metadata " : 1 ,
" $vector " : 1 ,
} ,
)
)
for hit in mmr_hits
]
]
return self . _get_mmr_hits ( embedding , k , lambda_mult , prefetch_hits )
def max_marginal_relevance_search (
def max_marginal_relevance_search (
self ,
self ,
query : str ,
query : str ,
@ -686,36 +1108,50 @@ class AstraDB(VectorStore):
filter = filter ,
filter = filter ,
)
)
@classmethod
async def amax_marginal_relevance_search (
def from_texts (
self ,
cls : Type [ ADBVST ] ,
query : str ,
texts: List [ str ] ,
k: int = 4 ,
embedding: Embeddings ,
fetch_k: int = 20 ,
metadatas: Optional [ List [ dict ] ] = None ,
lambda_mult: float = 0.5 ,
ids : Optional [ List [ str ] ] = None ,
filter : Optional [ Dict [ str , Any ] ] = None ,
* * kwargs : Any ,
* * kwargs : Any ,
) - > ADBVST :
) - > List [ Document ] :
""" Create an Astra DB vectorstore from raw texts.
""" Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents .
Args :
Args :
texts ( List [ str ] ) : the texts to insert .
query ( str ) : Text to look up documents similar to .
embedding ( Embeddings ) : the embedding function to use in the store .
k ( int = 4 ) : Number of Documents to return .
metadatas ( Optional [ List [ dict ] ] ) : metadata dicts for the texts .
fetch_k ( int = 20 ) : Number of Documents to fetch to pass to MMR algorithm .
ids ( Optional [ List [ str ] ] ) : ids to associate to the texts .
lambda_mult ( float = 0.5 ) : Number between 0 and 1 that determines the degree
* Additional arguments * : you can pass any argument that you would
of diversity among the results with 0 corresponding
to ' add_texts ' and / or to the ' AstraDB ' class constructor
to maximum diversity and 1 to minimum diversity .
( see these methods for details ) . These arguments will be
Optional .
routed to the respective methods as they are .
Returns :
Returns :
an ` AstraDb ` vectorstor e.
List of Documents selected by maximal marginal relevance .
"""
"""
embedding_vector = await self . embedding . aembed_query ( query )
return await self . amax_marginal_relevance_search_by_vector (
embedding_vector ,
k ,
fetch_k ,
lambda_mult = lambda_mult ,
filter = filter ,
)
@classmethod
def _from_kwargs (
cls : Type [ ADBVST ] ,
embedding : Embeddings ,
* * kwargs : Any ,
) - > ADBVST :
known_kwargs = {
known_kwargs = {
" collection_name " ,
" collection_name " ,
" token " ,
" token " ,
" api_endpoint " ,
" api_endpoint " ,
" astra_db_client " ,
" astra_db_client " ,
" async_astra_db_client " ,
" namespace " ,
" namespace " ,
" metric " ,
" metric " ,
" batch_size " ,
" batch_size " ,
@ -738,15 +1174,17 @@ class AstraDB(VectorStore):
token = kwargs . get ( " token " )
token = kwargs . get ( " token " )
api_endpoint = kwargs . get ( " api_endpoint " )
api_endpoint = kwargs . get ( " api_endpoint " )
astra_db_client = kwargs . get ( " astra_db_client " )
astra_db_client = kwargs . get ( " astra_db_client " )
async_astra_db_client = kwargs . get ( " async_astra_db_client " )
namespace = kwargs . get ( " namespace " )
namespace = kwargs . get ( " namespace " )
metric = kwargs . get ( " metric " )
metric = kwargs . get ( " metric " )
astra_db_store = cls (
return cls (
embedding = embedding ,
embedding = embedding ,
collection_name = collection_name ,
collection_name = collection_name ,
token = token ,
token = token ,
api_endpoint = api_endpoint ,
api_endpoint = api_endpoint ,
astra_db_client = astra_db_client ,
astra_db_client = astra_db_client ,
async_astra_db_client = async_astra_db_client ,
namespace = namespace ,
namespace = namespace ,
metric = metric ,
metric = metric ,
batch_size = kwargs . get ( " batch_size " ) ,
batch_size = kwargs . get ( " batch_size " ) ,
@ -756,6 +1194,32 @@ class AstraDB(VectorStore):
) ,
) ,
bulk_delete_concurrency = kwargs . get ( " bulk_delete_concurrency " ) ,
bulk_delete_concurrency = kwargs . get ( " bulk_delete_concurrency " ) ,
)
)
@classmethod
def from_texts (
cls : Type [ ADBVST ] ,
texts : List [ str ] ,
embedding : Embeddings ,
metadatas : Optional [ List [ dict ] ] = None ,
ids : Optional [ List [ str ] ] = None ,
* * kwargs : Any ,
) - > ADBVST :
""" Create an Astra DB vectorstore from raw texts.
Args :
texts ( List [ str ] ) : the texts to insert .
embedding ( Embeddings ) : the embedding function to use in the store .
metadatas ( Optional [ List [ dict ] ] ) : metadata dicts for the texts .
ids ( Optional [ List [ str ] ] ) : ids to associate to the texts .
* Additional arguments * : you can pass any argument that you would
to ' add_texts ' and / or to the ' AstraDB ' class constructor
( see these methods for details ) . These arguments will be
routed to the respective methods as they are .
Returns :
an ` AstraDb ` vectorstore .
"""
astra_db_store = AstraDB . _from_kwargs ( embedding , * * kwargs )
astra_db_store . add_texts (
astra_db_store . add_texts (
texts = texts ,
texts = texts ,
metadatas = metadatas ,
metadatas = metadatas ,
@ -766,6 +1230,41 @@ class AstraDB(VectorStore):
)
)
return astra_db_store
return astra_db_store
@classmethod
async def afrom_texts (
cls : Type [ ADBVST ] ,
texts : List [ str ] ,
embedding : Embeddings ,
metadatas : Optional [ List [ dict ] ] = None ,
ids : Optional [ List [ str ] ] = None ,
* * kwargs : Any ,
) - > ADBVST :
""" Create an Astra DB vectorstore from raw texts.
Args :
texts ( List [ str ] ) : the texts to insert .
embedding ( Embeddings ) : the embedding function to use in the store .
metadatas ( Optional [ List [ dict ] ] ) : metadata dicts for the texts .
ids ( Optional [ List [ str ] ] ) : ids to associate to the texts .
* Additional arguments * : you can pass any argument that you would
to ' add_texts ' and / or to the ' AstraDB ' class constructor
( see these methods for details ) . These arguments will be
routed to the respective methods as they are .
Returns :
an ` AstraDb ` vectorstore .
"""
astra_db_store = AstraDB . _from_kwargs ( embedding , * * kwargs )
await astra_db_store . aadd_texts (
texts = texts ,
metadatas = metadatas ,
ids = ids ,
batch_size = kwargs . get ( " batch_size " ) ,
batch_concurrency = kwargs . get ( " batch_concurrency " ) ,
overwrite_concurrency = kwargs . get ( " overwrite_concurrency " ) ,
)
return astra_db_store
@classmethod
@classmethod
def from_documents (
def from_documents (
cls : Type [ ADBVST ] ,
cls : Type [ ADBVST ] ,