Allow config propagation, Add default lambda name, Improve ergonomics of config passed in (#10273)

Makes it easier to do recursion using regular python compositional
patterns

```py
def lambda_decorator(func):
    """Decorate function as a RunnableLambda"""
    return runnable.RunnableLambda(func)

@lambda_decorator
def fibonacci(a, config: runnable.RunnableConfig) -> int:
    if a <= 1:
        return a
    else:
        return fibonacci.invoke(
            a - 1, config
        ) + fibonacci.invoke(a - 2, config)

fibonacci.invoke(10)
```

https://smith.langchain.com/public/cb98edb4-3a09-4798-9c22-a930037faf88/r

Also makes it more natural to do things like error handle and call other
langchain objects in ways we probably don't want to support in
`with_fallbacks()`

```py
@lambda_decorator
def handle_errors(a, config: runnable.RunnableConfig) -> int:
    try:
        return my_chain.invoke(a, config)
    except MyExceptionType as exc:
        return my_other_chain.invoke({"original": a, "error": exc}, config)
```

In this case, the next chain takes in the exception object. Maybe this
could be something we toggle in `with_fallbacks` but I fear we'll get
into uglier APIs + heavier cognitive load if we try to do too much there

---------

Co-authored-by: Nuno Campos <nuno@boringbits.io>
pull/10286/head
William FH 1 year ago committed by GitHub
parent c732d8fffd
commit ffca5e7eea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -39,6 +39,8 @@ from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import Field
from langchain.schema.runnable.config import (
RunnableConfig,
acall_func_with_variable_args,
call_func_with_variable_args,
ensure_config,
get_async_callback_manager_for_config,
get_callback_manager_for_config,
@ -47,16 +49,15 @@ from langchain.schema.runnable.config import (
patch_config,
)
from langchain.schema.runnable.utils import (
Input,
Output,
accepts_config,
accepts_run_manager,
accepts_run_manager_and_config,
gather_with_concurrency,
)
from langchain.utils.aiter import atee, py_anext
from langchain.utils.iter import safetee
Input = TypeVar("Input")
# Output type should implement __concat__, as eg str, list, dict do
Output = TypeVar("Output")
Other = TypeVar("Other")
@ -311,16 +312,7 @@ class Runnable(Generic[Input, Output], ABC):
name=config.get("run_name"),
)
try:
if accepts_run_manager_and_config(func):
output = func(
input,
run_manager=run_manager,
config=config,
) # type: ignore[call-arg]
elif accepts_run_manager(func):
output = func(input, run_manager=run_manager) # type: ignore[call-arg]
else:
output = func(input) # type: ignore[call-arg]
output = call_func_with_variable_args(func, input, run_manager, config)
except Exception as e:
run_manager.on_chain_error(e)
raise
@ -353,19 +345,9 @@ class Runnable(Generic[Input, Output], ABC):
name=config.get("run_name"),
)
try:
if accepts_run_manager_and_config(func):
output = await func(
input,
run_manager=run_manager,
config=config,
) # type: ignore[call-arg]
elif accepts_run_manager(func):
output = await func(
input,
run_manager=run_manager,
) # type: ignore[call-arg]
else:
output = await func(input) # type: ignore[call-arg]
output = await acall_func_with_variable_args(
func, input, run_manager, config
)
except Exception as e:
await run_manager.on_chain_error(e)
raise
@ -408,16 +390,15 @@ class Runnable(Generic[Input, Output], ABC):
)
]
try:
if accepts_run_manager_and_config(func):
output = func(
input,
run_manager=run_managers,
config=configs,
) # type: ignore[call-arg]
elif accepts_run_manager(func):
output = func(input, run_manager=run_managers) # type: ignore[call-arg]
else:
output = func(input) # type: ignore[call-arg]
kwargs: Dict[str, Any] = {}
if accepts_config(func):
kwargs["config"] = [
patch_config(c, callbacks=rm.get_child())
for c, rm in zip(configs, run_managers)
]
if accepts_run_manager(func):
kwargs["run_manager"] = run_managers
output = func(input, **kwargs) # type: ignore[call-arg]
except Exception as e:
for run_manager in run_managers:
run_manager.on_chain_error(e)
@ -479,16 +460,15 @@ class Runnable(Generic[Input, Output], ABC):
)
)
try:
if accepts_run_manager_and_config(func):
output = await func(
input,
run_manager=run_managers,
config=configs,
) # type: ignore[call-arg]
elif accepts_run_manager(func):
output = await func(input, run_manager=run_managers) # type: ignore
else:
output = await func(input) # type: ignore[call-arg]
kwargs: Dict[str, Any] = {}
if accepts_config(func):
kwargs["config"] = [
patch_config(c, callbacks=rm.get_child())
for c, rm in zip(configs, run_managers)
]
if accepts_run_manager(func):
kwargs["run_manager"] = run_managers
output = await func(input, **kwargs) # type: ignore[call-arg]
except Exception as e:
await asyncio.gather(
*(run_manager.on_chain_error(e) for run_manager in run_managers)
@ -550,19 +530,14 @@ class Runnable(Generic[Input, Output], ABC):
name=config.get("run_name"),
)
try:
if accepts_run_manager_and_config(transformer):
iterator = transformer(
input_for_transform,
run_manager=run_manager,
config=config,
) # type: ignore[call-arg]
elif accepts_run_manager(transformer):
iterator = transformer(
input_for_transform,
run_manager=run_manager,
) # type: ignore[call-arg]
else:
iterator = transformer(input_for_transform) # type: ignore[call-arg]
kwargs: Dict[str, Any] = {}
if accepts_config(transformer):
kwargs["config"] = patch_config(
config, callbacks=run_manager.get_child()
)
if accepts_run_manager(transformer):
kwargs["run_manager"] = run_manager
iterator = transformer(input_for_transform, **kwargs) # type: ignore[call-arg]
for chunk in iterator:
yield chunk
if final_output_supported:
@ -631,21 +606,14 @@ class Runnable(Generic[Input, Output], ABC):
name=config.get("run_name"),
)
try:
# mypy can't quite work out thew type guard here, but this is safe,
# check implementations of the accepts_* functions
if accepts_run_manager_and_config(transformer):
iterator = transformer(
input_for_transform,
run_manager=run_manager,
config=config,
) # type: ignore[call-arg]
elif accepts_run_manager(transformer):
iterator = transformer(
input_for_transform,
run_manager=run_manager,
) # type: ignore[call-arg]
else:
iterator = transformer(input_for_transform) # type: ignore[call-arg]
kwargs: Dict[str, Any] = {}
if accepts_config(transformer):
kwargs["config"] = patch_config(
config, callbacks=run_manager.get_child()
)
if accepts_run_manager(transformer):
kwargs["run_manager"] = run_manager
iterator = transformer(input_for_transform, **kwargs) # type: ignore[call-arg]
async for chunk in iterator:
yield chunk
if final_output_supported:
@ -1756,7 +1724,7 @@ class RunnableLambda(Runnable[Input, Output]):
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
) -> Output:
output = self.func(input)
output = call_func_with_variable_args(self.func, input, run_manager, config)
# If the output is a runnable, invoke it
if isinstance(output, Runnable):
recursion_limit = config["recursion_limit"]
@ -1780,7 +1748,9 @@ class RunnableLambda(Runnable[Input, Output]):
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
) -> Output:
output = await self.afunc(input)
output = await acall_func_with_variable_args(
self.afunc, input, run_manager, config
)
# If the output is a runnable, invoke it
if isinstance(output, Runnable):
recursion_limit = config["recursion_limit"]
@ -1798,6 +1768,21 @@ class RunnableLambda(Runnable[Input, Output]):
)
return output
def _config(
self, config: Optional[RunnableConfig], callable: Callable[..., Any]
) -> RunnableConfig:
config = config or {}
if config.get("run_name") is None:
try:
run_name = callable.__name__
except AttributeError:
run_name = None
if run_name is not None:
return patch_config(config, run_name=run_name)
return config
def invoke(
self,
input: Input,
@ -1805,7 +1790,11 @@ class RunnableLambda(Runnable[Input, Output]):
**kwargs: Optional[Any],
) -> Output:
if hasattr(self, "func"):
return self._call_with_config(self._invoke, input, config)
return self._call_with_config(
self._invoke,
input,
self._config(config, self.func),
)
else:
raise TypeError(
"Cannot invoke a coroutine function synchronously."
@ -1819,7 +1808,11 @@ class RunnableLambda(Runnable[Input, Output]):
**kwargs: Optional[Any],
) -> Output:
if hasattr(self, "afunc"):
return await self._acall_with_config(self._ainvoke, input, config)
return await self._acall_with_config(
self._ainvoke,
input,
self._config(config, self.afunc),
)
else:
# Delegating to super implementation of ainvoke.
# Uses asyncio executor to run the sync version (invoke)

