mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
ES similarity_search_with_score() and metadata filter (#3046)
Add similarity_search_with_score() to ElasticVectorSearch, add metadata filter to both similarity_search() and similarity_search_with_score()
This commit is contained in:
parent
416f3bdf11
commit
239dc10852
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
@ -20,10 +20,15 @@ def _default_text_mapping(dim: int) -> Dict:
|
||||
}
|
||||
|
||||
|
||||
def _default_script_query(query_vector: List[float]) -> Dict:
|
||||
def _default_script_query(query_vector: List[float], filter: Optional[dict]) -> Dict:
|
||||
if filter:
|
||||
((key, value),) = filter.items()
|
||||
filter = {"match": {f"metadata.{key}.keyword": f"{value}"}}
|
||||
else:
|
||||
filter = {"match_all": {}}
|
||||
return {
|
||||
"script_score": {
|
||||
"query": {"match_all": {}},
|
||||
"query": filter,
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.query_vector, 'vector') + 1.0",
|
||||
"params": {"query_vector": query_vector},
|
||||
@ -187,7 +192,7 @@ class ElasticVectorSearch(VectorStore, ABC):
|
||||
return ids
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
self, query: str, k: int = 4, filter: Optional[dict] = None, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
@ -198,15 +203,36 @@ class ElasticVectorSearch(VectorStore, ABC):
|
||||
Returns:
|
||||
List of Documents most similar to the query.
|
||||
"""
|
||||
embedding = self.embedding.embed_query(query)
|
||||
script_query = _default_script_query(embedding)
|
||||
response = self.client.search(index=self.index_name, query=script_query, size=k)
|
||||
hits = [hit["_source"] for hit in response["hits"]["hits"]]
|
||||
documents = [
|
||||
Document(page_content=hit["text"], metadata=hit["metadata"]) for hit in hits
|
||||
]
|
||||
docs_and_scores = self.similarity_search_with_score(query, k, filter=filter)
|
||||
documents = [d[0] for d in docs_and_scores]
|
||||
return documents
|
||||
|
||||
def similarity_search_with_score(
|
||||
self, query: str, k: int = 4, filter: Optional[dict] = None, **kwargs: Any
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to query.
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
Returns:
|
||||
List of Documents most similar to the query.
|
||||
"""
|
||||
embedding = self.embedding.embed_query(query)
|
||||
script_query = _default_script_query(embedding, filter)
|
||||
response = self.client.search(index=self.index_name, query=script_query, size=k)
|
||||
hits = [hit for hit in response["hits"]["hits"]]
|
||||
docs_and_scores = [
|
||||
(
|
||||
Document(
|
||||
page_content=hit["_source"]["text"],
|
||||
metadata=hit["_source"]["metadata"],
|
||||
),
|
||||
hit["_score"],
|
||||
)
|
||||
for hit in hits
|
||||
]
|
||||
return docs_and_scores
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
|
Loading…
Reference in New Issue
Block a user