Fix for SVM retriever discarding document metadata (#9141)

As stated in the title the SVM retriever discarded the metadata of
passed in docs. This code fixes that. I also added one unit test that
should test that.
---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Markus Schiffer 2023-08-11 13:08:17 -07:00 committed by GitHub
parent bace17e0aa
commit 00bf472265
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 4 deletions

View File

@ -38,6 +38,8 @@ class SVMRetriever(BaseRetriever):
"""Index of embeddings.""" """Index of embeddings."""
texts: List[str] texts: List[str]
"""List of texts to index.""" """List of texts to index."""
metadatas: Optional[List[dict]] = None
"""List of metadatas corresponding with each text."""
k: int = 4 k: int = 4
"""Number of results to return.""" """Number of results to return."""
relevancy_threshold: Optional[float] = None relevancy_threshold: Optional[float] = None
@ -51,10 +53,20 @@ class SVMRetriever(BaseRetriever):
@classmethod @classmethod
def from_texts( def from_texts(
cls, texts: List[str], embeddings: Embeddings, **kwargs: Any cls,
texts: List[str],
embeddings: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> SVMRetriever: ) -> SVMRetriever:
index = create_index(texts, embeddings) index = create_index(texts, embeddings)
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs) return cls(
embeddings=embeddings,
index=index,
texts=texts,
metadatas=metadatas,
**kwargs,
)
@classmethod @classmethod
def from_documents( def from_documents(
@ -64,7 +76,9 @@ class SVMRetriever(BaseRetriever):
**kwargs: Any, **kwargs: Any,
) -> SVMRetriever: ) -> SVMRetriever:
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents)) texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
return cls.from_texts(texts=texts, embeddings=embeddings, **kwargs) return cls.from_texts(
texts=texts, embeddings=embeddings, metadatas=metadatas, **kwargs
)
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun
@ -108,5 +122,7 @@ class SVMRetriever(BaseRetriever):
self.relevancy_threshold is None self.relevancy_threshold is None
or normalized_similarities[row] >= self.relevancy_threshold or normalized_similarities[row] >= self.relevancy_threshold
): ):
top_k_results.append(Document(page_content=self.texts[row - 1])) metadata = self.metadatas[row - 1] if self.metadatas else {}
doc = Document(page_content=self.texts[row - 1], metadata=metadata)
top_k_results.append(doc)
return top_k_results return top_k_results

View File

@ -25,3 +25,18 @@ class TestSVMRetriever:
documents=input_docs, embeddings=FakeEmbeddings(size=100) documents=input_docs, embeddings=FakeEmbeddings(size=100)
) )
assert len(svm_retriever.texts) == 3 assert len(svm_retriever.texts) == 3
@pytest.mark.requires("sklearn")
def test_metadata_persists(self) -> None:
input_docs = [
Document(page_content="I have a pen.", metadata={"foo": "bar"}),
Document(page_content="How about you?", metadata={"foo": "baz"}),
Document(page_content="I have a bag.", metadata={"foo": "qux"}),
]
svm_retriever = SVMRetriever.from_documents(
documents=input_docs, embeddings=FakeEmbeddings(size=100)
)
query = "Have anything?"
output_docs = svm_retriever.get_relevant_documents(query=query)
for doc in output_docs:
assert "foo" in doc.metadata