diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 89f2dd6c03..8f2f23d76e 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -3414,6 +3414,7 @@ class RunnableLambda(Runnable[Input, Output]): input: Input, run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, + **kwargs: Any, ) -> Output: output: Optional[Output] = None for chunk in call_func_with_variable_args( @@ -3438,6 +3439,7 @@ class RunnableLambda(Runnable[Input, Output]): input: Input, run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, + **kwargs: Any, ) -> Output: return call_func_with_variable_args( self.func, input, config, run_manager.get_sync(), **kwargs @@ -3643,6 +3645,7 @@ class RunnableLambda(Runnable[Input, Output]): input: Input, run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, + **kwargs: Any, ) -> Output: return call_func_with_variable_args( self.func, input, config, run_manager.get_sync(), **kwargs diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index ac7f5a0cce..d621989679 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -3424,6 +3424,26 @@ def test_bind_bind() -> None: ) == dumpd(llm.bind(stop=["Observation:"], one="two", hello="world")) +def test_bind_with_lambda() -> None: + def my_function(*args: Any, **kwargs: Any) -> int: + return 3 + kwargs.get("n", 0) + + runnable = RunnableLambda(my_function).bind(n=1) + assert 4 == runnable.invoke({}) + chunks = list(runnable.stream({})) + assert [4] == chunks + + +async def test_bind_with_lambda_async() -> None: + def my_function(*args: Any, **kwargs: Any) -> int: + return 3 + kwargs.get("n", 0) + + runnable = RunnableLambda(my_function).bind(n=1) + assert 4 == await runnable.ainvoke({}) + chunks = [item async for item in runnable.astream({})] + assert [4] == chunks + + def test_deep_stream() -> None: prompt = ( SystemMessagePromptTemplate.from_template("You are a nice assistant.")