core: fix .bind when used with RunnableLambda async methods (#17739)

**Description:** Here is a minimal example to illustrate behavior:
```python
from langchain_core.runnables import RunnableLambda

def my_function(*args, **kwargs):
    return 3 + kwargs.get("n", 0)

runnable = RunnableLambda(my_function).bind(n=1)


assert 4 == runnable.invoke({})
assert [4] == list(runnable.stream({}))

assert 4 == await runnable.ainvoke({})
assert [4] == [item async for item in runnable.astream({})]
```
Here, `runnable.invoke({})` and `runnable.stream({})` work fine, but
`runnable.ainvoke({})` raises
```
TypeError: RunnableLambda._ainvoke.<locals>.func() got an unexpected keyword argument 'n'
```
and similarly for `runnable.astream({})`:
```
TypeError: RunnableLambda._atransform.<locals>.func() got an unexpected keyword argument 'n'
```
Here we assume that this behavior is undesired and attempt to fix it.

**Issue:** https://github.com/langchain-ai/langchain/issues/17241,
https://github.com/langchain-ai/langchain/discussions/16446
This commit is contained in:
ccurme 2024-02-21 18:31:52 -05:00 committed by GitHub
parent f541545c96
commit 1b0802babe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 0 deletions

View File

@ -3414,6 +3414,7 @@ class RunnableLambda(Runnable[Input, Output]):
input: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Output:
output: Optional[Output] = None
for chunk in call_func_with_variable_args(
@ -3438,6 +3439,7 @@ class RunnableLambda(Runnable[Input, Output]):
input: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Output:
return call_func_with_variable_args(
self.func, input, config, run_manager.get_sync(), **kwargs
@ -3643,6 +3645,7 @@ class RunnableLambda(Runnable[Input, Output]):
input: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Output:
return call_func_with_variable_args(
self.func, input, config, run_manager.get_sync(), **kwargs

View File

@ -3424,6 +3424,26 @@ def test_bind_bind() -> None:
) == dumpd(llm.bind(stop=["Observation:"], one="two", hello="world"))
def test_bind_with_lambda() -> None:
def my_function(*args: Any, **kwargs: Any) -> int:
return 3 + kwargs.get("n", 0)
runnable = RunnableLambda(my_function).bind(n=1)
assert 4 == runnable.invoke({})
chunks = list(runnable.stream({}))
assert [4] == chunks
async def test_bind_with_lambda_async() -> None:
def my_function(*args: Any, **kwargs: Any) -> int:
return 3 + kwargs.get("n", 0)
runnable = RunnableLambda(my_function).bind(n=1)
assert 4 == await runnable.ainvoke({})
chunks = [item async for item in runnable.astream({})]
assert [4] == chunks
def test_deep_stream() -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")