community: Use _AstraDBCollectionEnvironment in AstraDB VectorStore (community) (#17635)

Another PR will be done for the langchain-astradb package.

Note: for future PRs, devs will be done in the partner package only. This one is just to align with the rest of the components in the community package and it fixes a bunch of issues.
pull/17648/head
Christophe Bornet 5 months ago committed by GitHub
parent 0b33abc8b1
commit 19ebc7418e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,13 +1,12 @@
from __future__ import annotations
import asyncio
import uuid
import warnings
from asyncio import Task
from concurrent.futures import ThreadPoolExecutor
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Iterable,
@ -17,17 +16,21 @@ from typing import (
Tuple,
Type,
TypeVar,
Union,
)
import numpy as np
from langchain_core._api.deprecation import deprecated
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.runnables import run_in_executor
from langchain_core.runnables.utils import gather_with_concurrency
from langchain_core.utils.iter import batch_iterate
from langchain_core.vectorstores import VectorStore
from langchain_community.utilities.astradb import (
SetupMode,
_AstraDBCollectionEnvironment,
)
from langchain_community.vectorstores.utils import maximal_marginal_relevance
if TYPE_CHECKING:
@ -167,28 +170,12 @@ class AstraDB(VectorStore):
bulk_insert_batch_concurrency: Optional[int] = None,
bulk_insert_overwrite_concurrency: Optional[int] = None,
bulk_delete_concurrency: Optional[int] = None,
setup_mode: SetupMode = SetupMode.SYNC,
pre_delete_collection: bool = False,
) -> None:
"""
Create an AstraDB vector store object. See class docstring for help.
"""
try:
from astrapy.db import AstraDB as LibAstraDB
from astrapy.db import AstraDBCollection
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Could not import a recent astrapy python package. "
"Please install it with `pip install --upgrade astrapy`."
)
# Conflicting-arg checks:
if astra_db_client is not None or async_astra_db_client is not None:
if token is not None or api_endpoint is not None:
raise ValueError(
"You cannot pass 'astra_db_client' or 'async_astra_db_client' to "
"AstraDB if passing 'token' and 'api_endpoint'."
)
self.embedding = embedding
self.collection_name = collection_name
self.token = token
@ -207,105 +194,35 @@ class AstraDB(VectorStore):
bulk_delete_concurrency or DEFAULT_BULK_DELETE_CONCURRENCY
)
# "vector-related" settings
self._embedding_dimension: Optional[int] = None
self.metric = metric
embedding_dimension: Union[int, Awaitable[int], None] = None
if setup_mode == SetupMode.ASYNC:
embedding_dimension = self._aget_embedding_dimension()
elif setup_mode == SetupMode.SYNC:
embedding_dimension = self._get_embedding_dimension()
self.astra_db = astra_db_client
self.async_astra_db = async_astra_db_client
self.collection = None
self.async_collection = None
if token and api_endpoint:
self.astra_db = LibAstraDB(
token=self.token,
api_endpoint=self.api_endpoint,
namespace=self.namespace,
)
try:
from astrapy.db import AsyncAstraDB
self.async_astra_db = AsyncAstraDB(
token=self.token,
api_endpoint=self.api_endpoint,
namespace=self.namespace,
)
except (ImportError, ModuleNotFoundError):
pass
if self.astra_db is not None:
self.collection = AstraDBCollection(
collection_name=self.collection_name,
astra_db=self.astra_db,
)
self.async_setup_db_task: Optional[Task] = None
if self.async_astra_db is not None:
from astrapy.db import AsyncAstraDBCollection
self.async_collection = AsyncAstraDBCollection(
collection_name=self.collection_name,
astra_db=self.async_astra_db,
)
try:
self.async_setup_db_task = asyncio.create_task(
self._setup_db(pre_delete_collection)
)
except RuntimeError:
pass
if self.async_setup_db_task is None:
if not pre_delete_collection:
self._provision_collection()
else:
self.clear()
def _ensure_astra_db_client(self): # type: ignore[no-untyped-def]
if not self.astra_db:
raise ValueError("Missing AstraDB client")
async def _setup_db(self, pre_delete_collection: bool) -> None:
if pre_delete_collection:
await self.async_astra_db.delete_collection( # type: ignore[union-attr]
collection_name=self.collection_name,
)
await self._aprovision_collection()
async def _ensure_db_setup(self) -> None:
if self.async_setup_db_task:
await self.async_setup_db_task
def _get_embedding_dimension(self) -> int:
if self._embedding_dimension is None:
self._embedding_dimension = len(
self.embedding.embed_query("This is a sample sentence.")
)
return self._embedding_dimension
def _provision_collection(self) -> None:
"""
Run the API invocation to create the collection on the backend.
Internal-usage method, no object members are set,
other than working on the underlying actual storage.
"""
self.astra_db.create_collection( # type: ignore[union-attr]
dimension=self._get_embedding_dimension(),
collection_name=self.collection_name,
metric=self.metric,
self.astra_env = _AstraDBCollectionEnvironment(
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
embedding_dimension=embedding_dimension,
metric=metric,
)
self.astra_db = self.astra_env.astra_db
self.async_astra_db = self.astra_env.async_astra_db
self.collection = self.astra_env.collection
self.async_collection = self.astra_env.async_collection
async def _aprovision_collection(self) -> None:
"""
Run the API invocation to create the collection on the backend.
def _get_embedding_dimension(self) -> int:
return len(self.embedding.embed_query(text="This is a sample sentence."))
Internal-usage method, no object members are set,
other than working on the underlying actual storage.
"""
await self.async_astra_db.create_collection( # type: ignore[union-attr]
dimension=self._get_embedding_dimension(),
collection_name=self.collection_name,
metric=self.metric,
)
async def _aget_embedding_dimension(self) -> int:
return len(await self.embedding.aembed_query(text="This is a sample sentence."))
@property
def embeddings(self) -> Embeddings:
@ -326,14 +243,12 @@ class AstraDB(VectorStore):
def clear(self) -> None:
"""Empty the collection of all its stored entries."""
self.delete_collection()
self._provision_collection()
self.astra_env.ensure_db_setup()
self.collection.delete_many({})
async def aclear(self) -> None:
"""Empty the collection of all its stored entries."""
await self._ensure_db_setup()
if not self.async_astra_db:
await run_in_executor(None, self.clear)
await self.astra_env.aensure_db_setup()
await self.async_collection.delete_many({}) # type: ignore[union-attr]
def delete_by_document_id(self, document_id: str) -> bool:
@ -341,7 +256,7 @@ class AstraDB(VectorStore):
Remove a single document from the store, given its document_id (str).
Return True if a document has indeed been deleted, False if ID not found.
"""
self._ensure_astra_db_client()
self.astra_env.ensure_db_setup()
deletion_response = self.collection.delete_one(document_id) # type: ignore[union-attr]
return ((deletion_response or {}).get("status") or {}).get(
"deletedCount", 0
@ -352,9 +267,7 @@ class AstraDB(VectorStore):
Remove a single document from the store, given its document_id (str).
Return True if a document has indeed been deleted, False if ID not found.
"""
await self._ensure_db_setup()
if not self.async_collection:
return await run_in_executor(None, self.delete_by_document_id, document_id)
await self.astra_env.aensure_db_setup()
deletion_response = await self.async_collection.delete_one(document_id)
return ((deletion_response or {}).get("status") or {}).get(
"deletedCount", 0
@ -439,8 +352,8 @@ class AstraDB(VectorStore):
Stored data is lost and unrecoverable, resources are freed.
Use with caution.
"""
self._ensure_astra_db_client()
self.astra_db.delete_collection( # type: ignore[union-attr]
self.astra_env.ensure_db_setup()
self.astra_db.delete_collection(
collection_name=self.collection_name,
)
@ -451,10 +364,8 @@ class AstraDB(VectorStore):
Stored data is lost and unrecoverable, resources are freed.
Use with caution.
"""
await self._ensure_db_setup()
if not self.async_astra_db:
await run_in_executor(None, self.delete_collection)
await self.async_astra_db.delete_collection( # type: ignore[union-attr]
await self.astra_env.aensure_db_setup()
await self.async_astra_db.delete_collection(
collection_name=self.collection_name,
)
@ -569,7 +480,7 @@ class AstraDB(VectorStore):
f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), "
"which will be ignored."
)
self._ensure_astra_db_client()
self.astra_env.ensure_db_setup()
embedding_vectors = self.embedding.embed_documents(list(texts))
documents_to_insert = self._get_documents_to_insert(
@ -655,22 +566,13 @@ class AstraDB(VectorStore):
Returns:
List[str]: List of ids of the added texts.
"""
await self._ensure_db_setup()
if not self.async_collection:
await super().aadd_texts(
texts,
metadatas,
ids=ids,
batch_size=batch_size,
batch_concurrency=batch_concurrency,
overwrite_concurrency=overwrite_concurrency,
)
if kwargs:
warnings.warn(
"Method 'aadd_texts' of AstraDB vector store invoked with "
f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), "
"which will be ignored."
)
await self.astra_env.aensure_db_setup()
embedding_vectors = await self.embedding.aembed_documents(list(texts))
documents_to_insert = self._get_documents_to_insert(
@ -731,7 +633,7 @@ class AstraDB(VectorStore):
Returns:
List of (Document, score, id), the most similar to the query vector.
"""
self._ensure_astra_db_client()
self.astra_env.ensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)
#
hits = list(
@ -773,15 +675,7 @@ class AstraDB(VectorStore):
Returns:
List of (Document, score, id), the most similar to the query vector.
"""
await self._ensure_db_setup()
if not self.async_collection:
return await run_in_executor(
None,
self.asimilarity_search_with_score_id_by_vector, # type: ignore[arg-type]
embedding,
k,
filter,
)
await self.astra_env.aensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)
#
return [
@ -1010,7 +904,7 @@ class AstraDB(VectorStore):
Returns:
List of Documents selected by maximal marginal relevance.
"""
self._ensure_astra_db_client()
self.astra_env.ensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)
prefetch_hits = list(
@ -1051,18 +945,7 @@ class AstraDB(VectorStore):
Returns:
List of Documents selected by maximal marginal relevance.
"""
await self._ensure_db_setup()
if not self.async_collection:
return await run_in_executor(
None,
self.max_marginal_relevance_search_by_vector,
embedding,
k,
fetch_k,
lambda_mult,
filter,
**kwargs,
)
await self.astra_env.aensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)
prefetch_hits = [

Loading…
Cancel
Save