mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Fix for SVM retriever discarding document metadata (#9141)
As stated in the title the SVM retriever discarded the metadata of passed in docs. This code fixes that. I also added one unit test that should test that. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
bace17e0aa
commit
00bf472265
@ -38,6 +38,8 @@ class SVMRetriever(BaseRetriever):
|
|||||||
"""Index of embeddings."""
|
"""Index of embeddings."""
|
||||||
texts: List[str]
|
texts: List[str]
|
||||||
"""List of texts to index."""
|
"""List of texts to index."""
|
||||||
|
metadatas: Optional[List[dict]] = None
|
||||||
|
"""List of metadatas corresponding with each text."""
|
||||||
k: int = 4
|
k: int = 4
|
||||||
"""Number of results to return."""
|
"""Number of results to return."""
|
||||||
relevancy_threshold: Optional[float] = None
|
relevancy_threshold: Optional[float] = None
|
||||||
@ -51,10 +53,20 @@ class SVMRetriever(BaseRetriever):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_texts(
|
def from_texts(
|
||||||
cls, texts: List[str], embeddings: Embeddings, **kwargs: Any
|
cls,
|
||||||
|
texts: List[str],
|
||||||
|
embeddings: Embeddings,
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> SVMRetriever:
|
) -> SVMRetriever:
|
||||||
index = create_index(texts, embeddings)
|
index = create_index(texts, embeddings)
|
||||||
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
|
return cls(
|
||||||
|
embeddings=embeddings,
|
||||||
|
index=index,
|
||||||
|
texts=texts,
|
||||||
|
metadatas=metadatas,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_documents(
|
def from_documents(
|
||||||
@ -64,7 +76,9 @@ class SVMRetriever(BaseRetriever):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> SVMRetriever:
|
) -> SVMRetriever:
|
||||||
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
|
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
|
||||||
return cls.from_texts(texts=texts, embeddings=embeddings, **kwargs)
|
return cls.from_texts(
|
||||||
|
texts=texts, embeddings=embeddings, metadatas=metadatas, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
@ -108,5 +122,7 @@ class SVMRetriever(BaseRetriever):
|
|||||||
self.relevancy_threshold is None
|
self.relevancy_threshold is None
|
||||||
or normalized_similarities[row] >= self.relevancy_threshold
|
or normalized_similarities[row] >= self.relevancy_threshold
|
||||||
):
|
):
|
||||||
top_k_results.append(Document(page_content=self.texts[row - 1]))
|
metadata = self.metadatas[row - 1] if self.metadatas else {}
|
||||||
|
doc = Document(page_content=self.texts[row - 1], metadata=metadata)
|
||||||
|
top_k_results.append(doc)
|
||||||
return top_k_results
|
return top_k_results
|
||||||
|
@ -25,3 +25,18 @@ class TestSVMRetriever:
|
|||||||
documents=input_docs, embeddings=FakeEmbeddings(size=100)
|
documents=input_docs, embeddings=FakeEmbeddings(size=100)
|
||||||
)
|
)
|
||||||
assert len(svm_retriever.texts) == 3
|
assert len(svm_retriever.texts) == 3
|
||||||
|
|
||||||
|
@pytest.mark.requires("sklearn")
|
||||||
|
def test_metadata_persists(self) -> None:
|
||||||
|
input_docs = [
|
||||||
|
Document(page_content="I have a pen.", metadata={"foo": "bar"}),
|
||||||
|
Document(page_content="How about you?", metadata={"foo": "baz"}),
|
||||||
|
Document(page_content="I have a bag.", metadata={"foo": "qux"}),
|
||||||
|
]
|
||||||
|
svm_retriever = SVMRetriever.from_documents(
|
||||||
|
documents=input_docs, embeddings=FakeEmbeddings(size=100)
|
||||||
|
)
|
||||||
|
query = "Have anything?"
|
||||||
|
output_docs = svm_retriever.get_relevant_documents(query=query)
|
||||||
|
for doc in output_docs:
|
||||||
|
assert "foo" in doc.metadata
|
||||||
|
Loading…
Reference in New Issue
Block a user