From 039f8f1abb03cc7983282bfa285920d00231e360 Mon Sep 17 00:00:00 2001 From: Donger Date: Tue, 23 May 2023 02:51:32 +0800 Subject: [PATCH] Add the usage of SSL certificates for Elasticsearch and user password authentication (#5058) Enhance the code to support SSL authentication for Elasticsearch when using the VectorStore module, as previous versions did not provide this capability. @dev2049 --------- Co-authored-by: caidong Co-authored-by: Dev 2049 --- .../vectorstores/elastic_vector_search.py | 76 ++++++------------- .../vectorstores/test_elasticsearch.py | 17 +++++ 2 files changed, 42 insertions(+), 51 deletions(-) diff --git a/langchain/vectorstores/elastic_vector_search.py b/langchain/vectorstores/elastic_vector_search.py index dc11a842..6663e79d 100644 --- a/langchain/vectorstores/elastic_vector_search.py +++ b/langchain/vectorstores/elastic_vector_search.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings -from langchain.utils import get_from_dict_or_env +from langchain.utils import get_from_env from langchain.vectorstores.base import VectorStore @@ -114,24 +114,31 @@ class ElasticVectorSearch(VectorStore, ABC): ValueError: If the elasticsearch python package is not installed. """ - def __init__(self, elasticsearch_url: str, index_name: str, embedding: Embeddings): + def __init__( + self, + elasticsearch_url: str, + index_name: str, + embedding: Embeddings, + *, + ssl_verify: Optional[Dict[str, Any]] = None, + ): """Initialize with necessary components.""" try: import elasticsearch except ImportError: - raise ValueError( + raise ImportError( "Could not import elasticsearch python package. " "Please install it with `pip install elasticsearch`." ) self.embedding = embedding self.index_name = index_name + _ssl_verify = ssl_verify or {} try: - es_client = elasticsearch.Elasticsearch(elasticsearch_url) # noqa + self.client = elasticsearch.Elasticsearch(elasticsearch_url, **_ssl_verify) except ValueError as e: raise ValueError( - f"Your elasticsearch client string is misformatted. Got error: {e} " + f"Your elasticsearch client string is mis-formatted. Got error: {e} " ) - self.client = es_client def add_texts( self, @@ -154,7 +161,7 @@ class ElasticVectorSearch(VectorStore, ABC): from elasticsearch.exceptions import NotFoundError from elasticsearch.helpers import bulk except ImportError: - raise ValueError( + raise ImportError( "Could not import elasticsearch python package. " "Please install it with `pip install elasticsearch`." ) @@ -239,6 +246,9 @@ class ElasticVectorSearch(VectorStore, ABC): texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, + elasticsearch_url: Optional[str] = None, + index_name: Optional[str] = None, + refresh_indices: bool = True, **kwargs: Any, ) -> ElasticVectorSearch: """Construct ElasticVectorSearch wrapper from raw documents. @@ -262,48 +272,12 @@ class ElasticVectorSearch(VectorStore, ABC): elasticsearch_url="http://localhost:9200" ) """ - elasticsearch_url = get_from_dict_or_env( - kwargs, "elasticsearch_url", "ELASTICSEARCH_URL" + elasticsearch_url = elasticsearch_url or get_from_env( + "elasticsearch_url", "ELASTICSEARCH_URL" ) - try: - import elasticsearch - from elasticsearch.exceptions import NotFoundError - from elasticsearch.helpers import bulk - except ImportError: - raise ValueError( - "Could not import elasticsearch python package. " - "Please install it with `pip install elasticsearch`." - ) - try: - client = elasticsearch.Elasticsearch(elasticsearch_url) - except ValueError as e: - raise ValueError( - "Your elasticsearch client string is misformatted. " f"Got error: {e} " - ) - index_name = kwargs.get("index_name", uuid.uuid4().hex) - embeddings = embedding.embed_documents(texts) - dim = len(embeddings[0]) - mapping = _default_text_mapping(dim) - - # check to see if the index already exists - try: - client.indices.get(index=index_name) - except NotFoundError: - # TODO would be nice to create index before embedding, - # just to save expensive steps for last - client.indices.create(index=index_name, mappings=mapping) - - requests = [] - for i, text in enumerate(texts): - metadata = metadatas[i] if metadatas else {} - request = { - "_op_type": "index", - "_index": index_name, - "vector": embeddings[i], - "text": text, - "metadata": metadata, - } - requests.append(request) - bulk(client, requests) - client.indices.refresh(index=index_name) - return cls(elasticsearch_url, index_name, embedding) + index_name = index_name or uuid.uuid4().hex + vectorsearch = cls(elasticsearch_url, index_name, embedding, **kwargs) + vectorsearch.add_texts( + texts, metadatas=metadatas, refresh_indices=refresh_indices + ) + return vectorsearch diff --git a/tests/integration_tests/vectorstores/test_elasticsearch.py b/tests/integration_tests/vectorstores/test_elasticsearch.py index b79d2b6b..cb2d0c0e 100644 --- a/tests/integration_tests/vectorstores/test_elasticsearch.py +++ b/tests/integration_tests/vectorstores/test_elasticsearch.py @@ -48,6 +48,23 @@ class TestElasticsearch: output = docsearch.similarity_search("foo", k=1) assert output == [Document(page_content="foo")] + def test_similarity_search_with_ssl_verify(self, elasticsearch_url: str) -> None: + """Test end to end construction and search with ssl verify.""" + ssl_verify = { + "verify_certs": True, + "basic_auth": ("ES_USER", "ES_PASSWORD"), + "ca_certs": "ES_CA_CERTS_PATH", + } + texts = ["foo", "bar", "baz"] + docsearch = ElasticVectorSearch.from_texts( + texts, + FakeEmbeddings(), + elasticsearch_url=elasticsearch_url, + ssl_verify=ssl_verify, + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + def test_similarity_search_with_metadata(self, elasticsearch_url: str) -> None: """Test end to end construction and search with metadata.""" texts = ["foo", "bar", "baz"]