From d5eb22887448f9c81c657478fa717670f77bd9fa Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 18 Aug 2023 15:04:26 +0100 Subject: [PATCH] Add kwargs to all other optional runnable methods (#9439) --- .../langchain/schema/output_parser.py | 10 +- libs/langchain/langchain/schema/retriever.py | 5 +- .../langchain/schema/runnable/base.py | 111 +++++++++++++----- .../langchain/schema/runnable/passthrough.py | 5 +- .../langchain/schema/runnable/router.py | 17 ++- 5 files changed, 114 insertions(+), 34 deletions(-) diff --git a/libs/langchain/langchain/schema/output_parser.py b/libs/langchain/langchain/schema/output_parser.py index 8de216800f..9e7a8083f2 100644 --- a/libs/langchain/langchain/schema/output_parser.py +++ b/libs/langchain/langchain/schema/output_parser.py @@ -79,7 +79,10 @@ class BaseGenerationOutputParser( ) async def ainvoke( - self, input: str | BaseMessage, config: RunnableConfig | None = None + self, + input: str | BaseMessage, + config: RunnableConfig | None = None, + **kwargs: Optional[Any], ) -> T: if isinstance(input, BaseMessage): return await self._acall_with_config( @@ -147,7 +150,10 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T] ) async def ainvoke( - self, input: str | BaseMessage, config: RunnableConfig | None = None + self, + input: str | BaseMessage, + config: RunnableConfig | None = None, + **kwargs: Optional[Any], ) -> T: if isinstance(input, BaseMessage): return await self._acall_with_config( diff --git a/libs/langchain/langchain/schema/retriever.py b/libs/langchain/langchain/schema/retriever.py index 72c5cf6366..5da50e1497 100644 --- a/libs/langchain/langchain/schema/retriever.py +++ b/libs/langchain/langchain/schema/retriever.py @@ -116,7 +116,10 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): ) async def ainvoke( - self, input: str, config: Optional[RunnableConfig] = None + self, + input: str, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], ) -> List[Document]: if type(self).aget_relevant_documents == BaseRetriever.aget_relevant_documents: # If the retriever doesn't implement async, use default implementation diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 1119189bc8..d0a7bbac9f 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -5,6 +5,7 @@ import copy import threading from abc import ABC, abstractmethod from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait +from functools import partial from itertools import tee from typing import ( TYPE_CHECKING, @@ -83,14 +84,14 @@ class Runnable(Generic[Input, Output], ABC): ... async def ainvoke( - self, input: Input, config: Optional[RunnableConfig] = None + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: """ Default implementation of ainvoke, which calls invoke in a thread pool. Subclasses should override this method if they can run asynchronously. """ return await asyncio.get_running_loop().run_in_executor( - None, self.invoke, input, config + None, partial(self.invoke, **kwargs), input, config ) def batch( @@ -99,6 +100,7 @@ class Runnable(Generic[Input, Output], ABC): config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, *, max_concurrency: Optional[int] = None, + **kwargs: Optional[Any], ) -> List[Output]: """ Default implementation of batch, which calls invoke N times. @@ -108,10 +110,10 @@ class Runnable(Generic[Input, Output], ABC): # If there's only one input, don't bother with the executor if len(inputs) == 1: - return [self.invoke(inputs[0], configs[0])] + return [self.invoke(inputs[0], configs[0], **kwargs)] with ThreadPoolExecutor(max_workers=max_concurrency) as executor: - return list(executor.map(self.invoke, inputs, configs)) + return list(executor.map(partial(self.invoke, **kwargs), inputs, configs)) async def abatch( self, @@ -119,33 +121,40 @@ class Runnable(Generic[Input, Output], ABC): config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, *, max_concurrency: Optional[int] = None, + **kwargs: Optional[Any], ) -> List[Output]: """ Default implementation of abatch, which calls ainvoke N times. Subclasses should override this method if they can batch more efficiently. """ configs = self._get_config_list(config, len(inputs)) - coros = map(self.ainvoke, inputs, configs) + coros = map(partial(self.ainvoke, **kwargs), inputs, configs) return await gather_with_concurrency(max_concurrency, *coros) def stream( - self, input: Input, config: Optional[RunnableConfig] = None + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], ) -> Iterator[Output]: """ Default implementation of stream, which calls invoke. Subclasses should override this method if they support streaming output. """ - yield self.invoke(input, config) + yield self.invoke(input, config, **kwargs) async def astream( - self, input: Input, config: Optional[RunnableConfig] = None + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], ) -> AsyncIterator[Output]: """ Default implementation of astream, which calls ainvoke. Subclasses should override this method if they support streaming output. """ - yield await self.ainvoke(input, config) + yield await self.ainvoke(input, config, **kwargs) def transform( self, @@ -601,7 +610,10 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): raise first_error async def ainvoke( - self, input: Input, config: Optional[RunnableConfig] = None + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], ) -> Output: from langchain.callbacks.manager import AsyncCallbackManager @@ -650,6 +662,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, *, max_concurrency: Optional[int] = None, + **kwargs: Optional[Any], ) -> List[Output]: from langchain.callbacks.manager import CallbackManager @@ -712,6 +725,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, *, max_concurrency: Optional[int] = None, + **kwargs: Optional[Any], ) -> List[Output]: from langchain.callbacks.manager import ( AsyncCallbackManager, @@ -879,7 +893,10 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): return cast(Output, input) async def ainvoke( - self, input: Input, config: Optional[RunnableConfig] = None + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], ) -> Output: from langchain.callbacks.manager import AsyncCallbackManager @@ -923,6 +940,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, *, max_concurrency: Optional[int] = None, + **kwargs: Optional[Any], ) -> List[Output]: from langchain.callbacks.manager import CallbackManager @@ -976,6 +994,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, *, max_concurrency: Optional[int] = None, + **kwargs: Optional[Any], ) -> List[Output]: from langchain.callbacks.manager import ( AsyncCallbackManager, @@ -1034,7 +1053,10 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): return cast(List[Output], inputs) def stream( - self, input: Input, config: Optional[RunnableConfig] = None + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], ) -> Iterator[Output]: from langchain.callbacks.manager import CallbackManager @@ -1111,7 +1133,10 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): ) async def astream( - self, input: Input, config: Optional[RunnableConfig] = None + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], ) -> AsyncIterator[Output]: from langchain.callbacks.manager import AsyncCallbackManager @@ -1280,7 +1305,10 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): return output async def ainvoke( - self, input: Input, config: Optional[RunnableConfig] = None + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], ) -> Dict[str, Any]: from langchain.callbacks.manager import AsyncCallbackManager @@ -1379,7 +1407,10 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): ) def stream( - self, input: Input, config: Optional[RunnableConfig] = None + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], ) -> Iterator[Dict[str, Any]]: yield from self.transform(iter([input]), config) @@ -1443,7 +1474,10 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): yield chunk async def astream( - self, input: Input, config: Optional[RunnableConfig] = None + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], ) -> AsyncIterator[Dict[str, Any]]: async def input_aiter() -> AsyncIterator[Input]: yield input @@ -1472,7 +1506,12 @@ class RunnableLambda(Runnable[Input, Output]): else: return False - def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: + def invoke( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Output: return self._call_with_config(self.func, input, config) @@ -1499,13 +1538,21 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): def bind(self, **kwargs: Any) -> Runnable[Input, Output]: return self.__class__(bound=self.bound, kwargs={**self.kwargs, **kwargs}) - def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: - return self.bound.invoke(input, config, **self.kwargs) + def invoke( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Output: + return self.bound.invoke(input, config, **{**self.kwargs, **kwargs}) async def ainvoke( - self, input: Input, config: Optional[RunnableConfig] = None + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], ) -> Output: - return await self.bound.ainvoke(input, config, **self.kwargs) + return await self.bound.ainvoke(input, config, **{**self.kwargs, **kwargs}) def batch( self, @@ -1513,9 +1560,10 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, *, max_concurrency: Optional[int] = None, + **kwargs: Optional[Any], ) -> List[Output]: return self.bound.batch( - inputs, config, max_concurrency=max_concurrency, **self.kwargs + inputs, config, max_concurrency=max_concurrency, **{**self.kwargs, **kwargs} ) async def abatch( @@ -1524,20 +1572,29 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, *, max_concurrency: Optional[int] = None, + **kwargs: Optional[Any], ) -> List[Output]: return await self.bound.abatch( - inputs, config, max_concurrency=max_concurrency, **self.kwargs + inputs, config, max_concurrency=max_concurrency, **{**self.kwargs, **kwargs} ) def stream( - self, input: Input, config: Optional[RunnableConfig] = None + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], ) -> Iterator[Output]: - yield from self.bound.stream(input, config, **self.kwargs) + yield from self.bound.stream(input, config, **{**self.kwargs, **kwargs}) async def astream( - self, input: Input, config: Optional[RunnableConfig] = None + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], ) -> AsyncIterator[Output]: - async for item in self.bound.astream(input, config, **self.kwargs): + async for item in self.bound.astream( + input, config, **{**self.kwargs, **kwargs} + ): yield item def transform( diff --git a/libs/langchain/langchain/schema/runnable/passthrough.py b/libs/langchain/langchain/schema/runnable/passthrough.py index 9ff26589ab..420b13fe80 100644 --- a/libs/langchain/langchain/schema/runnable/passthrough.py +++ b/libs/langchain/langchain/schema/runnable/passthrough.py @@ -32,7 +32,10 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]): return self._call_with_config(identity, input, config) async def ainvoke( - self, input: Input, config: Optional[RunnableConfig] = None + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], ) -> Input: return await self._acall_with_config(aidentity, input, config) diff --git a/libs/langchain/langchain/schema/runnable/router.py b/libs/langchain/langchain/schema/runnable/router.py index a844a01972..68989bfa7d 100644 --- a/libs/langchain/langchain/schema/runnable/router.py +++ b/libs/langchain/langchain/schema/runnable/router.py @@ -104,7 +104,10 @@ class RouterRunnable( return runnable.invoke(actual_input, config) async def ainvoke( - self, input: RouterInput, config: Optional[RunnableConfig] = None + self, + input: RouterInput, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], ) -> Output: key = input["key"] actual_input = input["input"] @@ -120,6 +123,7 @@ class RouterRunnable( config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, *, max_concurrency: Optional[int] = None, + **kwargs: Optional[Any], ) -> List[Output]: keys = [input["key"] for input in inputs] actual_inputs = [input["input"] for input in inputs] @@ -144,6 +148,7 @@ class RouterRunnable( config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, *, max_concurrency: Optional[int] = None, + **kwargs: Optional[Any], ) -> List[Output]: keys = [input["key"] for input in inputs] actual_inputs = [input["input"] for input in inputs] @@ -161,7 +166,10 @@ class RouterRunnable( ) def stream( - self, input: RouterInput, config: Optional[RunnableConfig] = None + self, + input: RouterInput, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], ) -> Iterator[Output]: key = input["key"] actual_input = input["input"] @@ -172,7 +180,10 @@ class RouterRunnable( yield from runnable.stream(actual_input, config) async def astream( - self, input: RouterInput, config: Optional[RunnableConfig] = None + self, + input: RouterInput, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], ) -> AsyncIterator[Output]: key = input["key"] actual_input = input["input"]