From 5fb7d3b4baeddfc20a0abb2b61e8b8b79ae74a56 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Fri, 26 Apr 2024 16:43:35 -0400 Subject: [PATCH] propagate metadata --- .../vectorstores/test_databricks_vector_search.py | 1 + libs/core/langchain_core/vectorstores.py | 12 ++++++++++-- .../langchain/retrievers/self_query/base.py | 8 ++++++-- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/libs/community/tests/unit_tests/vectorstores/test_databricks_vector_search.py b/libs/community/tests/unit_tests/vectorstores/test_databricks_vector_search.py index 19f4a7a6af..4b6fe996b3 100644 --- a/libs/community/tests/unit_tests/vectorstores/test_databricks_vector_search.py +++ b/libs/community/tests/unit_tests/vectorstores/test_databricks_vector_search.py @@ -605,6 +605,7 @@ def test_similarity_score_threshold(index_details: dict, threshold: float) -> No for idx, result in enumerate(result_with_scores): assert result.score >= threshold assert result.page_content == search_result[idx].page_content + assert result.metadata == search_result[idx].metadata @pytest.mark.requires("databricks", "databricks.vector_search") diff --git a/libs/core/langchain_core/vectorstores.py b/libs/core/langchain_core/vectorstores.py index 7a86e2a17c..7a0b9d4e5b 100644 --- a/libs/core/langchain_core/vectorstores.py +++ b/libs/core/langchain_core/vectorstores.py @@ -712,7 +712,11 @@ class VectorStoreRetriever(BaseRetriever): ) if include_score: return [ - DocumentSearchHit(page_content=doc.page_content, score=score) + DocumentSearchHit( + page_content=doc.page_content, + metadata=doc.metadata, + score=score, + ) for doc, score in docs_and_similarities ] docs = [doc for doc, _ in docs_and_similarities] @@ -748,7 +752,11 @@ class VectorStoreRetriever(BaseRetriever): ) if include_score: return [ - DocumentSearchHit(page_content=doc.page_content, score=score) + DocumentSearchHit( + page_content=doc.page_content, + metadata=doc.metadata, + score=score, + ) for doc, score in docs_and_similarities ] docs = [doc for doc, _ in docs_and_similarities] diff --git a/libs/langchain/langchain/retrievers/self_query/base.py b/libs/langchain/langchain/retrievers/self_query/base.py index ba8dd964f1..ce7b685583 100644 --- a/libs/langchain/langchain/retrievers/self_query/base.py +++ b/libs/langchain/langchain/retrievers/self_query/base.py @@ -199,7 +199,9 @@ class SelfQueryRetriever(BaseRetriever): query, **search_kwargs ) return [ - DocumentSearchHit(page_content=doc.page_content, score=score) + DocumentSearchHit( + page_content=doc.page_content, metadata=doc.metadata, score=score + ) for doc, score in docs_and_scores ] else: @@ -214,7 +216,9 @@ class SelfQueryRetriever(BaseRetriever): query, **search_kwargs ) return [ - DocumentSearchHit(page_content=doc.page_content, score=score) + DocumentSearchHit( + page_content=doc.page_content, metadata=doc.metadata, score=score + ) for doc, score in docs_and_scores ] else: