|
|
|
@ -1,7 +1,10 @@
|
|
|
|
|
from enum import Enum
|
|
|
|
|
from typing import Dict, List, Optional
|
|
|
|
|
|
|
|
|
|
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
|
|
|
|
from langchain_core.callbacks import (
|
|
|
|
|
AsyncCallbackManagerForRetrieverRun,
|
|
|
|
|
CallbackManagerForRetrieverRun,
|
|
|
|
|
)
|
|
|
|
|
from langchain_core.documents import Document
|
|
|
|
|
from langchain_core.pydantic_v1 import Field, root_validator
|
|
|
|
|
from langchain_core.retrievers import BaseRetriever
|
|
|
|
@ -71,3 +74,30 @@ class MultiVectorRetriever(BaseRetriever):
|
|
|
|
|
ids.append(d.metadata[self.id_key])
|
|
|
|
|
docs = self.docstore.mget(ids)
|
|
|
|
|
return [d for d in docs if d is not None]
|
|
|
|
|
|
|
|
|
|
async def _aget_relevant_documents(
|
|
|
|
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
|
|
|
|
) -> List[Document]:
|
|
|
|
|
"""Asynchronously get documents relevant to a query.
|
|
|
|
|
Args:
|
|
|
|
|
query: String to find relevant documents for
|
|
|
|
|
run_manager: The callbacks handler to use
|
|
|
|
|
Returns:
|
|
|
|
|
List of relevant documents
|
|
|
|
|
"""
|
|
|
|
|
if self.search_type == SearchType.mmr:
|
|
|
|
|
sub_docs = await self.vectorstore.amax_marginal_relevance_search(
|
|
|
|
|
query, **self.search_kwargs
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
sub_docs = await self.vectorstore.asimilarity_search(
|
|
|
|
|
query, **self.search_kwargs
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# We do this to maintain the order of the ids that are returned
|
|
|
|
|
ids = []
|
|
|
|
|
for d in sub_docs:
|
|
|
|
|
if self.id_key in d.metadata and d.metadata[self.id_key] not in ids:
|
|
|
|
|
ids.append(d.metadata[self.id_key])
|
|
|
|
|
docs = await self.docstore.amget(ids)
|
|
|
|
|
return [d for d in docs if d is not None]
|
|
|
|
|