From 782df1db10aa0c30a48729020097a85dbab12575 Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Mon, 8 May 2023 18:35:21 -0500 Subject: [PATCH] OpenSearch: Add Similarity Search with Score (#4089) ### Description Add `similarity_search_with_score` method for OpenSearch to return scores along with documents in the search results Signed-off-by: Naveen Tatikonda --- .../vectorstores/opensearch_vector_search.py | 42 +++++++++++++++---- .../vectorstores/test_opensearch.py | 20 ++++++++- 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/langchain/vectorstores/opensearch_vector_search.py b/langchain/vectorstores/opensearch_vector_search.py index 5f432b2d..624d62c5 100644 --- a/langchain/vectorstores/opensearch_vector_search.py +++ b/langchain/vectorstores/opensearch_vector_search.py @@ -2,7 +2,7 @@ from __future__ import annotations import uuid -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Tuple from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings @@ -410,6 +410,27 @@ class OpenSearchVectorSearch(VectorStore): pre_filter: script_score query to pre-filter documents before identifying nearest neighbors; default: {"match_all": {}} """ + docs_with_scores = self.similarity_search_with_score(query, k, **kwargs) + return [doc[0] for doc in docs_with_scores] + + def similarity_search_with_score( + self, query: str, k: int = 4, **kwargs: Any + ) -> List[Tuple[Document, float]]: + """Return docs and it's scores most similar to query. + + By default supports Approximate Search. + Also supports Script Scoring and Painless Scripting. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + + Returns: + List of Documents along with its scores most similar to the query. + + Optional Args: + same as `similarity_search` + """ embedding = self.embedding_function.embed_query(query) search_type = _get_kwargs_value(kwargs, "search_type", "approximate_search") text_field = _get_kwargs_value(kwargs, "text_field", "text") @@ -454,17 +475,20 @@ class OpenSearchVectorSearch(VectorStore): raise ValueError("Invalid `search_type` provided as an argument") response = self.client.search(index=self.index_name, body=search_query) - hits = [hit["_source"] for hit in response["hits"]["hits"][:k]] - documents = [ - Document( - page_content=hit[text_field], - metadata=hit - if metadata_field == "*" or metadata_field not in hit - else hit[metadata_field], + hits = [hit for hit in response["hits"]["hits"][:k]] + documents_with_scores = [ + ( + Document( + page_content=hit["_source"][text_field], + metadata=hit["_source"] + if metadata_field == "*" or metadata_field not in hit["_source"] + else hit["_source"][metadata_field], + ), + hit["_score"], ) for hit in hits ] - return documents + return documents_with_scores @classmethod def from_texts( diff --git a/tests/integration_tests/vectorstores/test_opensearch.py b/tests/integration_tests/vectorstores/test_opensearch.py index 88bdec09..8b9e12a8 100644 --- a/tests/integration_tests/vectorstores/test_opensearch.py +++ b/tests/integration_tests/vectorstores/test_opensearch.py @@ -23,6 +23,22 @@ def test_opensearch() -> None: assert output == [Document(page_content="foo")] +def test_similarity_search_with_score() -> None: + """Test similarity search with score using Approximate Search.""" + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = OpenSearchVectorSearch.from_texts( + texts, + FakeEmbeddings(), + metadatas=metadatas, + opensearch_url=DEFAULT_OPENSEARCH_URL, + ) + output = docsearch.similarity_search_with_score("foo", k=2) + assert output == [ + (Document(page_content="foo", metadata={"page": 0}), 1.0), + (Document(page_content="bar", metadata={"page": 1}), 0.5), + ] + + def test_opensearch_with_custom_field_name() -> None: """Test indexing and search using custom vector field and text field name.""" docsearch = OpenSearchVectorSearch.from_texts( @@ -178,7 +194,7 @@ def test_appx_search_with_lucene_filter() -> None: def test_opensearch_with_custom_field_name_appx_true() -> None: """Test Approximate Search with custom field name appx true.""" - text_input = ["test", "add", "text", "method"] + text_input = ["add", "test", "text", "method"] docsearch = OpenSearchVectorSearch.from_texts( text_input, FakeEmbeddings(), @@ -191,7 +207,7 @@ def test_opensearch_with_custom_field_name_appx_true() -> None: def test_opensearch_with_custom_field_name_appx_false() -> None: """Test Approximate Search with custom field name appx true.""" - text_input = ["test", "add", "text", "method"] + text_input = ["add", "test", "text", "method"] docsearch = OpenSearchVectorSearch.from_texts( text_input, FakeEmbeddings(), opensearch_url=DEFAULT_OPENSEARCH_URL )