mirror of https://github.com/hwchase17/langchain
Fix multi vector retriever subclassing (#14350)
Fixes #14342 @eyurtsev @baskaryan --------- Co-authored-by: Erick Friis <erick@langchain.dev>pull/14354/head
parent
7bdfc43766
commit
867ca6d0be
@ -0,0 +1,30 @@
|
|||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
|
from langchain.retrievers.multi_vector import MultiVectorRetriever
|
||||||
|
from langchain.storage import InMemoryStore
|
||||||
|
from tests.unit_tests.indexes.test_indexing import InMemoryVectorStore
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
|
||||||
|
def similarity_search(
|
||||||
|
self, query: str, k: int = 4, **kwargs: Any
|
||||||
|
) -> List[Document]:
|
||||||
|
res = self.store.get(query)
|
||||||
|
if res is None:
|
||||||
|
return []
|
||||||
|
return [res]
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_vector_retriever_initialization() -> None:
|
||||||
|
vectorstore = InMemoryVectorstoreWithSearch()
|
||||||
|
retriever = MultiVectorRetriever(
|
||||||
|
vectorstore=vectorstore, docstore=InMemoryStore(), doc_id="doc_id"
|
||||||
|
)
|
||||||
|
documents = [Document(page_content="test document", metadata={"doc_id": "1"})]
|
||||||
|
retriever.vectorstore.add_documents(documents, ids=["1"])
|
||||||
|
retriever.docstore.mset(list(zip(["1"], documents)))
|
||||||
|
results = retriever.invoke("1")
|
||||||
|
assert len(results) > 0
|
||||||
|
assert results[0].page_content == "test document"
|
@ -0,0 +1,40 @@
|
|||||||
|
from typing import Any, List, Sequence
|
||||||
|
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
|
from langchain.retrievers import ParentDocumentRetriever
|
||||||
|
from langchain.storage import InMemoryStore
|
||||||
|
from langchain.text_splitter import CharacterTextSplitter
|
||||||
|
from tests.unit_tests.indexes.test_indexing import InMemoryVectorStore
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
|
||||||
|
def similarity_search(
|
||||||
|
self, query: str, k: int = 4, **kwargs: Any
|
||||||
|
) -> List[Document]:
|
||||||
|
res = self.store.get(query)
|
||||||
|
if res is None:
|
||||||
|
return []
|
||||||
|
return [res]
|
||||||
|
|
||||||
|
def add_documents(self, documents: Sequence[Document], **kwargs: Any) -> List[str]:
|
||||||
|
print(documents)
|
||||||
|
return super().add_documents(
|
||||||
|
documents, ids=[f"{i}" for i in range(len(documents))]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parent_document_retriever_initialization() -> None:
|
||||||
|
vectorstore = InMemoryVectorstoreWithSearch()
|
||||||
|
store = InMemoryStore()
|
||||||
|
child_splitter = CharacterTextSplitter(chunk_size=400)
|
||||||
|
documents = [Document(page_content="test document")]
|
||||||
|
retriever = ParentDocumentRetriever(
|
||||||
|
vectorstore=vectorstore,
|
||||||
|
docstore=store,
|
||||||
|
child_splitter=child_splitter,
|
||||||
|
)
|
||||||
|
retriever.add_documents(documents)
|
||||||
|
results = retriever.invoke("0")
|
||||||
|
assert len(results) > 0
|
||||||
|
assert results[0].page_content == "test document"
|
Loading…
Reference in New Issue