OpenSearch: Add Similarity Search with Score (#4089)

### Description
Add `similarity_search_with_score` method for OpenSearch to return
scores along with documents in the search results

Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
This commit is contained in:
Naveen Tatikonda 2023-05-08 18:35:21 -05:00 committed by GitHub
parent b3ecce0545
commit 782df1db10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 51 additions and 11 deletions

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import uuid import uuid
from typing import Any, Dict, Iterable, List, Optional from typing import Any, Dict, Iterable, List, Optional, Tuple
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
@ -410,6 +410,27 @@ class OpenSearchVectorSearch(VectorStore):
pre_filter: script_score query to pre-filter documents before identifying pre_filter: script_score query to pre-filter documents before identifying
nearest neighbors; default: {"match_all": {}} nearest neighbors; default: {"match_all": {}}
""" """
docs_with_scores = self.similarity_search_with_score(query, k, **kwargs)
return [doc[0] for doc in docs_with_scores]
def similarity_search_with_score(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Tuple[Document, float]]:
"""Return docs and it's scores most similar to query.
By default supports Approximate Search.
Also supports Script Scoring and Painless Scripting.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of Documents along with its scores most similar to the query.
Optional Args:
same as `similarity_search`
"""
embedding = self.embedding_function.embed_query(query) embedding = self.embedding_function.embed_query(query)
search_type = _get_kwargs_value(kwargs, "search_type", "approximate_search") search_type = _get_kwargs_value(kwargs, "search_type", "approximate_search")
text_field = _get_kwargs_value(kwargs, "text_field", "text") text_field = _get_kwargs_value(kwargs, "text_field", "text")
@ -454,17 +475,20 @@ class OpenSearchVectorSearch(VectorStore):
raise ValueError("Invalid `search_type` provided as an argument") raise ValueError("Invalid `search_type` provided as an argument")
response = self.client.search(index=self.index_name, body=search_query) response = self.client.search(index=self.index_name, body=search_query)
hits = [hit["_source"] for hit in response["hits"]["hits"][:k]] hits = [hit for hit in response["hits"]["hits"][:k]]
documents = [ documents_with_scores = [
Document( (
page_content=hit[text_field], Document(
metadata=hit page_content=hit["_source"][text_field],
if metadata_field == "*" or metadata_field not in hit metadata=hit["_source"]
else hit[metadata_field], if metadata_field == "*" or metadata_field not in hit["_source"]
else hit["_source"][metadata_field],
),
hit["_score"],
) )
for hit in hits for hit in hits
] ]
return documents return documents_with_scores
@classmethod @classmethod
def from_texts( def from_texts(

View File

@ -23,6 +23,22 @@ def test_opensearch() -> None:
assert output == [Document(page_content="foo")] assert output == [Document(page_content="foo")]
def test_similarity_search_with_score() -> None:
"""Test similarity search with score using Approximate Search."""
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = OpenSearchVectorSearch.from_texts(
texts,
FakeEmbeddings(),
metadatas=metadatas,
opensearch_url=DEFAULT_OPENSEARCH_URL,
)
output = docsearch.similarity_search_with_score("foo", k=2)
assert output == [
(Document(page_content="foo", metadata={"page": 0}), 1.0),
(Document(page_content="bar", metadata={"page": 1}), 0.5),
]
def test_opensearch_with_custom_field_name() -> None: def test_opensearch_with_custom_field_name() -> None:
"""Test indexing and search using custom vector field and text field name.""" """Test indexing and search using custom vector field and text field name."""
docsearch = OpenSearchVectorSearch.from_texts( docsearch = OpenSearchVectorSearch.from_texts(
@ -178,7 +194,7 @@ def test_appx_search_with_lucene_filter() -> None:
def test_opensearch_with_custom_field_name_appx_true() -> None: def test_opensearch_with_custom_field_name_appx_true() -> None:
"""Test Approximate Search with custom field name appx true.""" """Test Approximate Search with custom field name appx true."""
text_input = ["test", "add", "text", "method"] text_input = ["add", "test", "text", "method"]
docsearch = OpenSearchVectorSearch.from_texts( docsearch = OpenSearchVectorSearch.from_texts(
text_input, text_input,
FakeEmbeddings(), FakeEmbeddings(),
@ -191,7 +207,7 @@ def test_opensearch_with_custom_field_name_appx_true() -> None:
def test_opensearch_with_custom_field_name_appx_false() -> None: def test_opensearch_with_custom_field_name_appx_false() -> None:
"""Test Approximate Search with custom field name appx true.""" """Test Approximate Search with custom field name appx true."""
text_input = ["test", "add", "text", "method"] text_input = ["add", "test", "text", "method"]
docsearch = OpenSearchVectorSearch.from_texts( docsearch = OpenSearchVectorSearch.from_texts(
text_input, FakeEmbeddings(), opensearch_url=DEFAULT_OPENSEARCH_URL text_input, FakeEmbeddings(), opensearch_url=DEFAULT_OPENSEARCH_URL
) )