From 1b0802babeeee65afe283b39ca3f0e9fd508c103 Mon Sep 17 00:00:00 2001 From: ccurme Date: Wed, 21 Feb 2024 18:31:52 -0500 Subject: [PATCH] core: fix .bind when used with RunnableLambda async methods (#17739) **Description:** Here is a minimal example to illustrate behavior: ```python from langchain_core.runnables import RunnableLambda def my_function(*args, **kwargs): return 3 + kwargs.get("n", 0) runnable = RunnableLambda(my_function).bind(n=1) assert 4 == runnable.invoke({}) assert [4] == list(runnable.stream({})) assert 4 == await runnable.ainvoke({}) assert [4] == [item async for item in runnable.astream({})] ``` Here, `runnable.invoke({})` and `runnable.stream({})` work fine, but `runnable.ainvoke({})` raises ``` TypeError: RunnableLambda._ainvoke..func() got an unexpected keyword argument 'n' ``` and similarly for `runnable.astream({})`: ``` TypeError: RunnableLambda._atransform..func() got an unexpected keyword argument 'n' ``` Here we assume that this behavior is undesired and attempt to fix it. **Issue:** https://github.com/langchain-ai/langchain/issues/17241, https://github.com/langchain-ai/langchain/discussions/16446 --- libs/core/langchain_core/runnables/base.py | 3 +++ .../unit_tests/runnables/test_runnable.py | 20 +++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 89f2dd6c03..8f2f23d76e 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -3414,6 +3414,7 @@ class RunnableLambda(Runnable[Input, Output]): input: Input, run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, + **kwargs: Any, ) -> Output: output: Optional[Output] = None for chunk in call_func_with_variable_args( @@ -3438,6 +3439,7 @@ class RunnableLambda(Runnable[Input, Output]): input: Input, run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, + **kwargs: Any, ) -> Output: return call_func_with_variable_args( self.func, input, config, run_manager.get_sync(), **kwargs @@ -3643,6 +3645,7 @@ class RunnableLambda(Runnable[Input, Output]): input: Input, run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, + **kwargs: Any, ) -> Output: return call_func_with_variable_args( self.func, input, config, run_manager.get_sync(), **kwargs diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index ac7f5a0cce..d621989679 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -3424,6 +3424,26 @@ def test_bind_bind() -> None: ) == dumpd(llm.bind(stop=["Observation:"], one="two", hello="world")) +def test_bind_with_lambda() -> None: + def my_function(*args: Any, **kwargs: Any) -> int: + return 3 + kwargs.get("n", 0) + + runnable = RunnableLambda(my_function).bind(n=1) + assert 4 == runnable.invoke({}) + chunks = list(runnable.stream({})) + assert [4] == chunks + + +async def test_bind_with_lambda_async() -> None: + def my_function(*args: Any, **kwargs: Any) -> int: + return 3 + kwargs.get("n", 0) + + runnable = RunnableLambda(my_function).bind(n=1) + assert 4 == await runnable.ainvoke({}) + chunks = [item async for item in runnable.astream({})] + assert [4] == chunks + + def test_deep_stream() -> None: prompt = ( SystemMessagePromptTemplate.from_template("You are a nice assistant.")