OpenSearch: Add Similarity Search with Score (#4089)

### Description
Add `similarity_search_with_score` method for OpenSearch to return
scores along with documents in the search results

Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
This commit is contained in:
Naveen Tatikonda 2023-05-08 18:35:21 -05:00 committed by GitHub
parent b3ecce0545
commit 782df1db10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 51 additions and 11 deletions

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import uuid
from typing import Any, Dict, Iterable, List, Optional
from typing import Any, Dict, Iterable, List, Optional, Tuple
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
@ -410,6 +410,27 @@ class OpenSearchVectorSearch(VectorStore):
pre_filter: script_score query to pre-filter documents before identifying
nearest neighbors; default: {"match_all": {}}
"""
docs_with_scores = self.similarity_search_with_score(query, k, **kwargs)
return [doc[0] for doc in docs_with_scores]
def similarity_search_with_score(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Tuple[Document, float]]:
"""Return docs and it's scores most similar to query.
By default supports Approximate Search.
Also supports Script Scoring and Painless Scripting.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of Documents along with its scores most similar to the query.
Optional Args:
same as `similarity_search`
"""
embedding = self.embedding_function.embed_query(query)
search_type = _get_kwargs_value(kwargs, "search_type", "approximate_search")
text_field = _get_kwargs_value(kwargs, "text_field", "text")
@ -454,17 +475,20 @@ class OpenSearchVectorSearch(VectorStore):
raise ValueError("Invalid `search_type` provided as an argument")
response = self.client.search(index=self.index_name, body=search_query)
hits = [hit["_source"] for hit in response["hits"]["hits"][:k]]
documents = [
Document(
page_content=hit[text_field],
metadata=hit
if metadata_field == "*" or metadata_field not in hit
else hit[metadata_field],
hits = [hit for hit in response["hits"]["hits"][:k]]
documents_with_scores = [
(
Document(
page_content=hit["_source"][text_field],
metadata=hit["_source"]
if metadata_field == "*" or metadata_field not in hit["_source"]
else hit["_source"][metadata_field],
),
hit["_score"],
)
for hit in hits
]
return documents
return documents_with_scores
@classmethod
def from_texts(

View File

@ -23,6 +23,22 @@ def test_opensearch() -> None:
assert output == [Document(page_content="foo")]
def test_similarity_search_with_score() -> None:
"""Test similarity search with score using Approximate Search."""
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = OpenSearchVectorSearch.from_texts(
texts,
FakeEmbeddings(),
metadatas=metadatas,
opensearch_url=DEFAULT_OPENSEARCH_URL,
)
output = docsearch.similarity_search_with_score("foo", k=2)
assert output == [
(Document(page_content="foo", metadata={"page": 0}), 1.0),
(Document(page_content="bar", metadata={"page": 1}), 0.5),
]
def test_opensearch_with_custom_field_name() -> None:
"""Test indexing and search using custom vector field and text field name."""
docsearch = OpenSearchVectorSearch.from_texts(
@ -178,7 +194,7 @@ def test_appx_search_with_lucene_filter() -> None:
def test_opensearch_with_custom_field_name_appx_true() -> None:
"""Test Approximate Search with custom field name appx true."""
text_input = ["test", "add", "text", "method"]
text_input = ["add", "test", "text", "method"]
docsearch = OpenSearchVectorSearch.from_texts(
text_input,
FakeEmbeddings(),
@ -191,7 +207,7 @@ def test_opensearch_with_custom_field_name_appx_true() -> None:
def test_opensearch_with_custom_field_name_appx_false() -> None:
"""Test Approximate Search with custom field name appx true."""
text_input = ["test", "add", "text", "method"]
text_input = ["add", "test", "text", "method"]
docsearch = OpenSearchVectorSearch.from_texts(
text_input, FakeEmbeddings(), opensearch_url=DEFAULT_OPENSEARCH_URL
)