core[patch]: passthrough BaseRetriever.invoke(**kwargs) (#16551)

Fix for #16547
This commit is contained in:
Bagatur 2024-01-25 08:58:39 -08:00 committed by GitHub
parent 355ef2a4a6
commit e510cfaa23
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -115,7 +115,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
)
def invoke(
self, input: str, config: Optional[RunnableConfig] = None
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> List[Document]:
config = ensure_config(config)
return self.get_relevant_documents(
@ -124,13 +124,14 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
)
async def ainvoke(
self,
input: str,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
**kwargs: Any,
) -> List[Document]:
config = ensure_config(config)
return await self.aget_relevant_documents(
@ -139,6 +140,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
)
@abstractmethod
@ -208,7 +210,6 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
dumpd(self),
query,
name=run_name,
**kwargs,
)
try:
_kwargs = kwargs if self._expects_other_args else {}
@ -224,7 +225,6 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
else:
run_manager.on_retriever_end(
result,
**kwargs,
)
return result
@ -266,7 +266,6 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
dumpd(self),
query,
name=run_name,
**kwargs,
)
try:
_kwargs = kwargs if self._expects_other_args else {}
@ -282,6 +281,5 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
else:
await run_manager.on_retriever_end(
result,
**kwargs,
)
return result