From 00bf472265eb38e686dd6a609d0ab6cdd7cd5850 Mon Sep 17 00:00:00 2001 From: Markus Schiffer <47703802+MarkusSchiffer@users.noreply.github.com> Date: Fri, 11 Aug 2023 13:08:17 -0700 Subject: [PATCH] Fix for SVM retriever discarding document metadata (#9141) As stated in the title the SVM retriever discarded the metadata of passed in docs. This code fixes that. I also added one unit test that should test that. --------- Co-authored-by: Bagatur --- libs/langchain/langchain/retrievers/svm.py | 24 +++++++++++++++---- .../tests/unit_tests/retrievers/test_svm.py | 15 ++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/retrievers/svm.py b/libs/langchain/langchain/retrievers/svm.py index 3c65e974eb..9594ce0dce 100644 --- a/libs/langchain/langchain/retrievers/svm.py +++ b/libs/langchain/langchain/retrievers/svm.py @@ -38,6 +38,8 @@ class SVMRetriever(BaseRetriever): """Index of embeddings.""" texts: List[str] """List of texts to index.""" + metadatas: Optional[List[dict]] = None + """List of metadatas corresponding with each text.""" k: int = 4 """Number of results to return.""" relevancy_threshold: Optional[float] = None @@ -51,10 +53,20 @@ class SVMRetriever(BaseRetriever): @classmethod def from_texts( - cls, texts: List[str], embeddings: Embeddings, **kwargs: Any + cls, + texts: List[str], + embeddings: Embeddings, + metadatas: Optional[List[dict]] = None, + **kwargs: Any, ) -> SVMRetriever: index = create_index(texts, embeddings) - return cls(embeddings=embeddings, index=index, texts=texts, **kwargs) + return cls( + embeddings=embeddings, + index=index, + texts=texts, + metadatas=metadatas, + **kwargs, + ) @classmethod def from_documents( @@ -64,7 +76,9 @@ class SVMRetriever(BaseRetriever): **kwargs: Any, ) -> SVMRetriever: texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents)) - return cls.from_texts(texts=texts, embeddings=embeddings, **kwargs) + return cls.from_texts( + texts=texts, embeddings=embeddings, metadatas=metadatas, **kwargs + ) def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun @@ -108,5 +122,7 @@ class SVMRetriever(BaseRetriever): self.relevancy_threshold is None or normalized_similarities[row] >= self.relevancy_threshold ): - top_k_results.append(Document(page_content=self.texts[row - 1])) + metadata = self.metadatas[row - 1] if self.metadatas else {} + doc = Document(page_content=self.texts[row - 1], metadata=metadata) + top_k_results.append(doc) return top_k_results diff --git a/libs/langchain/tests/unit_tests/retrievers/test_svm.py b/libs/langchain/tests/unit_tests/retrievers/test_svm.py index 2f52fc38cd..491be75a6c 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_svm.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_svm.py @@ -25,3 +25,18 @@ class TestSVMRetriever: documents=input_docs, embeddings=FakeEmbeddings(size=100) ) assert len(svm_retriever.texts) == 3 + + @pytest.mark.requires("sklearn") + def test_metadata_persists(self) -> None: + input_docs = [ + Document(page_content="I have a pen.", metadata={"foo": "bar"}), + Document(page_content="How about you?", metadata={"foo": "baz"}), + Document(page_content="I have a bag.", metadata={"foo": "qux"}), + ] + svm_retriever = SVMRetriever.from_documents( + documents=input_docs, embeddings=FakeEmbeddings(size=100) + ) + query = "Have anything?" + output_docs = svm_retriever.get_relevant_documents(query=query) + for doc in output_docs: + assert "foo" in doc.metadata