@ -3,13 +3,35 @@ from __future__ import annotations
from concurrent.futures import Executor, ThreadPoolExecutor
from contextlib import contextmanager
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Generator,
List,
Optional,
Union,
)
from typing_extensions import TypedDict
from langchain.schema.runnable.utils import (
Input,
Output,
accepts_config,
accepts_run_manager,
)
if TYPE_CHECKING:
from langchain.callbacks.base import BaseCallbackManager, Callbacks
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForChainRun,
CallbackManager,
CallbackManagerForChainRun,
)
class RunnableConfig(TypedDict, total=False):
@ -117,6 +139,47 @@ def patch_config(
return config
def call_func_with_variable_args(
func: Union[
Callable[[Input], Output],
Callable[[Input, CallbackManagerForChainRun], Output],
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
],
input: Input,
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
) -> Output:
"""Call function that may optionally accept a run_manager and/or config."""
kwargs: Dict[str, Any] = {}
if accepts_config(func):
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
if accepts_run_manager(func):
kwargs["run_manager"] = run_manager
return func(input, **kwargs) # type: ignore[call-arg]
async def acall_func_with_variable_args(
func: Union[
Callable[[Input], Awaitable[Output]],
Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]],
Callable[
[Input, AsyncCallbackManagerForChainRun, RunnableConfig],
Awaitable[Output],
],
],
input: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
) -> Output:
"""Call function that may optionally accept a run_manager and/or config."""
kwargs: Dict[str, Any] = {}
if accepts_config(func):
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
if accepts_run_manager(func):
kwargs["run_manager"] = run_manager
return await func(input, **kwargs) # type: ignore[call-arg]
def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
from langchain.callbacks.manager import CallbackManager

