diff --git a/langchain/vectorstores/elastic_vector_search.py b/langchain/vectorstores/elastic_vector_search.py index 4f37800e..57a73dcd 100644 --- a/langchain/vectorstores/elastic_vector_search.py +++ b/langchain/vectorstores/elastic_vector_search.py @@ -2,7 +2,7 @@ from __future__ import annotations import uuid -from typing import Any, Callable, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings @@ -19,7 +19,7 @@ def _default_text_mapping(dim: int) -> Dict: } -def _default_script_query(query_vector: List[int]) -> Dict: +def _default_script_query(query_vector: List[float]) -> Dict: return { "script_score": { "query": {"match_all": {}}, @@ -41,14 +41,12 @@ class ElasticVectorSearch(VectorStore): elastic_vector_search = ElasticVectorSearch( "http://localhost:9200", "embeddings", - embedding_function + embedding ) """ - def __init__( - self, elasticsearch_url: str, index_name: str, embedding_function: Callable - ): + def __init__(self, elasticsearch_url: str, index_name: str, embedding: Embeddings): """Initialize with necessary components.""" try: import elasticsearch @@ -57,7 +55,7 @@ class ElasticVectorSearch(VectorStore): "Could not import elasticsearch python package. " "Please install it with `pip install elasticsearch`." ) - self.embedding_function = embedding_function + self.embedding = embedding self.index_name = index_name try: es_client = elasticsearch.Elasticsearch(elasticsearch_url) # noqa @@ -91,13 +89,14 @@ class ElasticVectorSearch(VectorStore): ) requests = [] ids = [] + embeddings = self.embedding.embed_documents(list(texts)) for i, text in enumerate(texts): metadata = metadatas[i] if metadatas else {} _id = str(uuid.uuid4()) request = { "_op_type": "index", "_index": self.index_name, - "vector": self.embedding_function(text), + "vector": embeddings[i], "text": text, "metadata": metadata, "_id": _id, @@ -121,7 +120,7 @@ class ElasticVectorSearch(VectorStore): Returns: List of Documents most similar to the query. """ - embedding = self.embedding_function(query) + embedding = self.embedding.embed_query(query) script_query = _default_script_query(embedding) response = self.client.search(index=self.index_name, query=script_query) hits = [hit["_source"] for hit in response["hits"]["hits"][:k]] @@ -196,4 +195,4 @@ class ElasticVectorSearch(VectorStore): requests.append(request) bulk(client, requests) client.indices.refresh(index=index_name) - return cls(elasticsearch_url, index_name, embedding.embed_query) + return cls(elasticsearch_url, index_name, embedding)