mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
partners/mongodb: Improved search index commands (#24745)
Hardens index commands with try/except for free clusters and optional waits for syncing and tests. [efriis](https://github.com/efriis) These are the upgrades to the search index commands (CRUD) that I mentioned. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
db42576b09
commit
db3ceb4d0a
@ -1,11 +1,25 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from time import monotonic, sleep
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.errors import OperationFailure
|
||||
from pymongo.operations import SearchIndexModel
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
_DELAY = 0.5 # Interval between checks for index operations
|
||||
|
||||
|
||||
def _search_index_error_message() -> str:
|
||||
return (
|
||||
"Search index operations are not currently available on shared clusters, "
|
||||
"such as MO. They require dedicated clusters >= M10. "
|
||||
"You may still perform vector search. "
|
||||
"You simply must set up indexes manually. Follow the instructions here: "
|
||||
"https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/"
|
||||
)
|
||||
|
||||
|
||||
def _vector_search_index_definition(
|
||||
dimensions: int,
|
||||
@ -32,7 +46,9 @@ def create_vector_search_index(
|
||||
dimensions: int,
|
||||
path: str,
|
||||
similarity: str,
|
||||
filters: List[Dict[str, str]],
|
||||
filters: Optional[List[Dict[str, str]]] = None,
|
||||
*,
|
||||
wait_until_complete: Optional[float] = None,
|
||||
) -> None:
|
||||
"""Experimental Utility function to create a vector search index
|
||||
|
||||
@ -43,31 +59,65 @@ def create_vector_search_index(
|
||||
path (str): field with vector embedding
|
||||
similarity (str): The similarity score used for the index
|
||||
filters (List[Dict[str, str]]): additional filters for index definition.
|
||||
wait_until_complete (Optional[float]): If provided, number of seconds to wait
|
||||
until search index is ready.
|
||||
"""
|
||||
logger.info("Creating Search Index %s on %s", index_name, collection.name)
|
||||
result = collection.create_search_index(
|
||||
SearchIndexModel(
|
||||
definition=_vector_search_index_definition(
|
||||
dimensions=dimensions, path=path, similarity=similarity, filters=filters
|
||||
),
|
||||
name=index_name,
|
||||
type="vectorSearch",
|
||||
|
||||
try:
|
||||
result = collection.create_search_index(
|
||||
SearchIndexModel(
|
||||
definition=_vector_search_index_definition(
|
||||
dimensions=dimensions,
|
||||
path=path,
|
||||
similarity=similarity,
|
||||
filters=filters,
|
||||
),
|
||||
name=index_name,
|
||||
type="vectorSearch",
|
||||
)
|
||||
)
|
||||
except OperationFailure as e:
|
||||
raise OperationFailure(_search_index_error_message()) from e
|
||||
|
||||
if wait_until_complete:
|
||||
_wait_for_predicate(
|
||||
predicate=lambda: _is_index_ready(collection, index_name),
|
||||
err=f"Index {index_name} creation did not finish in {wait_until_complete}!",
|
||||
timeout=wait_until_complete,
|
||||
)
|
||||
)
|
||||
logger.info(result)
|
||||
|
||||
|
||||
def drop_vector_search_index(collection: Collection, index_name: str) -> None:
|
||||
def drop_vector_search_index(
|
||||
collection: Collection,
|
||||
index_name: str,
|
||||
*,
|
||||
wait_until_complete: Optional[float] = None,
|
||||
) -> None:
|
||||
"""Drop a created vector search index
|
||||
|
||||
Args:
|
||||
collection (Collection): MongoDB Collection with index to be dropped
|
||||
index_name (str): Name of the MongoDB index
|
||||
wait_until_complete (Optional[float]): If provided, number of seconds to wait
|
||||
until search index is ready.
|
||||
"""
|
||||
logger.info(
|
||||
"Dropping Search Index %s from Collection: %s", index_name, collection.name
|
||||
)
|
||||
collection.drop_search_index(index_name)
|
||||
try:
|
||||
collection.drop_search_index(index_name)
|
||||
except OperationFailure as e:
|
||||
if "CommandNotSupported" in str(e):
|
||||
raise OperationFailure(_search_index_error_message()) from e
|
||||
# else this most likely means an ongoing drop request was made so skip
|
||||
if wait_until_complete:
|
||||
_wait_for_predicate(
|
||||
predicate=lambda: len(list(collection.list_search_indexes())) == 0,
|
||||
err=f"Index {index_name} did not drop in {wait_until_complete}!",
|
||||
timeout=wait_until_complete,
|
||||
)
|
||||
logger.info("Vector Search index %s.%s dropped", collection.name, index_name)
|
||||
|
||||
|
||||
@ -78,8 +128,12 @@ def update_vector_search_index(
|
||||
path: str,
|
||||
similarity: str,
|
||||
filters: List[Dict[str, str]],
|
||||
*,
|
||||
wait_until_complete: Optional[float] = None,
|
||||
) -> None:
|
||||
"""Leverages the updateSearchIndex call
|
||||
"""Update a search index.
|
||||
|
||||
Replace the existing index definition with the provided definition.
|
||||
|
||||
Args:
|
||||
collection (Collection): MongoDB Collection
|
||||
@ -88,18 +142,73 @@ def update_vector_search_index(
|
||||
path (str): field with vector embedding.
|
||||
similarity (str): The similarity score used for the index.
|
||||
filters (List[Dict[str, str]]): additional filters for index definition.
|
||||
wait_until_complete (Optional[float]): If provided, number of seconds to wait
|
||||
until search index is ready.
|
||||
"""
|
||||
|
||||
logger.info(
|
||||
"Updating Search Index %s from Collection: %s", index_name, collection.name
|
||||
)
|
||||
collection.update_search_index(
|
||||
name=index_name,
|
||||
definition=_vector_search_index_definition(
|
||||
dimensions=dimensions,
|
||||
path=path,
|
||||
similarity=similarity,
|
||||
filters=filters,
|
||||
),
|
||||
)
|
||||
try:
|
||||
collection.update_search_index(
|
||||
name=index_name,
|
||||
definition=_vector_search_index_definition(
|
||||
dimensions=dimensions,
|
||||
path=path,
|
||||
similarity=similarity,
|
||||
filters=filters,
|
||||
),
|
||||
)
|
||||
except OperationFailure as e:
|
||||
raise OperationFailure(_search_index_error_message()) from e
|
||||
|
||||
if wait_until_complete:
|
||||
_wait_for_predicate(
|
||||
predicate=lambda: _is_index_ready(collection, index_name),
|
||||
err=f"Index {index_name} update did not complete in {wait_until_complete}!",
|
||||
timeout=wait_until_complete,
|
||||
)
|
||||
logger.info("Update succeeded")
|
||||
|
||||
|
||||
def _is_index_ready(collection: Collection, index_name: str) -> bool:
|
||||
"""Check for the index name in the list of available search indexes to see if the
|
||||
specified index is of status READY
|
||||
|
||||
Args:
|
||||
collection (Collection): MongoDB Collection to for the search indexes
|
||||
index_name (str): Vector Search Index name
|
||||
|
||||
Returns:
|
||||
bool : True if the index is present and READY false otherwise
|
||||
"""
|
||||
try:
|
||||
search_indexes = collection.list_search_indexes(index_name)
|
||||
except OperationFailure as e:
|
||||
raise OperationFailure(_search_index_error_message()) from e
|
||||
|
||||
for index in search_indexes:
|
||||
if index["type"] == "vectorSearch" and index["status"] == "READY":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _wait_for_predicate(
|
||||
predicate: Callable, err: str, timeout: float = 120, interval: float = 0.5
|
||||
) -> None:
|
||||
"""Generic to block until the predicate returns true
|
||||
|
||||
Args:
|
||||
predicate (Callable[, bool]): A function that returns a boolean value
|
||||
err (str): Error message to raise if nothing occurs
|
||||
timeout (float, optional): wait time for predicate. Defaults to TIMEOUT.
|
||||
interval (float, optional): Interval to check predicate. Defaults to DELAY.
|
||||
|
||||
Raises:
|
||||
TimeoutError: _description_
|
||||
"""
|
||||
start = monotonic()
|
||||
while not predicate():
|
||||
if monotonic() - start > timeout:
|
||||
raise TimeoutError(err)
|
||||
sleep(interval)
|
||||
|
@ -629,4 +629,4 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
path=self._embedding_key,
|
||||
similarity=self._relevance_score_fn,
|
||||
filters=filters or [],
|
||||
)
|
||||
) # type: ignore [operator]
|
||||
|
73
libs/partners/mongodb/tests/integration_tests/test_index.py
Normal file
73
libs/partners/mongodb/tests/integration_tests/test_index.py
Normal file
@ -0,0 +1,73 @@
|
||||
"""Search index commands are only supported on Atlas Clusters >=M10"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from pymongo import MongoClient
|
||||
from pymongo.collection import Collection
|
||||
|
||||
from langchain_mongodb import index
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def collection() -> Collection:
|
||||
"""Depending on uri, this could point to any type of cluster."""
|
||||
uri = os.environ.get("MONGODB_ATLAS_URI")
|
||||
client: MongoClient = MongoClient(uri)
|
||||
clxn = client["db"].create_collection("collection")
|
||||
return clxn
|
||||
|
||||
|
||||
def test_search_index_commands(collection: Collection) -> None:
|
||||
index_name = "vector_index"
|
||||
dimensions = 1536
|
||||
path = "embedding"
|
||||
similarity = "cosine"
|
||||
filters: list = []
|
||||
wait_until_complete = 120
|
||||
|
||||
for index_info in collection.list_search_indexes():
|
||||
index.drop_vector_search_index(
|
||||
collection, index_info["name"], wait_until_complete=wait_until_complete
|
||||
)
|
||||
|
||||
assert len(list(collection.list_search_indexes())) == 0
|
||||
|
||||
index.create_vector_search_index(
|
||||
collection=collection,
|
||||
index_name=index_name,
|
||||
dimensions=dimensions,
|
||||
path=path,
|
||||
similarity=similarity,
|
||||
filters=filters,
|
||||
wait_until_complete=wait_until_complete,
|
||||
)
|
||||
|
||||
assert index._is_index_ready(collection, index_name)
|
||||
indexes = list(collection.list_search_indexes())
|
||||
assert len(indexes) == 1
|
||||
assert indexes[0]["name"] == index_name
|
||||
|
||||
new_similarity = "euclidean"
|
||||
index.update_vector_search_index(
|
||||
collection,
|
||||
index_name,
|
||||
1536,
|
||||
"embedding",
|
||||
new_similarity,
|
||||
[],
|
||||
wait_until_complete=wait_until_complete,
|
||||
)
|
||||
|
||||
assert index._is_index_ready(collection, index_name)
|
||||
indexes = list(collection.list_search_indexes())
|
||||
assert len(indexes) == 1
|
||||
assert indexes[0]["name"] == index_name
|
||||
assert indexes[0]["latestDefinition"]["fields"][0]["similarity"] == new_similarity
|
||||
|
||||
index.drop_vector_search_index(
|
||||
collection, index_name, wait_until_complete=wait_until_complete
|
||||
)
|
||||
|
||||
indexes = list(collection.list_search_indexes())
|
||||
assert len(indexes) == 0
|
55
libs/partners/mongodb/tests/unit_tests/test_index.py
Normal file
55
libs/partners/mongodb/tests/unit_tests/test_index.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""Search index commands are only supported on Atlas Clusters >=M10"""
|
||||
|
||||
import os
|
||||
from time import sleep
|
||||
|
||||
import pytest
|
||||
from pymongo import MongoClient
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.errors import OperationFailure, ServerSelectionTimeoutError
|
||||
|
||||
from langchain_mongodb import index
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def collection() -> Collection:
|
||||
"""Depending on uri, this could point to any type of cluster.
|
||||
|
||||
For unit tests, MONGODB_URI should be localhost, None, or Atlas cluster <M10.
|
||||
"""
|
||||
uri = os.environ.get("MONGODB_URI")
|
||||
client: MongoClient = MongoClient(uri)
|
||||
return client["db"]["collection"]
|
||||
|
||||
|
||||
def test_create_vector_search_index(collection: Collection) -> None:
|
||||
with pytest.raises((OperationFailure, ServerSelectionTimeoutError)):
|
||||
index.create_vector_search_index(
|
||||
collection, "index_name", 1536, "embedding", "cosine", []
|
||||
)
|
||||
|
||||
|
||||
def test_drop_vector_search_index(collection: Collection) -> None:
|
||||
with pytest.raises((OperationFailure, ServerSelectionTimeoutError)):
|
||||
index.drop_vector_search_index(collection, "index_name")
|
||||
|
||||
|
||||
def test_update_vector_search_index(collection: Collection) -> None:
|
||||
with pytest.raises((OperationFailure, ServerSelectionTimeoutError)):
|
||||
index.update_vector_search_index(
|
||||
collection, "index_name", 1536, "embedding", "cosine", []
|
||||
)
|
||||
|
||||
|
||||
def test___is_index_ready(collection: Collection) -> None:
|
||||
with pytest.raises((OperationFailure, ServerSelectionTimeoutError)):
|
||||
index._is_index_ready(collection, "index_name")
|
||||
|
||||
|
||||
def test__wait_for_predicate() -> None:
|
||||
err = "error string"
|
||||
with pytest.raises(TimeoutError) as e:
|
||||
index._wait_for_predicate(lambda: sleep(5), err=err, timeout=0.5, interval=0.1)
|
||||
assert err in str(e)
|
||||
|
||||
index._wait_for_predicate(lambda: True, err=err, timeout=1.0, interval=0.5)
|
Loading…
Reference in New Issue
Block a user