mirror of https://github.com/hwchase17/langchain
Fix multi vector retriever subclassing (#14350)
Fixes #14342 @eyurtsev @baskaryan --------- Co-authored-by: Erick Friis <erick@langchain.dev>pull/14354/head
parent
7bdfc43766
commit
867ca6d0be
@ -0,0 +1,30 @@
|
||||
from typing import Any, List
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain.retrievers.multi_vector import MultiVectorRetriever
|
||||
from langchain.storage import InMemoryStore
|
||||
from tests.unit_tests.indexes.test_indexing import InMemoryVectorStore
|
||||
|
||||
|
||||
class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
res = self.store.get(query)
|
||||
if res is None:
|
||||
return []
|
||||
return [res]
|
||||
|
||||
|
||||
def test_multi_vector_retriever_initialization() -> None:
|
||||
vectorstore = InMemoryVectorstoreWithSearch()
|
||||
retriever = MultiVectorRetriever(
|
||||
vectorstore=vectorstore, docstore=InMemoryStore(), doc_id="doc_id"
|
||||
)
|
||||
documents = [Document(page_content="test document", metadata={"doc_id": "1"})]
|
||||
retriever.vectorstore.add_documents(documents, ids=["1"])
|
||||
retriever.docstore.mset(list(zip(["1"], documents)))
|
||||
results = retriever.invoke("1")
|
||||
assert len(results) > 0
|
||||
assert results[0].page_content == "test document"
|
@ -0,0 +1,40 @@
|
||||
from typing import Any, List, Sequence
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain.retrievers import ParentDocumentRetriever
|
||||
from langchain.storage import InMemoryStore
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from tests.unit_tests.indexes.test_indexing import InMemoryVectorStore
|
||||
|
||||
|
||||
class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
res = self.store.get(query)
|
||||
if res is None:
|
||||
return []
|
||||
return [res]
|
||||
|
||||
def add_documents(self, documents: Sequence[Document], **kwargs: Any) -> List[str]:
|
||||
print(documents)
|
||||
return super().add_documents(
|
||||
documents, ids=[f"{i}" for i in range(len(documents))]
|
||||
)
|
||||
|
||||
|
||||
def test_parent_document_retriever_initialization() -> None:
|
||||
vectorstore = InMemoryVectorstoreWithSearch()
|
||||
store = InMemoryStore()
|
||||
child_splitter = CharacterTextSplitter(chunk_size=400)
|
||||
documents = [Document(page_content="test document")]
|
||||
retriever = ParentDocumentRetriever(
|
||||
vectorstore=vectorstore,
|
||||
docstore=store,
|
||||
child_splitter=child_splitter,
|
||||
)
|
||||
retriever.add_documents(documents)
|
||||
results = retriever.invoke("0")
|
||||
assert len(results) > 0
|
||||
assert results[0].page_content == "test document"
|
Loading…
Reference in New Issue