diff --git a/langchain/vectorstores/elastic_vector_search.py b/langchain/vectorstores/elastic_vector_search.py index 17af42c6..dc11a842 100644 --- a/langchain/vectorstores/elastic_vector_search.py +++ b/langchain/vectorstores/elastic_vector_search.py @@ -3,7 +3,7 @@ from __future__ import annotations import uuid from abc import ABC -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Tuple from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings @@ -20,10 +20,15 @@ def _default_text_mapping(dim: int) -> Dict: } -def _default_script_query(query_vector: List[float]) -> Dict: +def _default_script_query(query_vector: List[float], filter: Optional[dict]) -> Dict: + if filter: + ((key, value),) = filter.items() + filter = {"match": {f"metadata.{key}.keyword": f"{value}"}} + else: + filter = {"match_all": {}} return { "script_score": { - "query": {"match_all": {}}, + "query": filter, "script": { "source": "cosineSimilarity(params.query_vector, 'vector') + 1.0", "params": {"query_vector": query_vector}, @@ -187,7 +192,7 @@ class ElasticVectorSearch(VectorStore, ABC): return ids def similarity_search( - self, query: str, k: int = 4, **kwargs: Any + self, query: str, k: int = 4, filter: Optional[dict] = None, **kwargs: Any ) -> List[Document]: """Return docs most similar to query. @@ -198,14 +203,35 @@ class ElasticVectorSearch(VectorStore, ABC): Returns: List of Documents most similar to the query. """ + docs_and_scores = self.similarity_search_with_score(query, k, filter=filter) + documents = [d[0] for d in docs_and_scores] + return documents + + def similarity_search_with_score( + self, query: str, k: int = 4, filter: Optional[dict] = None, **kwargs: Any + ) -> List[Tuple[Document, float]]: + """Return docs most similar to query. + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + Returns: + List of Documents most similar to the query. + """ embedding = self.embedding.embed_query(query) - script_query = _default_script_query(embedding) + script_query = _default_script_query(embedding, filter) response = self.client.search(index=self.index_name, query=script_query, size=k) - hits = [hit["_source"] for hit in response["hits"]["hits"]] - documents = [ - Document(page_content=hit["text"], metadata=hit["metadata"]) for hit in hits + hits = [hit for hit in response["hits"]["hits"]] + docs_and_scores = [ + ( + Document( + page_content=hit["_source"]["text"], + metadata=hit["_source"]["metadata"], + ), + hit["_score"], + ) + for hit in hits ] - return documents + return docs_and_scores @classmethod def from_texts(