|
|
|
@ -1,6 +1,7 @@
|
|
|
|
|
"""Wrapper around Elasticsearch vector database."""
|
|
|
|
|
import os
|
|
|
|
|
import uuid
|
|
|
|
|
from typing import Callable, Dict, List
|
|
|
|
|
from typing import Any, Callable, Dict, List
|
|
|
|
|
|
|
|
|
|
from langchain.docstore.document import Document
|
|
|
|
|
from langchain.embeddings.base import Embeddings
|
|
|
|
@ -46,7 +47,7 @@ class ElasticVectorSearch(VectorStore):
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
elastic_url: str,
|
|
|
|
|
elasticsearch_url: str,
|
|
|
|
|
index_name: str,
|
|
|
|
|
mapping: Dict,
|
|
|
|
|
embedding_function: Callable,
|
|
|
|
@ -62,7 +63,7 @@ class ElasticVectorSearch(VectorStore):
|
|
|
|
|
self.embedding_function = embedding_function
|
|
|
|
|
self.index_name = index_name
|
|
|
|
|
try:
|
|
|
|
|
es_client = elasticsearch.Elasticsearch(elastic_url) # noqa
|
|
|
|
|
es_client = elasticsearch.Elasticsearch(elasticsearch_url) # noqa
|
|
|
|
|
except ValueError as e:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Your elasticsearch client string is misformatted. " f"Got error: {e} "
|
|
|
|
@ -89,7 +90,7 @@ class ElasticVectorSearch(VectorStore):
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_texts(
|
|
|
|
|
cls, elastic_url: str, texts: List[str], embedding: Embeddings
|
|
|
|
|
cls, texts: List[str], embedding: Embeddings, **kwargs: Any
|
|
|
|
|
) -> "ElasticVectorSearch":
|
|
|
|
|
"""Construct ElasticVectorSearch wrapper from raw documents.
|
|
|
|
|
|
|
|
|
@ -107,11 +108,21 @@ class ElasticVectorSearch(VectorStore):
|
|
|
|
|
from langchain.embeddings import OpenAIEmbeddings
|
|
|
|
|
embeddings = OpenAIEmbeddings()
|
|
|
|
|
elastic_vector_search = ElasticVectorSearch.from_texts(
|
|
|
|
|
"http://localhost:9200",
|
|
|
|
|
texts,
|
|
|
|
|
embeddings
|
|
|
|
|
embeddings,
|
|
|
|
|
elasticsearch_url="http://localhost:9200"
|
|
|
|
|
)
|
|
|
|
|
"""
|
|
|
|
|
elasticsearch_url = kwargs.get("elasticsearch_url")
|
|
|
|
|
if not elasticsearch_url:
|
|
|
|
|
elasticsearch_url = os.environ.get("ELASTICSEARCH_URL")
|
|
|
|
|
|
|
|
|
|
if elasticsearch_url is None or elasticsearch_url == "":
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Did not find Elasticsearch URL, please add an environment variable"
|
|
|
|
|
" `ELASTICSEARCH_URL` which contains it, or pass"
|
|
|
|
|
" `elasticsearch_url` as a named parameter."
|
|
|
|
|
)
|
|
|
|
|
try:
|
|
|
|
|
import elasticsearch
|
|
|
|
|
from elasticsearch.helpers import bulk
|
|
|
|
@ -121,7 +132,7 @@ class ElasticVectorSearch(VectorStore):
|
|
|
|
|
"Please install it with `pip install elasticearch`."
|
|
|
|
|
)
|
|
|
|
|
try:
|
|
|
|
|
client = elasticsearch.Elasticsearch(elastic_url)
|
|
|
|
|
client = elasticsearch.Elasticsearch(elasticsearch_url)
|
|
|
|
|
except ValueError as e:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Your elasticsearch client string is misformatted. " f"Got error: {e} "
|
|
|
|
@ -144,4 +155,4 @@ class ElasticVectorSearch(VectorStore):
|
|
|
|
|
requests.append(request)
|
|
|
|
|
bulk(client, requests)
|
|
|
|
|
client.indices.refresh(index=index_name)
|
|
|
|
|
return cls(elastic_url, index_name, mapping, embedding.embed_query)
|
|
|
|
|
return cls(elasticsearch_url, index_name, mapping, embedding.embed_query)
|
|
|
|
|