langchain[patch]: Add async methods to MultiVectorRetriever (#16878)

Adds async support to multi vector retriever
pull/16433/merge
Christophe Bornet 5 months ago committed by GitHub
parent 7d03d8f586
commit 78a1af4848
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,7 +1,10 @@
from enum import Enum
from typing import Dict, List, Optional
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.retrievers import BaseRetriever
@ -71,3 +74,30 @@ class MultiVectorRetriever(BaseRetriever):
ids.append(d.metadata[self.id_key])
docs = self.docstore.mget(ids)
return [d for d in docs if d is not None]
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
"""Asynchronously get documents relevant to a query.
Args:
query: String to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
if self.search_type == SearchType.mmr:
sub_docs = await self.vectorstore.amax_marginal_relevance_search(
query, **self.search_kwargs
)
else:
sub_docs = await self.vectorstore.asimilarity_search(
query, **self.search_kwargs
)
# We do this to maintain the order of the ids that are returned
ids = []
for d in sub_docs:
if self.id_key in d.metadata and d.metadata[self.id_key] not in ids:
ids.append(d.metadata[self.id_key])
docs = await self.docstore.amget(ids)
return [d for d in docs if d is not None]

@ -28,3 +28,16 @@ def test_multi_vector_retriever_initialization() -> None:
results = retriever.invoke("1")
assert len(results) > 0
assert results[0].page_content == "test document"
async def test_multi_vector_retriever_initialization_async() -> None:
vectorstore = InMemoryVectorstoreWithSearch()
retriever = MultiVectorRetriever(
vectorstore=vectorstore, docstore=InMemoryStore(), doc_id="doc_id"
)
documents = [Document(page_content="test document", metadata={"doc_id": "1"})]
await retriever.vectorstore.aadd_documents(documents, ids=["1"])
await retriever.docstore.amset(list(zip(["1"], documents)))
results = await retriever.ainvoke("1")
assert len(results) > 0
assert results[0].page_content == "test document"

Loading…
Cancel
Save