forked from Archives/langchain
Add the usage of SSL certificates for Elasticsearch and user password authentication (#5058)
Enhance the code to support SSL authentication for Elasticsearch when using the VectorStore module, as previous versions did not provide this capability. @dev2049 --------- Co-authored-by: caidong <zhucaidong1992@gmail.com> Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
44dc959584
commit
039f8f1abb
@ -7,7 +7,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from langchain.utils import get_from_env
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
@ -114,24 +114,31 @@ class ElasticVectorSearch(VectorStore, ABC):
|
||||
ValueError: If the elasticsearch python package is not installed.
|
||||
"""
|
||||
|
||||
def __init__(self, elasticsearch_url: str, index_name: str, embedding: Embeddings):
|
||||
def __init__(
|
||||
self,
|
||||
elasticsearch_url: str,
|
||||
index_name: str,
|
||||
embedding: Embeddings,
|
||||
*,
|
||||
ssl_verify: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""Initialize with necessary components."""
|
||||
try:
|
||||
import elasticsearch
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import elasticsearch python package. "
|
||||
"Please install it with `pip install elasticsearch`."
|
||||
)
|
||||
self.embedding = embedding
|
||||
self.index_name = index_name
|
||||
_ssl_verify = ssl_verify or {}
|
||||
try:
|
||||
es_client = elasticsearch.Elasticsearch(elasticsearch_url) # noqa
|
||||
self.client = elasticsearch.Elasticsearch(elasticsearch_url, **_ssl_verify)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Your elasticsearch client string is misformatted. Got error: {e} "
|
||||
f"Your elasticsearch client string is mis-formatted. Got error: {e} "
|
||||
)
|
||||
self.client = es_client
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
@ -154,7 +161,7 @@ class ElasticVectorSearch(VectorStore, ABC):
|
||||
from elasticsearch.exceptions import NotFoundError
|
||||
from elasticsearch.helpers import bulk
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import elasticsearch python package. "
|
||||
"Please install it with `pip install elasticsearch`."
|
||||
)
|
||||
@ -239,6 +246,9 @@ class ElasticVectorSearch(VectorStore, ABC):
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
elasticsearch_url: Optional[str] = None,
|
||||
index_name: Optional[str] = None,
|
||||
refresh_indices: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> ElasticVectorSearch:
|
||||
"""Construct ElasticVectorSearch wrapper from raw documents.
|
||||
@ -262,48 +272,12 @@ class ElasticVectorSearch(VectorStore, ABC):
|
||||
elasticsearch_url="http://localhost:9200"
|
||||
)
|
||||
"""
|
||||
elasticsearch_url = get_from_dict_or_env(
|
||||
kwargs, "elasticsearch_url", "ELASTICSEARCH_URL"
|
||||
elasticsearch_url = elasticsearch_url or get_from_env(
|
||||
"elasticsearch_url", "ELASTICSEARCH_URL"
|
||||
)
|
||||
try:
|
||||
import elasticsearch
|
||||
from elasticsearch.exceptions import NotFoundError
|
||||
from elasticsearch.helpers import bulk
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import elasticsearch python package. "
|
||||
"Please install it with `pip install elasticsearch`."
|
||||
index_name = index_name or uuid.uuid4().hex
|
||||
vectorsearch = cls(elasticsearch_url, index_name, embedding, **kwargs)
|
||||
vectorsearch.add_texts(
|
||||
texts, metadatas=metadatas, refresh_indices=refresh_indices
|
||||
)
|
||||
try:
|
||||
client = elasticsearch.Elasticsearch(elasticsearch_url)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
"Your elasticsearch client string is misformatted. " f"Got error: {e} "
|
||||
)
|
||||
index_name = kwargs.get("index_name", uuid.uuid4().hex)
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
dim = len(embeddings[0])
|
||||
mapping = _default_text_mapping(dim)
|
||||
|
||||
# check to see if the index already exists
|
||||
try:
|
||||
client.indices.get(index=index_name)
|
||||
except NotFoundError:
|
||||
# TODO would be nice to create index before embedding,
|
||||
# just to save expensive steps for last
|
||||
client.indices.create(index=index_name, mappings=mapping)
|
||||
|
||||
requests = []
|
||||
for i, text in enumerate(texts):
|
||||
metadata = metadatas[i] if metadatas else {}
|
||||
request = {
|
||||
"_op_type": "index",
|
||||
"_index": index_name,
|
||||
"vector": embeddings[i],
|
||||
"text": text,
|
||||
"metadata": metadata,
|
||||
}
|
||||
requests.append(request)
|
||||
bulk(client, requests)
|
||||
client.indices.refresh(index=index_name)
|
||||
return cls(elasticsearch_url, index_name, embedding)
|
||||
return vectorsearch
|
||||
|
@ -48,6 +48,23 @@ class TestElasticsearch:
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
def test_similarity_search_with_ssl_verify(self, elasticsearch_url: str) -> None:
|
||||
"""Test end to end construction and search with ssl verify."""
|
||||
ssl_verify = {
|
||||
"verify_certs": True,
|
||||
"basic_auth": ("ES_USER", "ES_PASSWORD"),
|
||||
"ca_certs": "ES_CA_CERTS_PATH",
|
||||
}
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = ElasticVectorSearch.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
elasticsearch_url=elasticsearch_url,
|
||||
ssl_verify=ssl_verify,
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
def test_similarity_search_with_metadata(self, elasticsearch_url: str) -> None:
|
||||
"""Test end to end construction and search with metadata."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
|
Loading…
Reference in New Issue
Block a user