@ -1,13 +1,12 @@
from __future__ import annotations
from __future__ import annotations
import asyncio
import uuid
import uuid
import warnings
import warnings
from asyncio import Task
from concurrent . futures import ThreadPoolExecutor
from concurrent . futures import ThreadPoolExecutor
from typing import (
from typing import (
TYPE_CHECKING ,
TYPE_CHECKING ,
Any ,
Any ,
Awaitable ,
Callable ,
Callable ,
Dict ,
Dict ,
Iterable ,
Iterable ,
@ -17,17 +16,21 @@ from typing import (
Tuple ,
Tuple ,
Type ,
Type ,
TypeVar ,
TypeVar ,
Union ,
)
)
import numpy as np
import numpy as np
from langchain_core . _api . deprecation import deprecated
from langchain_core . _api . deprecation import deprecated
from langchain_core . documents import Document
from langchain_core . documents import Document
from langchain_core . embeddings import Embeddings
from langchain_core . embeddings import Embeddings
from langchain_core . runnables import run_in_executor
from langchain_core . runnables . utils import gather_with_concurrency
from langchain_core . runnables . utils import gather_with_concurrency
from langchain_core . utils . iter import batch_iterate
from langchain_core . utils . iter import batch_iterate
from langchain_core . vectorstores import VectorStore
from langchain_core . vectorstores import VectorStore
from langchain_community . utilities . astradb import (
SetupMode ,
_AstraDBCollectionEnvironment ,
)
from langchain_community . vectorstores . utils import maximal_marginal_relevance
from langchain_community . vectorstores . utils import maximal_marginal_relevance
if TYPE_CHECKING :
if TYPE_CHECKING :
@ -167,28 +170,12 @@ class AstraDB(VectorStore):
bulk_insert_batch_concurrency : Optional [ int ] = None ,
bulk_insert_batch_concurrency : Optional [ int ] = None ,
bulk_insert_overwrite_concurrency : Optional [ int ] = None ,
bulk_insert_overwrite_concurrency : Optional [ int ] = None ,
bulk_delete_concurrency : Optional [ int ] = None ,
bulk_delete_concurrency : Optional [ int ] = None ,
setup_mode : SetupMode = SetupMode . SYNC ,
pre_delete_collection : bool = False ,
pre_delete_collection : bool = False ,
) - > None :
) - > None :
"""
"""
Create an AstraDB vector store object . See class docstring for help .
Create an AstraDB vector store object . See class docstring for help .
"""
"""
try :
from astrapy . db import AstraDB as LibAstraDB
from astrapy . db import AstraDBCollection
except ( ImportError , ModuleNotFoundError ) :
raise ImportError (
" Could not import a recent astrapy python package. "
" Please install it with `pip install --upgrade astrapy`. "
)
# Conflicting-arg checks:
if astra_db_client is not None or async_astra_db_client is not None :
if token is not None or api_endpoint is not None :
raise ValueError (
" You cannot pass ' astra_db_client ' or ' async_astra_db_client ' to "
" AstraDB if passing ' token ' and ' api_endpoint ' . "
)
self . embedding = embedding
self . embedding = embedding
self . collection_name = collection_name
self . collection_name = collection_name
self . token = token
self . token = token
@ -207,105 +194,35 @@ class AstraDB(VectorStore):
bulk_delete_concurrency or DEFAULT_BULK_DELETE_CONCURRENCY
bulk_delete_concurrency or DEFAULT_BULK_DELETE_CONCURRENCY
)
)
# "vector-related" settings
# "vector-related" settings
self . _embedding_dimension : Optional [ int ] = None
self . metric = metric
self . metric = metric
embedding_dimension : Union [ int , Awaitable [ int ] , None ] = None
if setup_mode == SetupMode . ASYNC :
embedding_dimension = self . _aget_embedding_dimension ( )
elif setup_mode == SetupMode . SYNC :
embedding_dimension = self . _get_embedding_dimension ( )
self . astra_db = astra_db_client
self . astra_env = _AstraDBCollectionEnvironment (
self . async_astra_db = async_astra_db_client
collection_name = collection_name ,
self . collection = None
token = token ,
self . async_collection = None
api_endpoint = api_endpoint ,
astra_db_client = astra_db_client ,
if token and api_endpoint :
async_astra_db_client = async_astra_db_client ,
self . astra_db = LibAstraDB (
namespace = namespace ,
token = self . token ,
setup_mode = setup_mode ,
api_endpoint = self . api_endpoint ,
pre_delete_collection = pre_delete_collection ,
namespace = self . namespace ,
embedding_dimension = embedding_dimension ,
)
metric = metric ,
try :
from astrapy . db import AsyncAstraDB
self . async_astra_db = AsyncAstraDB (
token = self . token ,
api_endpoint = self . api_endpoint ,
namespace = self . namespace ,
)
except ( ImportError , ModuleNotFoundError ) :
pass
if self . astra_db is not None :
self . collection = AstraDBCollection (
collection_name = self . collection_name ,
astra_db = self . astra_db ,
)
self . async_setup_db_task : Optional [ Task ] = None
if self . async_astra_db is not None :
from astrapy . db import AsyncAstraDBCollection
self . async_collection = AsyncAstraDBCollection (
collection_name = self . collection_name ,
astra_db = self . async_astra_db ,
)
try :
self . async_setup_db_task = asyncio . create_task (
self . _setup_db ( pre_delete_collection )
)
except RuntimeError :
pass
if self . async_setup_db_task is None :
if not pre_delete_collection :
self . _provision_collection ( )
else :
self . clear ( )
def _ensure_astra_db_client ( self ) : # type: ignore[no-untyped-def]
if not self . astra_db :
raise ValueError ( " Missing AstraDB client " )
async def _setup_db ( self , pre_delete_collection : bool ) - > None :
if pre_delete_collection :
await self . async_astra_db . delete_collection ( # type: ignore[union-attr]
collection_name = self . collection_name ,
)
await self . _aprovision_collection ( )
async def _ensure_db_setup ( self ) - > None :
if self . async_setup_db_task :
await self . async_setup_db_task
def _get_embedding_dimension ( self ) - > int :
if self . _embedding_dimension is None :
self . _embedding_dimension = len (
self . embedding . embed_query ( " This is a sample sentence. " )
)
return self . _embedding_dimension
def _provision_collection ( self ) - > None :
"""
Run the API invocation to create the collection on the backend .
Internal - usage method , no object members are set ,
other than working on the underlying actual storage .
"""
self . astra_db . create_collection ( # type: ignore[union-attr]
dimension = self . _get_embedding_dimension ( ) ,
collection_name = self . collection_name ,
metric = self . metric ,
)
)
self . astra_db = self . astra_env . astra_db
self . async_astra_db = self . astra_env . async_astra_db
self . collection = self . astra_env . collection
self . async_collection = self . astra_env . async_collection
async def _aprovision_collection ( self ) - > None :
def _get_embedding_dimension ( self ) - > int :
"""
return len ( self . embedding . embed_query ( text = " This is a sample sentence. " ) )
Run the API invocation to create the collection on the backend .
Internal - usage method , no object members are set ,
async def _aget_embedding_dimension ( self ) - > int :
other than working on the underlying actual storage .
return len ( await self . embedding . aembed_query ( text = " This is a sample sentence. " ) )
"""
await self . async_astra_db . create_collection ( # type: ignore[union-attr]
dimension = self . _get_embedding_dimension ( ) ,
collection_name = self . collection_name ,
metric = self . metric ,
)
@property
@property
def embeddings ( self ) - > Embeddings :
def embeddings ( self ) - > Embeddings :
@ -326,14 +243,12 @@ class AstraDB(VectorStore):
def clear ( self ) - > None :
def clear ( self ) - > None :
""" Empty the collection of all its stored entries. """
""" Empty the collection of all its stored entries. """
self . delete_collection ( )
self . astra_env. ensure_db_setup ( )
self . _provision_ collection( )
self . collection. delete_many ( { } )
async def aclear ( self ) - > None :
async def aclear ( self ) - > None :
""" Empty the collection of all its stored entries. """
""" Empty the collection of all its stored entries. """
await self . _ensure_db_setup ( )
await self . astra_env . aensure_db_setup ( )
if not self . async_astra_db :
await run_in_executor ( None , self . clear )
await self . async_collection . delete_many ( { } ) # type: ignore[union-attr]
await self . async_collection . delete_many ( { } ) # type: ignore[union-attr]
def delete_by_document_id ( self , document_id : str ) - > bool :
def delete_by_document_id ( self , document_id : str ) - > bool :
@ -341,7 +256,7 @@ class AstraDB(VectorStore):
Remove a single document from the store , given its document_id ( str ) .
Remove a single document from the store , given its document_id ( str ) .
Return True if a document has indeed been deleted , False if ID not found .
Return True if a document has indeed been deleted , False if ID not found .
"""
"""
self . _ensure_astra_db_client ( )
self . astra_env. ensure_db_setup ( )
deletion_response = self . collection . delete_one ( document_id ) # type: ignore[union-attr]
deletion_response = self . collection . delete_one ( document_id ) # type: ignore[union-attr]
return ( ( deletion_response or { } ) . get ( " status " ) or { } ) . get (
return ( ( deletion_response or { } ) . get ( " status " ) or { } ) . get (
" deletedCount " , 0
" deletedCount " , 0
@ -352,9 +267,7 @@ class AstraDB(VectorStore):
Remove a single document from the store , given its document_id ( str ) .
Remove a single document from the store , given its document_id ( str ) .
Return True if a document has indeed been deleted , False if ID not found .
Return True if a document has indeed been deleted , False if ID not found .
"""
"""
await self . _ensure_db_setup ( )
await self . astra_env . aensure_db_setup ( )
if not self . async_collection :
return await run_in_executor ( None , self . delete_by_document_id , document_id )
deletion_response = await self . async_collection . delete_one ( document_id )
deletion_response = await self . async_collection . delete_one ( document_id )
return ( ( deletion_response or { } ) . get ( " status " ) or { } ) . get (
return ( ( deletion_response or { } ) . get ( " status " ) or { } ) . get (
" deletedCount " , 0
" deletedCount " , 0
@ -439,8 +352,8 @@ class AstraDB(VectorStore):
Stored data is lost and unrecoverable , resources are freed .
Stored data is lost and unrecoverable , resources are freed .
Use with caution .
Use with caution .
"""
"""
self . _ensure_astra_db_client ( )
self . astra_env. ensure_db_setup ( )
self . astra_db . delete_collection ( # type: ignore[union-attr]
self . astra_db . delete_collection (
collection_name = self . collection_name ,
collection_name = self . collection_name ,
)
)
@ -451,10 +364,8 @@ class AstraDB(VectorStore):
Stored data is lost and unrecoverable , resources are freed .
Stored data is lost and unrecoverable , resources are freed .
Use with caution .
Use with caution .
"""
"""
await self . _ensure_db_setup ( )
await self . astra_env . aensure_db_setup ( )
if not self . async_astra_db :
await self . async_astra_db . delete_collection (
await run_in_executor ( None , self . delete_collection )
await self . async_astra_db . delete_collection ( # type: ignore[union-attr]
collection_name = self . collection_name ,
collection_name = self . collection_name ,
)
)
@ -569,7 +480,7 @@ class AstraDB(VectorStore):
f " unsupported arguments ( { ' , ' . join ( sorted ( kwargs . keys ( ) ) ) } ), "
f " unsupported arguments ( { ' , ' . join ( sorted ( kwargs . keys ( ) ) ) } ), "
" which will be ignored. "
" which will be ignored. "
)
)
self . _ensure_astra_db_client ( )
self . astra_env. ensure_db_setup ( )
embedding_vectors = self . embedding . embed_documents ( list ( texts ) )
embedding_vectors = self . embedding . embed_documents ( list ( texts ) )
documents_to_insert = self . _get_documents_to_insert (
documents_to_insert = self . _get_documents_to_insert (
@ -655,22 +566,13 @@ class AstraDB(VectorStore):
Returns :
Returns :
List [ str ] : List of ids of the added texts .
List [ str ] : List of ids of the added texts .
"""
"""
await self . _ensure_db_setup ( )
if not self . async_collection :
await super ( ) . aadd_texts (
texts ,
metadatas ,
ids = ids ,
batch_size = batch_size ,
batch_concurrency = batch_concurrency ,
overwrite_concurrency = overwrite_concurrency ,
)
if kwargs :
if kwargs :
warnings . warn (
warnings . warn (
" Method ' aadd_texts ' of AstraDB vector store invoked with "
" Method ' aadd_texts ' of AstraDB vector store invoked with "
f " unsupported arguments ( { ' , ' . join ( sorted ( kwargs . keys ( ) ) ) } ), "
f " unsupported arguments ( { ' , ' . join ( sorted ( kwargs . keys ( ) ) ) } ), "
" which will be ignored. "
" which will be ignored. "
)
)
await self . astra_env . aensure_db_setup ( )
embedding_vectors = await self . embedding . aembed_documents ( list ( texts ) )
embedding_vectors = await self . embedding . aembed_documents ( list ( texts ) )
documents_to_insert = self . _get_documents_to_insert (
documents_to_insert = self . _get_documents_to_insert (
@ -731,7 +633,7 @@ class AstraDB(VectorStore):
Returns :
Returns :
List of ( Document , score , id ) , the most similar to the query vector .
List of ( Document , score , id ) , the most similar to the query vector .
"""
"""
self . _ensure_astra_db_client ( )
self . astra_env. ensure_db_setup ( )
metadata_parameter = self . _filter_to_metadata ( filter )
metadata_parameter = self . _filter_to_metadata ( filter )
#
#
hits = list (
hits = list (
@ -773,15 +675,7 @@ class AstraDB(VectorStore):
Returns :
Returns :
List of ( Document , score , id ) , the most similar to the query vector .
List of ( Document , score , id ) , the most similar to the query vector .
"""
"""
await self . _ensure_db_setup ( )
await self . astra_env . aensure_db_setup ( )
if not self . async_collection :
return await run_in_executor (
None ,
self . asimilarity_search_with_score_id_by_vector , # type: ignore[arg-type]
embedding ,
k ,
filter ,
)
metadata_parameter = self . _filter_to_metadata ( filter )
metadata_parameter = self . _filter_to_metadata ( filter )
#
#
return [
return [
@ -1010,7 +904,7 @@ class AstraDB(VectorStore):
Returns :
Returns :
List of Documents selected by maximal marginal relevance .
List of Documents selected by maximal marginal relevance .
"""
"""
self . _ensure_astra_db_client ( )
self . astra_env. ensure_db_setup ( )
metadata_parameter = self . _filter_to_metadata ( filter )
metadata_parameter = self . _filter_to_metadata ( filter )
prefetch_hits = list (
prefetch_hits = list (
@ -1051,18 +945,7 @@ class AstraDB(VectorStore):
Returns :
Returns :
List of Documents selected by maximal marginal relevance .
List of Documents selected by maximal marginal relevance .
"""
"""
await self . _ensure_db_setup ( )
await self . astra_env . aensure_db_setup ( )
if not self . async_collection :
return await run_in_executor (
None ,
self . max_marginal_relevance_search_by_vector ,
embedding ,
k ,
fetch_k ,
lambda_mult ,
filter ,
* * kwargs ,
)
metadata_parameter = self . _filter_to_metadata ( filter )
metadata_parameter = self . _filter_to_metadata ( filter )
prefetch_hits = [
prefetch_hits = [