From 719da8746e6239b015452c0b90415be171ae5876 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 17 Apr 2024 15:38:39 -0700 Subject: [PATCH] core: fix attributeerror in runnablelambda.deps (#20569) - would happen when user's code tries to access attritbute that doesnt exist, we prefer to let this crash in the user's code, rather than here - also catch more cases where a runnable is invoked/streamed inside a lambda. before we weren't seeing these as deps --- libs/core/langchain_core/runnables/base.py | 8 ++++++- libs/core/langchain_core/runnables/utils.py | 5 ++++- .../tests/unit_tests/runnables/test_utils.py | 22 ++++++++++++++++++- 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index b81a15d5d6..7645bf52bc 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -3752,7 +3752,13 @@ class RunnableLambda(Runnable[Input, Output]): else: objects = [] - return [obj for obj in objects if isinstance(obj, Runnable)] + deps: List[Runnable] = [] + for obj in objects: + if isinstance(obj, Runnable): + deps.append(obj) + elif isinstance(getattr(obj, "__self__", None), Runnable): + deps.append(obj.__self__) + return deps @property def config_specs(self) -> List[ConfigurableFieldSpec]: diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index 40f52c1816..dff10ad049 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -263,7 +263,10 @@ def get_function_nonlocals(func: Callable) -> List[Any]: if vv is None: break else: - vv = getattr(vv, part) + try: + vv = getattr(vv, part) + except AttributeError: + break else: values.append(vv) return values diff --git a/libs/core/tests/unit_tests/runnables/test_utils.py b/libs/core/tests/unit_tests/runnables/test_utils.py index 1bbf5a8a91..fa82826857 100644 --- a/libs/core/tests/unit_tests/runnables/test_utils.py +++ b/libs/core/tests/unit_tests/runnables/test_utils.py @@ -1,9 +1,11 @@ import sys -from typing import Callable +from typing import Callable, Dict import pytest +from langchain_core.runnables.base import RunnableLambda from langchain_core.runnables.utils import ( + get_function_nonlocals, get_lambda_source, indent_lines_after_first, ) @@ -37,3 +39,21 @@ def test_indent_lines_after_first(text: str, prefix: str, expected_output: str) """Test indent_lines_after_first function""" indented_text = indent_lines_after_first(text, prefix) assert indented_text == expected_output + + +def test_nonlocals() -> None: + agent = RunnableLambda(lambda x: x * 2) # noqa: F841 + + def my_func(input: str, agent: Dict[str, str]) -> str: + return agent.get("agent_name", input) + + def my_func2(input: str) -> str: + return agent.get("agent_name", input) # type: ignore[attr-defined] + + def my_func3(input: str) -> str: + return agent.invoke(input) + + assert get_function_nonlocals(my_func) == [] + assert get_function_nonlocals(my_func2) == [] + assert get_function_nonlocals(my_func3) == [agent.invoke] + assert RunnableLambda(my_func3).deps == [agent]