@ -2,7 +2,11 @@ from __future__ import annotations
import asyncio
from inspect import signature
from typing import Any, Callable, Coroutine, Union
from typing import Any, Callable, Coroutine, TypeVar, Union
Input = TypeVar("Input")
# Output type should implement __concat__, as eg str, list, dict do
Output = TypeVar("Output")
async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
@ -26,8 +30,8 @@ def accepts_run_manager(callable: Callable[..., Any]) -> bool:
return False
def accepts_run_manager_and_config(callable: Callable[..., Any]) -> bool:
return (
accepts_run_manager(callable)
and signature(callable).parameters.get("config") is not None
)
def accepts_config(callable: Callable[..., Any]) -> bool:
try:
return signature(callable).parameters.get("config") is not None
except ValueError:
return False

File diff suppressed because one or more lines are too long

@ -948,7 +948,7 @@ async def test_higher_order_lambda_runnable(
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
assert len(parent_run.child_runs) == 2
router_run = parent_run.child_runs[1]
assert router_run.name == "RunnableLambda"
assert router_run.name == "router"
assert len(router_run.child_runs) == 1
math_run = router_run.child_runs[0]
assert math_run.name == "RunnableSequence"
@ -980,7 +980,7 @@ async def test_higher_order_lambda_runnable(
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
assert len(parent_run.child_runs) == 2
router_run = parent_run.child_runs[1]
assert router_run.name == "RunnableLambda"
assert router_run.name == "arouter"
assert len(router_run.child_runs) == 1
math_run = router_run.child_runs[0]
assert math_run.name == "RunnableSequence"

Loading…
Cancel
Save