Nc/runnable lambda recurse (#9390)

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. These live is docs/extras
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17, @rlancemartin.
 -->
pull/9671/head
Nuno Campos 1 year ago committed by GitHub
commit fa05e18278
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import inspect
import threading
from abc import ABC, abstractmethod
from concurrent.futures import FIRST_COMPLETED, wait
@ -1343,9 +1344,18 @@ class RunnableLambda(Runnable[Input, Output]):
A runnable that runs a callable.
"""
def __init__(self, func: Callable[[Input], Output]) -> None:
if callable(func):
self.func = func
def __init__(
self,
func: Union[Callable[[Input], Output], Callable[[Input], Awaitable[Output]]],
afunc: Optional[Callable[[Input], Awaitable[Output]]] = None,
) -> None:
if afunc is not None:
self.afunc = afunc
if inspect.iscoroutinefunction(func):
self.afunc = func
elif callable(func):
self.func = cast(Callable[[Input], Output], func)
else:
raise TypeError(
"Expected a callable type for `func`."
@ -1354,17 +1364,89 @@ class RunnableLambda(Runnable[Input, Output]):
def __eq__(self, other: Any) -> bool:
if isinstance(other, RunnableLambda):
return self.func == other.func
if hasattr(self, "func") and hasattr(other, "func"):
return self.func == other.func
elif hasattr(self, "afunc") and hasattr(other, "afunc"):
return self.afunc == other.afunc
else:
return False
else:
return False
def _invoke(
self,
input: Input,
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
) -> Output:
output = self.func(input)
# If the output is a runnable, invoke it
if isinstance(output, Runnable):
recursion_limit = config["recursion_limit"]
if recursion_limit <= 0:
raise RecursionError(
f"Recursion limit reached when invoking {self} with input {input}."
)
output = output.invoke(
input,
patch_config(
config,
callbacks=run_manager.get_child(),
recursion_limit=recursion_limit - 1,
),
)
return output
async def _ainvoke(
self,
input: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
) -> Output:
output = await self.afunc(input)
# If the output is a runnable, invoke it
if isinstance(output, Runnable):
recursion_limit = config["recursion_limit"]
if recursion_limit <= 0:
raise RecursionError(
f"Recursion limit reached when invoking {self} with input {input}."
)
output = await output.ainvoke(
input,
patch_config(
config,
callbacks=run_manager.get_child(),
recursion_limit=recursion_limit - 1,
),
)
return output
def invoke(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output:
return self._call_with_config(self.func, input, config)
if hasattr(self, "func"):
return self._call_with_config(self._invoke, input, config)
else:
raise TypeError(
"Cannot invoke a coroutine function synchronously."
"Use `ainvoke` instead."
)
async def ainvoke(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output:
if hasattr(self, "afunc"):
return await self._acall_with_config(self._ainvoke, input, config)
else:
# Delegating to super implementation of ainvoke.
# Uses asyncio executor to run the sync version (invoke)
return await super().ainvoke(input, config)
class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):

@ -47,6 +47,11 @@ class RunnableConfig(TypedDict, total=False):
ThreadPoolExecutor will be created.
"""
recursion_limit: int
"""
Maximum number of times a call can recurse. If not provided, defaults to 10.
"""
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
empty = RunnableConfig(
@ -54,6 +59,7 @@ def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
metadata={},
callbacks=None,
_locals={},
recursion_limit=10,
)
if config is not None:
empty.update(config)
@ -66,6 +72,7 @@ def patch_config(
deep_copy_locals: bool = False,
callbacks: Optional[BaseCallbackManager] = None,
executor: Optional[Executor] = None,
recursion_limit: Optional[int] = None,
) -> RunnableConfig:
config = ensure_config(config)
if deep_copy_locals:
@ -74,6 +81,8 @@ def patch_config(
config["callbacks"] = callbacks
if executor is not None:
config["executor"] = executor
if recursion_limit is not None:
config["recursion_limit"] = recursion_limit
return config

File diff suppressed because one or more lines are too long

@ -1,4 +1,5 @@
from typing import Any, Dict, List, Optional
from operator import itemgetter
from typing import Any, Dict, List, Optional, Union
from uuid import UUID
import pytest
@ -176,6 +177,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
tags=[],
callbacks=None,
_locals={},
recursion_limit=10,
),
),
mocker.call(
@ -185,6 +187,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
tags=[],
callbacks=None,
_locals={},
recursion_limit=10,
),
),
]
@ -438,6 +441,50 @@ async def test_prompt_with_llm(
)
@pytest.mark.asyncio
@freeze_time("2023-01-01")
async def test_prompt_with_llm_and_async_lambda(
mocker: MockerFixture, snapshot: SnapshotAssertion
) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
llm = FakeListLLM(responses=["foo", "bar"])
async def passthrough(input: Any) -> Any:
return input
chain = prompt | llm | passthrough
assert isinstance(chain, RunnableSequence)
assert chain.first == prompt
assert chain.middle == [llm]
assert chain.last == RunnableLambda(func=passthrough)
assert dumps(chain, pretty=True) == snapshot
# Test invoke
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
llm_spy = mocker.spy(llm.__class__, "ainvoke")
tracer = FakeTracer()
assert (
await chain.ainvoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
)
== "foo"
)
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert llm_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
assert tracer.runs == snapshot
mocker.stop(prompt_spy)
mocker.stop(llm_spy)
@freeze_time("2023-01-01")
def test_prompt_with_chat_model_and_parser(
mocker: MockerFixture, snapshot: SnapshotAssertion
@ -722,6 +769,105 @@ async def test_router_runnable(
assert len(router_run.child_runs) == 2
@pytest.mark.asyncio
@freeze_time("2023-01-01")
async def test_higher_order_lambda_runnable(
mocker: MockerFixture, snapshot: SnapshotAssertion
) -> None:
math_chain = ChatPromptTemplate.from_template(
"You are a math genius. Answer the question: {question}"
) | FakeListLLM(responses=["4"])
english_chain = ChatPromptTemplate.from_template(
"You are an english major. Answer the question: {question}"
) | FakeListLLM(responses=["2"])
input_map: Runnable = RunnableMap(
{ # type: ignore[arg-type]
"key": lambda x: x["key"],
"input": {"question": lambda x: x["question"]},
}
)
def router(input: Dict[str, Any]) -> Runnable:
if input["key"] == "math":
return itemgetter("input") | math_chain
elif input["key"] == "english":
return itemgetter("input") | english_chain
else:
raise ValueError(f"Unknown key: {input['key']}")
chain: Runnable = input_map | router
assert dumps(chain, pretty=True) == snapshot
result = chain.invoke({"key": "math", "question": "2 + 2"})
assert result == "4"
result2 = chain.batch(
[{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}]
)
assert result2 == ["4", "2"]
result = await chain.ainvoke({"key": "math", "question": "2 + 2"})
assert result == "4"
result2 = await chain.abatch(
[{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}]
)
assert result2 == ["4", "2"]
# Test invoke
math_spy = mocker.spy(math_chain.__class__, "invoke")
tracer = FakeTracer()
assert (
chain.invoke({"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer]))
== "4"
)
assert math_spy.call_args.args[1] == {
"key": "math",
"input": {"question": "2 + 2"},
}
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
assert len(parent_run.child_runs) == 2
router_run = parent_run.child_runs[1]
assert router_run.name == "RunnableLambda"
assert len(router_run.child_runs) == 1
math_run = router_run.child_runs[0]
assert math_run.name == "RunnableSequence"
assert len(math_run.child_runs) == 3
# Test ainvoke
async def arouter(input: Dict[str, Any]) -> Runnable:
if input["key"] == "math":
return itemgetter("input") | math_chain
elif input["key"] == "english":
return itemgetter("input") | english_chain
else:
raise ValueError(f"Unknown key: {input['key']}")
achain: Runnable = input_map | arouter
math_spy = mocker.spy(math_chain.__class__, "ainvoke")
tracer = FakeTracer()
assert (
await achain.ainvoke(
{"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer])
)
== "4"
)
assert math_spy.call_args.args[1] == {
"key": "math",
"input": {"question": "2 + 2"},
}
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
assert len(parent_run.child_runs) == 2
router_run = parent_run.child_runs[1]
assert router_run.name == "RunnableLambda"
assert len(router_run.child_runs) == 1
math_run = router_run.child_runs[0]
assert math_run.name == "RunnableSequence"
assert len(math_run.child_runs) == 3
@freeze_time("2023-01-01")
def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None:
passthrough = mocker.Mock(side_effect=lambda x: x)
@ -1136,3 +1282,17 @@ def test_each(snapshot: SnapshotAssertion) -> None:
"test",
"this",
]
def test_recursive_lambda() -> None:
def _simple_recursion(x: int) -> Union[int, Runnable]:
if x < 10:
return RunnableLambda(lambda *args: _simple_recursion(x + 1))
else:
return x
runnable = RunnableLambda(_simple_recursion)
assert runnable.invoke(5) == 10
with pytest.raises(RecursionError):
runnable.invoke(0, {"recursion_limit": 9})

Loading…
Cancel
Save