diff --git a/langchain/vectorstores/elastic_vector_search.py b/langchain/vectorstores/elastic_vector_search.py index 6663e79d..d67974d1 100644 --- a/langchain/vectorstores/elastic_vector_search.py +++ b/langchain/vectorstores/elastic_vector_search.py @@ -177,7 +177,7 @@ class ElasticVectorSearch(VectorStore, ABC): except NotFoundError: # TODO would be nice to create index before embedding, # just to save expensive steps for last - self.client.indices.create(index=self.index_name, mappings=mapping) + self.create_index(self.client, self.index_name, mapping) for i, text in enumerate(texts): metadata = metadatas[i] if metadatas else {} @@ -226,7 +226,9 @@ class ElasticVectorSearch(VectorStore, ABC): """ embedding = self.embedding.embed_query(query) script_query = _default_script_query(embedding, filter) - response = self.client.search(index=self.index_name, query=script_query, size=k) + response = self.client_search( + self.client, self.index_name, script_query, size=k + ) hits = [hit for hit in response["hits"]["hits"]] docs_and_scores = [ ( @@ -281,3 +283,24 @@ class ElasticVectorSearch(VectorStore, ABC): texts, metadatas=metadatas, refresh_indices=refresh_indices ) return vectorsearch + + def create_index(self, client: Any, index_name: str, mapping: Dict) -> None: + version_num = client.info()["version"]["number"][0] + version_num = int(version_num) + if version_num >= 8: + client.indices.create(index=index_name, mappings=mapping) + else: + client.indices.create(index=index_name, body={"mappings": mapping}) + + def client_search( + self, client: Any, index_name: str, script_query: Dict, size: int + ) -> Any: + version_num = client.info()["version"]["number"][0] + version_num = int(version_num) + if version_num >= 8: + response = client.search(index=index_name, query=script_query, size=size) + else: + response = client.search( + index=index_name, body={"query": script_query, "size": size} + ) + return response