mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
core: fix .bind when used with RunnableLambda async methods (#17739)
**Description:** Here is a minimal example to illustrate behavior: ```python from langchain_core.runnables import RunnableLambda def my_function(*args, **kwargs): return 3 + kwargs.get("n", 0) runnable = RunnableLambda(my_function).bind(n=1) assert 4 == runnable.invoke({}) assert [4] == list(runnable.stream({})) assert 4 == await runnable.ainvoke({}) assert [4] == [item async for item in runnable.astream({})] ``` Here, `runnable.invoke({})` and `runnable.stream({})` work fine, but `runnable.ainvoke({})` raises ``` TypeError: RunnableLambda._ainvoke.<locals>.func() got an unexpected keyword argument 'n' ``` and similarly for `runnable.astream({})`: ``` TypeError: RunnableLambda._atransform.<locals>.func() got an unexpected keyword argument 'n' ``` Here we assume that this behavior is undesired and attempt to fix it. **Issue:** https://github.com/langchain-ai/langchain/issues/17241, https://github.com/langchain-ai/langchain/discussions/16446
This commit is contained in:
parent
f541545c96
commit
1b0802babe
@ -3414,6 +3414,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
input: Input,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> Output:
|
||||
output: Optional[Output] = None
|
||||
for chunk in call_func_with_variable_args(
|
||||
@ -3438,6 +3439,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
input: Input,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> Output:
|
||||
return call_func_with_variable_args(
|
||||
self.func, input, config, run_manager.get_sync(), **kwargs
|
||||
@ -3643,6 +3645,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
input: Input,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> Output:
|
||||
return call_func_with_variable_args(
|
||||
self.func, input, config, run_manager.get_sync(), **kwargs
|
||||
|
@ -3424,6 +3424,26 @@ def test_bind_bind() -> None:
|
||||
) == dumpd(llm.bind(stop=["Observation:"], one="two", hello="world"))
|
||||
|
||||
|
||||
def test_bind_with_lambda() -> None:
|
||||
def my_function(*args: Any, **kwargs: Any) -> int:
|
||||
return 3 + kwargs.get("n", 0)
|
||||
|
||||
runnable = RunnableLambda(my_function).bind(n=1)
|
||||
assert 4 == runnable.invoke({})
|
||||
chunks = list(runnable.stream({}))
|
||||
assert [4] == chunks
|
||||
|
||||
|
||||
async def test_bind_with_lambda_async() -> None:
|
||||
def my_function(*args: Any, **kwargs: Any) -> int:
|
||||
return 3 + kwargs.get("n", 0)
|
||||
|
||||
runnable = RunnableLambda(my_function).bind(n=1)
|
||||
assert 4 == await runnable.ainvoke({})
|
||||
chunks = [item async for item in runnable.astream({})]
|
||||
assert [4] == chunks
|
||||
|
||||
|
||||
def test_deep_stream() -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
|
Loading…
Reference in New Issue
Block a user