forked from Archives/langchain
OpenSearch: Add Similarity Search with Score (#4089)
### Description Add `similarity_search_with_score` method for OpenSearch to return scores along with documents in the search results Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
This commit is contained in:
parent
b3ecce0545
commit
782df1db10
@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
@ -410,6 +410,27 @@ class OpenSearchVectorSearch(VectorStore):
|
||||
pre_filter: script_score query to pre-filter documents before identifying
|
||||
nearest neighbors; default: {"match_all": {}}
|
||||
"""
|
||||
docs_with_scores = self.similarity_search_with_score(query, k, **kwargs)
|
||||
return [doc[0] for doc in docs_with_scores]
|
||||
|
||||
def similarity_search_with_score(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs and it's scores most similar to query.
|
||||
|
||||
By default supports Approximate Search.
|
||||
Also supports Script Scoring and Painless Scripting.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
List of Documents along with its scores most similar to the query.
|
||||
|
||||
Optional Args:
|
||||
same as `similarity_search`
|
||||
"""
|
||||
embedding = self.embedding_function.embed_query(query)
|
||||
search_type = _get_kwargs_value(kwargs, "search_type", "approximate_search")
|
||||
text_field = _get_kwargs_value(kwargs, "text_field", "text")
|
||||
@ -454,17 +475,20 @@ class OpenSearchVectorSearch(VectorStore):
|
||||
raise ValueError("Invalid `search_type` provided as an argument")
|
||||
|
||||
response = self.client.search(index=self.index_name, body=search_query)
|
||||
hits = [hit["_source"] for hit in response["hits"]["hits"][:k]]
|
||||
documents = [
|
||||
Document(
|
||||
page_content=hit[text_field],
|
||||
metadata=hit
|
||||
if metadata_field == "*" or metadata_field not in hit
|
||||
else hit[metadata_field],
|
||||
hits = [hit for hit in response["hits"]["hits"][:k]]
|
||||
documents_with_scores = [
|
||||
(
|
||||
Document(
|
||||
page_content=hit["_source"][text_field],
|
||||
metadata=hit["_source"]
|
||||
if metadata_field == "*" or metadata_field not in hit["_source"]
|
||||
else hit["_source"][metadata_field],
|
||||
),
|
||||
hit["_score"],
|
||||
)
|
||||
for hit in hits
|
||||
]
|
||||
return documents
|
||||
return documents_with_scores
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
|
@ -23,6 +23,22 @@ def test_opensearch() -> None:
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
|
||||
def test_similarity_search_with_score() -> None:
|
||||
"""Test similarity search with score using Approximate Search."""
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = OpenSearchVectorSearch.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
opensearch_url=DEFAULT_OPENSEARCH_URL,
|
||||
)
|
||||
output = docsearch.similarity_search_with_score("foo", k=2)
|
||||
assert output == [
|
||||
(Document(page_content="foo", metadata={"page": 0}), 1.0),
|
||||
(Document(page_content="bar", metadata={"page": 1}), 0.5),
|
||||
]
|
||||
|
||||
|
||||
def test_opensearch_with_custom_field_name() -> None:
|
||||
"""Test indexing and search using custom vector field and text field name."""
|
||||
docsearch = OpenSearchVectorSearch.from_texts(
|
||||
@ -178,7 +194,7 @@ def test_appx_search_with_lucene_filter() -> None:
|
||||
|
||||
def test_opensearch_with_custom_field_name_appx_true() -> None:
|
||||
"""Test Approximate Search with custom field name appx true."""
|
||||
text_input = ["test", "add", "text", "method"]
|
||||
text_input = ["add", "test", "text", "method"]
|
||||
docsearch = OpenSearchVectorSearch.from_texts(
|
||||
text_input,
|
||||
FakeEmbeddings(),
|
||||
@ -191,7 +207,7 @@ def test_opensearch_with_custom_field_name_appx_true() -> None:
|
||||
|
||||
def test_opensearch_with_custom_field_name_appx_false() -> None:
|
||||
"""Test Approximate Search with custom field name appx true."""
|
||||
text_input = ["test", "add", "text", "method"]
|
||||
text_input = ["add", "test", "text", "method"]
|
||||
docsearch = OpenSearchVectorSearch.from_texts(
|
||||
text_input, FakeEmbeddings(), opensearch_url=DEFAULT_OPENSEARCH_URL
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user