partners/mongodb: Improved search index commands (#24745)

Hardens index commands with try/except for free clusters and optional
waits for syncing and tests.

[efriis](https://github.com/efriis) These are the upgrades to the search
index commands (CRUD) that I mentioned.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Casey Clements 2024-08-01 16:16:32 -04:00 committed by GitHub
parent db42576b09
commit db3ceb4d0a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 260 additions and 23 deletions

View File

@ -1,11 +1,25 @@
import logging
from typing import Any, Dict, List, Optional
from time import monotonic, sleep
from typing import Any, Callable, Dict, List, Optional
from pymongo.collection import Collection
from pymongo.errors import OperationFailure
from pymongo.operations import SearchIndexModel
logger = logging.getLogger(__file__)
_DELAY = 0.5 # Interval between checks for index operations
def _search_index_error_message() -> str:
return (
"Search index operations are not currently available on shared clusters, "
"such as MO. They require dedicated clusters >= M10. "
"You may still perform vector search. "
"You simply must set up indexes manually. Follow the instructions here: "
"https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/"
)
def _vector_search_index_definition(
dimensions: int,
@ -32,7 +46,9 @@ def create_vector_search_index(
dimensions: int,
path: str,
similarity: str,
filters: List[Dict[str, str]],
filters: Optional[List[Dict[str, str]]] = None,
*,
wait_until_complete: Optional[float] = None,
) -> None:
"""Experimental Utility function to create a vector search index
@ -43,31 +59,65 @@ def create_vector_search_index(
path (str): field with vector embedding
similarity (str): The similarity score used for the index
filters (List[Dict[str, str]]): additional filters for index definition.
wait_until_complete (Optional[float]): If provided, number of seconds to wait
until search index is ready.
"""
logger.info("Creating Search Index %s on %s", index_name, collection.name)
result = collection.create_search_index(
SearchIndexModel(
definition=_vector_search_index_definition(
dimensions=dimensions, path=path, similarity=similarity, filters=filters
),
name=index_name,
type="vectorSearch",
try:
result = collection.create_search_index(
SearchIndexModel(
definition=_vector_search_index_definition(
dimensions=dimensions,
path=path,
similarity=similarity,
filters=filters,
),
name=index_name,
type="vectorSearch",
)
)
except OperationFailure as e:
raise OperationFailure(_search_index_error_message()) from e
if wait_until_complete:
_wait_for_predicate(
predicate=lambda: _is_index_ready(collection, index_name),
err=f"Index {index_name} creation did not finish in {wait_until_complete}!",
timeout=wait_until_complete,
)
)
logger.info(result)
def drop_vector_search_index(collection: Collection, index_name: str) -> None:
def drop_vector_search_index(
collection: Collection,
index_name: str,
*,
wait_until_complete: Optional[float] = None,
) -> None:
"""Drop a created vector search index
Args:
collection (Collection): MongoDB Collection with index to be dropped
index_name (str): Name of the MongoDB index
wait_until_complete (Optional[float]): If provided, number of seconds to wait
until search index is ready.
"""
logger.info(
"Dropping Search Index %s from Collection: %s", index_name, collection.name
)
collection.drop_search_index(index_name)
try:
collection.drop_search_index(index_name)
except OperationFailure as e:
if "CommandNotSupported" in str(e):
raise OperationFailure(_search_index_error_message()) from e
# else this most likely means an ongoing drop request was made so skip
if wait_until_complete:
_wait_for_predicate(
predicate=lambda: len(list(collection.list_search_indexes())) == 0,
err=f"Index {index_name} did not drop in {wait_until_complete}!",
timeout=wait_until_complete,
)
logger.info("Vector Search index %s.%s dropped", collection.name, index_name)
@ -78,8 +128,12 @@ def update_vector_search_index(
path: str,
similarity: str,
filters: List[Dict[str, str]],
*,
wait_until_complete: Optional[float] = None,
) -> None:
"""Leverages the updateSearchIndex call
"""Update a search index.
Replace the existing index definition with the provided definition.
Args:
collection (Collection): MongoDB Collection
@ -88,18 +142,73 @@ def update_vector_search_index(
path (str): field with vector embedding.
similarity (str): The similarity score used for the index.
filters (List[Dict[str, str]]): additional filters for index definition.
wait_until_complete (Optional[float]): If provided, number of seconds to wait
until search index is ready.
"""
logger.info(
"Updating Search Index %s from Collection: %s", index_name, collection.name
)
collection.update_search_index(
name=index_name,
definition=_vector_search_index_definition(
dimensions=dimensions,
path=path,
similarity=similarity,
filters=filters,
),
)
try:
collection.update_search_index(
name=index_name,
definition=_vector_search_index_definition(
dimensions=dimensions,
path=path,
similarity=similarity,
filters=filters,
),
)
except OperationFailure as e:
raise OperationFailure(_search_index_error_message()) from e
if wait_until_complete:
_wait_for_predicate(
predicate=lambda: _is_index_ready(collection, index_name),
err=f"Index {index_name} update did not complete in {wait_until_complete}!",
timeout=wait_until_complete,
)
logger.info("Update succeeded")
def _is_index_ready(collection: Collection, index_name: str) -> bool:
"""Check for the index name in the list of available search indexes to see if the
specified index is of status READY
Args:
collection (Collection): MongoDB Collection to for the search indexes
index_name (str): Vector Search Index name
Returns:
bool : True if the index is present and READY false otherwise
"""
try:
search_indexes = collection.list_search_indexes(index_name)
except OperationFailure as e:
raise OperationFailure(_search_index_error_message()) from e
for index in search_indexes:
if index["type"] == "vectorSearch" and index["status"] == "READY":
return True
return False
def _wait_for_predicate(
predicate: Callable, err: str, timeout: float = 120, interval: float = 0.5
) -> None:
"""Generic to block until the predicate returns true
Args:
predicate (Callable[, bool]): A function that returns a boolean value
err (str): Error message to raise if nothing occurs
timeout (float, optional): wait time for predicate. Defaults to TIMEOUT.
interval (float, optional): Interval to check predicate. Defaults to DELAY.
Raises:
TimeoutError: _description_
"""
start = monotonic()
while not predicate():
if monotonic() - start > timeout:
raise TimeoutError(err)
sleep(interval)

View File

@ -629,4 +629,4 @@ class MongoDBAtlasVectorSearch(VectorStore):
path=self._embedding_key,
similarity=self._relevance_score_fn,
filters=filters or [],
)
) # type: ignore [operator]

View File

@ -0,0 +1,73 @@
"""Search index commands are only supported on Atlas Clusters >=M10"""
import os
import pytest
from pymongo import MongoClient
from pymongo.collection import Collection
from langchain_mongodb import index
@pytest.fixture
def collection() -> Collection:
"""Depending on uri, this could point to any type of cluster."""
uri = os.environ.get("MONGODB_ATLAS_URI")
client: MongoClient = MongoClient(uri)
clxn = client["db"].create_collection("collection")
return clxn
def test_search_index_commands(collection: Collection) -> None:
index_name = "vector_index"
dimensions = 1536
path = "embedding"
similarity = "cosine"
filters: list = []
wait_until_complete = 120
for index_info in collection.list_search_indexes():
index.drop_vector_search_index(
collection, index_info["name"], wait_until_complete=wait_until_complete
)
assert len(list(collection.list_search_indexes())) == 0
index.create_vector_search_index(
collection=collection,
index_name=index_name,
dimensions=dimensions,
path=path,
similarity=similarity,
filters=filters,
wait_until_complete=wait_until_complete,
)
assert index._is_index_ready(collection, index_name)
indexes = list(collection.list_search_indexes())
assert len(indexes) == 1
assert indexes[0]["name"] == index_name
new_similarity = "euclidean"
index.update_vector_search_index(
collection,
index_name,
1536,
"embedding",
new_similarity,
[],
wait_until_complete=wait_until_complete,
)
assert index._is_index_ready(collection, index_name)
indexes = list(collection.list_search_indexes())
assert len(indexes) == 1
assert indexes[0]["name"] == index_name
assert indexes[0]["latestDefinition"]["fields"][0]["similarity"] == new_similarity
index.drop_vector_search_index(
collection, index_name, wait_until_complete=wait_until_complete
)
indexes = list(collection.list_search_indexes())
assert len(indexes) == 0

View File

@ -0,0 +1,55 @@
"""Search index commands are only supported on Atlas Clusters >=M10"""
import os
from time import sleep
import pytest
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.errors import OperationFailure, ServerSelectionTimeoutError
from langchain_mongodb import index
@pytest.fixture
def collection() -> Collection:
"""Depending on uri, this could point to any type of cluster.
For unit tests, MONGODB_URI should be localhost, None, or Atlas cluster <M10.
"""
uri = os.environ.get("MONGODB_URI")
client: MongoClient = MongoClient(uri)
return client["db"]["collection"]
def test_create_vector_search_index(collection: Collection) -> None:
with pytest.raises((OperationFailure, ServerSelectionTimeoutError)):
index.create_vector_search_index(
collection, "index_name", 1536, "embedding", "cosine", []
)
def test_drop_vector_search_index(collection: Collection) -> None:
with pytest.raises((OperationFailure, ServerSelectionTimeoutError)):
index.drop_vector_search_index(collection, "index_name")
def test_update_vector_search_index(collection: Collection) -> None:
with pytest.raises((OperationFailure, ServerSelectionTimeoutError)):
index.update_vector_search_index(
collection, "index_name", 1536, "embedding", "cosine", []
)
def test___is_index_ready(collection: Collection) -> None:
with pytest.raises((OperationFailure, ServerSelectionTimeoutError)):
index._is_index_ready(collection, "index_name")
def test__wait_for_predicate() -> None:
err = "error string"
with pytest.raises(TimeoutError) as e:
index._wait_for_predicate(lambda: sleep(5), err=err, timeout=0.5, interval=0.1)
assert err in str(e)
index._wait_for_predicate(lambda: True, err=err, timeout=1.0, interval=0.5)