From 76283e962518ab1832b1701eabde0aeffac0823b Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Wed, 8 Nov 2023 17:50:06 -0800 Subject: [PATCH] Adds embeddings filter option to return scores in state (#12489) CC @baskaryan @assafelovic --- .../retrievers/document_compressors/embeddings_filter.py | 2 ++ .../retrievers/document_compressors/test_embeddings_filter.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py b/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py index 001b9494a0..9241c3bc59 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py +++ b/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py @@ -67,4 +67,6 @@ class EmbeddingsFilter(BaseDocumentCompressor): similarity[included_idxs] > self.similarity_threshold ) included_idxs = included_idxs[similar_enough] + for i in included_idxs: + stateful_documents[i].state["query_similarity_score"] = similarity[i] return [stateful_documents[i] for i in included_idxs] diff --git a/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_embeddings_filter.py b/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_embeddings_filter.py index 7faa97b688..ad2f71d7bf 100644 --- a/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_embeddings_filter.py +++ b/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_embeddings_filter.py @@ -35,7 +35,9 @@ def test_embeddings_filter_with_state() -> None: state = {"embedded_doc": np.zeros(len(embedded_query))} docs = [_DocumentWithState(page_content=t, state=state) for t in texts] docs[-1].state = {"embedded_doc": embedded_query} - relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.75) + relevant_filter = EmbeddingsFilter( + embeddings=embeddings, similarity_threshold=0.75, return_similarity_scores=True + ) actual = relevant_filter.compress_documents(docs, query) assert len(actual) == 1 assert texts[-1] == actual[0].page_content