From 80b3fdf2f766be786ac7b416805a1205aaffc56b Mon Sep 17 00:00:00 2001 From: Blithe Date: Fri, 2 Jun 2023 01:58:20 +0800 Subject: [PATCH] make the elasticsearch api support version which below 8.x (#5495) the api which create index or search in the elasticsearch below 8.x is different with 8.x. When use the es which below 8.x , it will throw error. I fix the problem Co-authored-by: gaofeng27692 --- .../vectorstores/elastic_vector_search.py | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/langchain/vectorstores/elastic_vector_search.py b/langchain/vectorstores/elastic_vector_search.py index 6663e79d..d67974d1 100644 --- a/langchain/vectorstores/elastic_vector_search.py +++ b/langchain/vectorstores/elastic_vector_search.py @@ -177,7 +177,7 @@ class ElasticVectorSearch(VectorStore, ABC): except NotFoundError: # TODO would be nice to create index before embedding, # just to save expensive steps for last - self.client.indices.create(index=self.index_name, mappings=mapping) + self.create_index(self.client, self.index_name, mapping) for i, text in enumerate(texts): metadata = metadatas[i] if metadatas else {} @@ -226,7 +226,9 @@ class ElasticVectorSearch(VectorStore, ABC): """ embedding = self.embedding.embed_query(query) script_query = _default_script_query(embedding, filter) - response = self.client.search(index=self.index_name, query=script_query, size=k) + response = self.client_search( + self.client, self.index_name, script_query, size=k + ) hits = [hit for hit in response["hits"]["hits"]] docs_and_scores = [ ( @@ -281,3 +283,24 @@ class ElasticVectorSearch(VectorStore, ABC): texts, metadatas=metadatas, refresh_indices=refresh_indices ) return vectorsearch + + def create_index(self, client: Any, index_name: str, mapping: Dict) -> None: + version_num = client.info()["version"]["number"][0] + version_num = int(version_num) + if version_num >= 8: + client.indices.create(index=index_name, mappings=mapping) + else: + client.indices.create(index=index_name, body={"mappings": mapping}) + + def client_search( + self, client: Any, index_name: str, script_query: Dict, size: int + ) -> Any: + version_num = client.info()["version"]["number"][0] + version_num = int(version_num) + if version_num >= 8: + response = client.search(index=index_name, query=script_query, size=size) + else: + response = client.search( + index=index_name, body={"query": script_query, "size": size} + ) + return response