elasticsearch: check for deployed models (#18973)

When creating a new index, if we use a retrieval strategy that expects a
model to be deployed in Elasticsearch, check if a model with this name
is indeed deployed before creating an index. This lowers the probability
to get into a state in which an index was created with a faulty model
ID, which cannot be overwritten any more (the index has to manually be
deleted).
pull/18645/head
Max Jakob 4 months ago committed by GitHub
parent b82644078e
commit 6f544a6a25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -2,7 +2,7 @@ from enum import Enum
from typing import List, Union
import numpy as np
from elasticsearch import Elasticsearch
from elasticsearch import BadRequestError, ConflictError, Elasticsearch, NotFoundError
from langchain_core import __version__ as langchain_version
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
@ -88,3 +88,21 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
return similarity
def check_if_model_deployed(client: Elasticsearch, model_id: str) -> None:
try:
dummy = {"x": "y"}
client.ml.infer_trained_model(model_id=model_id, docs=[dummy])
except NotFoundError as err:
raise err
except ConflictError as err:
raise NotFoundError(
f"model '{model_id}' not found, please deploy it first",
meta=err.meta,
body=err.body,
) from err
except BadRequestError:
# This error is expected because we do not know the expected document
# shape and just use a dummy doc above.
pass

@ -22,6 +22,7 @@ from langchain_core.vectorstores import VectorStore
from langchain_elasticsearch._utilities import (
DistanceStrategy,
check_if_model_deployed,
maximal_marginal_relevance,
with_user_agent_header,
)
@ -199,6 +200,12 @@ class ApproxRetrievalStrategy(BaseRetrievalStrategy):
else:
return {"knn": knn}
def before_index_setup(
self, client: "Elasticsearch", text_field: str, vector_query_field: str
) -> None:
if self.query_model_id:
check_if_model_deployed(client, self.query_model_id)
def index(
self,
dims_length: Union[int, None],
@ -340,8 +347,10 @@ class SparseRetrievalStrategy(BaseRetrievalStrategy):
def before_index_setup(
self, client: "Elasticsearch", text_field: str, vector_query_field: str
) -> None:
# If model_id is provided, create a pipeline for the model
if self.model_id:
check_if_model_deployed(client, self.model_id)
# Create a pipeline for the model
client.ingest.put_pipeline(
id=self._get_pipeline_name(),
description="Embedding pipeline for langchain vectorstore",

@ -7,7 +7,7 @@ import uuid
from typing import Any, Dict, Generator, List, Union
import pytest
from elasticsearch import Elasticsearch
from elasticsearch import Elasticsearch, NotFoundError
from elasticsearch.helpers import BulkIndexError
from langchain_core.documents import Document
@ -40,7 +40,7 @@ Enable them by adding the model name to the modelsDeployed list below.
"""
modelsDeployed: List[str] = [
# "elser",
# ".elser_model_1",
# "sentence-transformers__all-minilm-l6-v2",
]
@ -709,7 +709,7 @@ class TestElasticsearch:
assert output == [Document(page_content="bar")]
@pytest.mark.skipif(
"elser" not in modelsDeployed,
".elser_model_1" not in modelsDeployed,
reason="ELSER not deployed in ML Node, skipping test",
)
def test_similarity_search_with_sparse_infer_instack(
@ -726,6 +726,35 @@ class TestElasticsearch:
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]
def test_deployed_model_check_fails_approx(
self, elasticsearch_connection: dict, index_name: str
) -> None:
"""test that exceptions are raised if a specified model is not deployed"""
with pytest.raises(NotFoundError):
ElasticsearchStore.from_texts(
texts=["foo", "bar", "baz"],
embedding=ConsistentFakeEmbeddings(10),
**elasticsearch_connection,
index_name=index_name,
strategy=ElasticsearchStore.ApproxRetrievalStrategy(
query_model_id="non-existing model ID",
),
)
def test_deployed_model_check_fails_sparse(
self, elasticsearch_connection: dict, index_name: str
) -> None:
"""test that exceptions are raised if a specified model is not deployed"""
with pytest.raises(NotFoundError):
ElasticsearchStore.from_texts(
texts=["foo", "bar", "baz"],
**elasticsearch_connection,
index_name=index_name,
strategy=ElasticsearchStore.SparseVectorRetrievalStrategy(
model_id="non-existing model ID"
),
)
def test_elasticsearch_with_relevance_score(
self, elasticsearch_connection: dict, index_name: str
) -> None:

Loading…
Cancel
Save