From e510cfaa23b90662ed88d54115f0f8c0726213b8 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Thu, 25 Jan 2024 08:58:39 -0800 Subject: [PATCH] core[patch]: passthrough BaseRetriever.invoke(**kwargs) (#16551) Fix for #16547 --- libs/core/langchain_core/retrievers.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/libs/core/langchain_core/retrievers.py b/libs/core/langchain_core/retrievers.py index f215636695..b01b29bb43 100644 --- a/libs/core/langchain_core/retrievers.py +++ b/libs/core/langchain_core/retrievers.py @@ -115,7 +115,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): ) def invoke( - self, input: str, config: Optional[RunnableConfig] = None + self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> List[Document]: config = ensure_config(config) return self.get_relevant_documents( @@ -124,13 +124,14 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): tags=config.get("tags"), metadata=config.get("metadata"), run_name=config.get("run_name"), + **kwargs, ) async def ainvoke( self, input: str, config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + **kwargs: Any, ) -> List[Document]: config = ensure_config(config) return await self.aget_relevant_documents( @@ -139,6 +140,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): tags=config.get("tags"), metadata=config.get("metadata"), run_name=config.get("run_name"), + **kwargs, ) @abstractmethod @@ -208,7 +210,6 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): dumpd(self), query, name=run_name, - **kwargs, ) try: _kwargs = kwargs if self._expects_other_args else {} @@ -224,7 +225,6 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): else: run_manager.on_retriever_end( result, - **kwargs, ) return result @@ -266,7 +266,6 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): dumpd(self), query, name=run_name, - **kwargs, ) try: _kwargs = kwargs if self._expects_other_args else {} @@ -282,6 +281,5 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): else: await run_manager.on_retriever_end( result, - **kwargs, ) return result