Add the usage of SSL certificates for Elasticsearch and user password authentication (#5058)

Enhance the code to support SSL authentication for Elasticsearch when
using the VectorStore module, as previous versions did not provide this
capability.
@dev2049

---------

Co-authored-by: caidong <zhucaidong1992@gmail.com>
Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
Donger 2023-05-23 02:51:32 +08:00 committed by GitHub
parent 44dc959584
commit 039f8f1abb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 51 deletions

View File

@ -7,7 +7,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
from langchain.utils import get_from_env
from langchain.vectorstores.base import VectorStore
@ -114,24 +114,31 @@ class ElasticVectorSearch(VectorStore, ABC):
ValueError: If the elasticsearch python package is not installed.
"""
def __init__(self, elasticsearch_url: str, index_name: str, embedding: Embeddings):
def __init__(
self,
elasticsearch_url: str,
index_name: str,
embedding: Embeddings,
*,
ssl_verify: Optional[Dict[str, Any]] = None,
):
"""Initialize with necessary components."""
try:
import elasticsearch
except ImportError:
raise ValueError(
raise ImportError(
"Could not import elasticsearch python package. "
"Please install it with `pip install elasticsearch`."
)
self.embedding = embedding
self.index_name = index_name
_ssl_verify = ssl_verify or {}
try:
es_client = elasticsearch.Elasticsearch(elasticsearch_url) # noqa
self.client = elasticsearch.Elasticsearch(elasticsearch_url, **_ssl_verify)
except ValueError as e:
raise ValueError(
f"Your elasticsearch client string is misformatted. Got error: {e} "
f"Your elasticsearch client string is mis-formatted. Got error: {e} "
)
self.client = es_client
def add_texts(
self,
@ -154,7 +161,7 @@ class ElasticVectorSearch(VectorStore, ABC):
from elasticsearch.exceptions import NotFoundError
from elasticsearch.helpers import bulk
except ImportError:
raise ValueError(
raise ImportError(
"Could not import elasticsearch python package. "
"Please install it with `pip install elasticsearch`."
)
@ -239,6 +246,9 @@ class ElasticVectorSearch(VectorStore, ABC):
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
elasticsearch_url: Optional[str] = None,
index_name: Optional[str] = None,
refresh_indices: bool = True,
**kwargs: Any,
) -> ElasticVectorSearch:
"""Construct ElasticVectorSearch wrapper from raw documents.
@ -262,48 +272,12 @@ class ElasticVectorSearch(VectorStore, ABC):
elasticsearch_url="http://localhost:9200"
)
"""
elasticsearch_url = get_from_dict_or_env(
kwargs, "elasticsearch_url", "ELASTICSEARCH_URL"
elasticsearch_url = elasticsearch_url or get_from_env(
"elasticsearch_url", "ELASTICSEARCH_URL"
)
try:
import elasticsearch
from elasticsearch.exceptions import NotFoundError
from elasticsearch.helpers import bulk
except ImportError:
raise ValueError(
"Could not import elasticsearch python package. "
"Please install it with `pip install elasticsearch`."
)
try:
client = elasticsearch.Elasticsearch(elasticsearch_url)
except ValueError as e:
raise ValueError(
"Your elasticsearch client string is misformatted. " f"Got error: {e} "
)
index_name = kwargs.get("index_name", uuid.uuid4().hex)
embeddings = embedding.embed_documents(texts)
dim = len(embeddings[0])
mapping = _default_text_mapping(dim)
# check to see if the index already exists
try:
client.indices.get(index=index_name)
except NotFoundError:
# TODO would be nice to create index before embedding,
# just to save expensive steps for last
client.indices.create(index=index_name, mappings=mapping)
requests = []
for i, text in enumerate(texts):
metadata = metadatas[i] if metadatas else {}
request = {
"_op_type": "index",
"_index": index_name,
"vector": embeddings[i],
"text": text,
"metadata": metadata,
}
requests.append(request)
bulk(client, requests)
client.indices.refresh(index=index_name)
return cls(elasticsearch_url, index_name, embedding)
index_name = index_name or uuid.uuid4().hex
vectorsearch = cls(elasticsearch_url, index_name, embedding, **kwargs)
vectorsearch.add_texts(
texts, metadatas=metadatas, refresh_indices=refresh_indices
)
return vectorsearch

View File

@ -48,6 +48,23 @@ class TestElasticsearch:
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]
def test_similarity_search_with_ssl_verify(self, elasticsearch_url: str) -> None:
"""Test end to end construction and search with ssl verify."""
ssl_verify = {
"verify_certs": True,
"basic_auth": ("ES_USER", "ES_PASSWORD"),
"ca_certs": "ES_CA_CERTS_PATH",
}
texts = ["foo", "bar", "baz"]
docsearch = ElasticVectorSearch.from_texts(
texts,
FakeEmbeddings(),
elasticsearch_url=elasticsearch_url,
ssl_verify=ssl_verify,
)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]
def test_similarity_search_with_metadata(self, elasticsearch_url: str) -> None:
"""Test end to end construction and search with metadata."""
texts = ["foo", "bar", "baz"]