forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
154 lines
5.2 KiB
Python
154 lines
5.2 KiB
Python
"""Wrapper around Elasticsearch vector database."""
|
|
import uuid
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
from langchain.docstore.document import Document
|
|
from langchain.embeddings.base import Embeddings
|
|
from langchain.utils import get_from_dict_or_env
|
|
from langchain.vectorstores.base import VectorStore
|
|
|
|
|
|
def _default_text_mapping(dim: int) -> Dict:
|
|
return {
|
|
"properties": {
|
|
"text": {"type": "text"},
|
|
"vector": {"type": "dense_vector", "dims": dim},
|
|
}
|
|
}
|
|
|
|
|
|
def _default_script_query(query_vector: List[int]) -> Dict:
|
|
return {
|
|
"script_score": {
|
|
"query": {"match_all": {}},
|
|
"script": {
|
|
"source": "cosineSimilarity(params.query_vector, 'vector') + 1.0",
|
|
"params": {"query_vector": query_vector},
|
|
},
|
|
}
|
|
}
|
|
|
|
|
|
class ElasticVectorSearch(VectorStore):
|
|
"""Wrapper around Elasticsearch as a vector database.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain import ElasticVectorSearch
|
|
elastic_vector_search = ElasticVectorSearch(
|
|
"http://localhost:9200",
|
|
"embeddings",
|
|
embedding_function
|
|
)
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self, elasticsearch_url: str, index_name: str, embedding_function: Callable
|
|
):
|
|
"""Initialize with necessary components."""
|
|
try:
|
|
import elasticsearch
|
|
except ImportError:
|
|
raise ValueError(
|
|
"Could not import elasticsearch python package. "
|
|
"Please install it with `pip install elasticearch`."
|
|
)
|
|
self.embedding_function = embedding_function
|
|
self.index_name = index_name
|
|
try:
|
|
es_client = elasticsearch.Elasticsearch(elasticsearch_url) # noqa
|
|
except ValueError as e:
|
|
raise ValueError(
|
|
f"Your elasticsearch client string is misformatted. Got error: {e} "
|
|
)
|
|
self.client = es_client
|
|
|
|
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
|
|
"""Return docs most similar to query.
|
|
|
|
Args:
|
|
query: Text to look up documents similar to.
|
|
k: Number of Documents to return. Defaults to 4.
|
|
|
|
Returns:
|
|
List of Documents most similar to the query.
|
|
"""
|
|
embedding = self.embedding_function(query)
|
|
script_query = _default_script_query(embedding)
|
|
response = self.client.search(index=self.index_name, query=script_query)
|
|
hits = [hit["_source"] for hit in response["hits"]["hits"][:k]]
|
|
documents = [
|
|
Document(page_content=hit["text"], metadata=hit["metadata"]) for hit in hits
|
|
]
|
|
return documents
|
|
|
|
@classmethod
|
|
def from_texts(
|
|
cls,
|
|
texts: List[str],
|
|
embedding: Embeddings,
|
|
metadatas: Optional[List[dict]] = None,
|
|
**kwargs: Any,
|
|
) -> "ElasticVectorSearch":
|
|
"""Construct ElasticVectorSearch wrapper from raw documents.
|
|
|
|
This is a user-friendly interface that:
|
|
1. Embeds documents.
|
|
2. Creates a new index for the embeddings in the Elasticsearch instance.
|
|
3. Adds the documents to the newly created Elasticsearch index.
|
|
|
|
This is intended to be a quick way to get started.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain import ElasticVectorSearch
|
|
from langchain.embeddings import OpenAIEmbeddings
|
|
embeddings = OpenAIEmbeddings()
|
|
elastic_vector_search = ElasticVectorSearch.from_texts(
|
|
texts,
|
|
embeddings,
|
|
elasticsearch_url="http://localhost:9200"
|
|
)
|
|
"""
|
|
elasticsearch_url = get_from_dict_or_env(
|
|
kwargs, "elasticsearch_url", "ELASTICSEARCH_URL"
|
|
)
|
|
try:
|
|
import elasticsearch
|
|
from elasticsearch.helpers import bulk
|
|
except ImportError:
|
|
raise ValueError(
|
|
"Could not import elasticsearch python package. "
|
|
"Please install it with `pip install elasticearch`."
|
|
)
|
|
try:
|
|
client = elasticsearch.Elasticsearch(elasticsearch_url)
|
|
except ValueError as e:
|
|
raise ValueError(
|
|
"Your elasticsearch client string is misformatted. " f"Got error: {e} "
|
|
)
|
|
index_name = uuid.uuid4().hex
|
|
embeddings = embedding.embed_documents(texts)
|
|
dim = len(embeddings[0])
|
|
mapping = _default_text_mapping(dim)
|
|
# TODO would be nice to create index before embedding,
|
|
# just to save expensive steps for last
|
|
client.indices.create(index=index_name, mappings=mapping)
|
|
requests = []
|
|
for i, text in enumerate(texts):
|
|
metadata = metadatas[i] if metadatas else {}
|
|
request = {
|
|
"_op_type": "index",
|
|
"_index": index_name,
|
|
"vector": embeddings[i],
|
|
"text": text,
|
|
"metadata": metadata,
|
|
}
|
|
requests.append(request)
|
|
bulk(client, requests)
|
|
client.indices.refresh(index=index_name)
|
|
return cls(elasticsearch_url, index_name, embedding.embed_query)
|