forked from Archives/langchain
OpenSearch: Add Similarity Search with Score (#4089)
### Description Add `similarity_search_with_score` method for OpenSearch to return scores along with documents in the search results Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
This commit is contained in:
parent
b3ecce0545
commit
782df1db10
@ -2,7 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, Iterable, List, Optional
|
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
@ -410,6 +410,27 @@ class OpenSearchVectorSearch(VectorStore):
|
|||||||
pre_filter: script_score query to pre-filter documents before identifying
|
pre_filter: script_score query to pre-filter documents before identifying
|
||||||
nearest neighbors; default: {"match_all": {}}
|
nearest neighbors; default: {"match_all": {}}
|
||||||
"""
|
"""
|
||||||
|
docs_with_scores = self.similarity_search_with_score(query, k, **kwargs)
|
||||||
|
return [doc[0] for doc in docs_with_scores]
|
||||||
|
|
||||||
|
def similarity_search_with_score(
|
||||||
|
self, query: str, k: int = 4, **kwargs: Any
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
"""Return docs and it's scores most similar to query.
|
||||||
|
|
||||||
|
By default supports Approximate Search.
|
||||||
|
Also supports Script Scoring and Painless Scripting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Text to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Documents along with its scores most similar to the query.
|
||||||
|
|
||||||
|
Optional Args:
|
||||||
|
same as `similarity_search`
|
||||||
|
"""
|
||||||
embedding = self.embedding_function.embed_query(query)
|
embedding = self.embedding_function.embed_query(query)
|
||||||
search_type = _get_kwargs_value(kwargs, "search_type", "approximate_search")
|
search_type = _get_kwargs_value(kwargs, "search_type", "approximate_search")
|
||||||
text_field = _get_kwargs_value(kwargs, "text_field", "text")
|
text_field = _get_kwargs_value(kwargs, "text_field", "text")
|
||||||
@ -454,17 +475,20 @@ class OpenSearchVectorSearch(VectorStore):
|
|||||||
raise ValueError("Invalid `search_type` provided as an argument")
|
raise ValueError("Invalid `search_type` provided as an argument")
|
||||||
|
|
||||||
response = self.client.search(index=self.index_name, body=search_query)
|
response = self.client.search(index=self.index_name, body=search_query)
|
||||||
hits = [hit["_source"] for hit in response["hits"]["hits"][:k]]
|
hits = [hit for hit in response["hits"]["hits"][:k]]
|
||||||
documents = [
|
documents_with_scores = [
|
||||||
|
(
|
||||||
Document(
|
Document(
|
||||||
page_content=hit[text_field],
|
page_content=hit["_source"][text_field],
|
||||||
metadata=hit
|
metadata=hit["_source"]
|
||||||
if metadata_field == "*" or metadata_field not in hit
|
if metadata_field == "*" or metadata_field not in hit["_source"]
|
||||||
else hit[metadata_field],
|
else hit["_source"][metadata_field],
|
||||||
|
),
|
||||||
|
hit["_score"],
|
||||||
)
|
)
|
||||||
for hit in hits
|
for hit in hits
|
||||||
]
|
]
|
||||||
return documents
|
return documents_with_scores
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_texts(
|
def from_texts(
|
||||||
|
@ -23,6 +23,22 @@ def test_opensearch() -> None:
|
|||||||
assert output == [Document(page_content="foo")]
|
assert output == [Document(page_content="foo")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_similarity_search_with_score() -> None:
|
||||||
|
"""Test similarity search with score using Approximate Search."""
|
||||||
|
metadatas = [{"page": i} for i in range(len(texts))]
|
||||||
|
docsearch = OpenSearchVectorSearch.from_texts(
|
||||||
|
texts,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
metadatas=metadatas,
|
||||||
|
opensearch_url=DEFAULT_OPENSEARCH_URL,
|
||||||
|
)
|
||||||
|
output = docsearch.similarity_search_with_score("foo", k=2)
|
||||||
|
assert output == [
|
||||||
|
(Document(page_content="foo", metadata={"page": 0}), 1.0),
|
||||||
|
(Document(page_content="bar", metadata={"page": 1}), 0.5),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_opensearch_with_custom_field_name() -> None:
|
def test_opensearch_with_custom_field_name() -> None:
|
||||||
"""Test indexing and search using custom vector field and text field name."""
|
"""Test indexing and search using custom vector field and text field name."""
|
||||||
docsearch = OpenSearchVectorSearch.from_texts(
|
docsearch = OpenSearchVectorSearch.from_texts(
|
||||||
@ -178,7 +194,7 @@ def test_appx_search_with_lucene_filter() -> None:
|
|||||||
|
|
||||||
def test_opensearch_with_custom_field_name_appx_true() -> None:
|
def test_opensearch_with_custom_field_name_appx_true() -> None:
|
||||||
"""Test Approximate Search with custom field name appx true."""
|
"""Test Approximate Search with custom field name appx true."""
|
||||||
text_input = ["test", "add", "text", "method"]
|
text_input = ["add", "test", "text", "method"]
|
||||||
docsearch = OpenSearchVectorSearch.from_texts(
|
docsearch = OpenSearchVectorSearch.from_texts(
|
||||||
text_input,
|
text_input,
|
||||||
FakeEmbeddings(),
|
FakeEmbeddings(),
|
||||||
@ -191,7 +207,7 @@ def test_opensearch_with_custom_field_name_appx_true() -> None:
|
|||||||
|
|
||||||
def test_opensearch_with_custom_field_name_appx_false() -> None:
|
def test_opensearch_with_custom_field_name_appx_false() -> None:
|
||||||
"""Test Approximate Search with custom field name appx true."""
|
"""Test Approximate Search with custom field name appx true."""
|
||||||
text_input = ["test", "add", "text", "method"]
|
text_input = ["add", "test", "text", "method"]
|
||||||
docsearch = OpenSearchVectorSearch.from_texts(
|
docsearch = OpenSearchVectorSearch.from_texts(
|
||||||
text_input, FakeEmbeddings(), opensearch_url=DEFAULT_OPENSEARCH_URL
|
text_input, FakeEmbeddings(), opensearch_url=DEFAULT_OPENSEARCH_URL
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user