diff --git a/libs/langchain/langchain/schema/runnable/utils.py b/libs/langchain/langchain/schema/runnable/utils.py index 693749a534..7fb08b433f 100644 --- a/libs/langchain/langchain/schema/runnable/utils.py +++ b/libs/langchain/langchain/schema/runnable/utils.py @@ -90,8 +90,10 @@ class IsFunctionArgDict(ast.NodeVisitor): class GetLambdaSource(ast.NodeVisitor): def __init__(self) -> None: self.source: Optional[str] = None + self.count = 0 def visit_Lambda(self, node: ast.Lambda) -> Any: + self.count += 1 if hasattr(ast, "unparse"): self.source = ast.unparse(node) @@ -118,10 +120,11 @@ def get_lambda_source(func: Callable) -> Optional[str]: """ try: code = inspect.getsource(func) + print(code) tree = ast.parse(textwrap.dedent(code)) visitor = GetLambdaSource() visitor.visit(tree) - return visitor.source + return visitor.source if visitor.count == 1 else None except (SyntaxError, TypeError, OSError): return None diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 9abeb1a03d..e276e0a141 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -2741,6 +2741,9 @@ async def test_runnable_branch_abatch() -> None: assert await branch.abatch([1, 10, 0]) == [2, 100, -1] +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run." +) def test_reprsentation_of_runnables() -> None: """Test representation of runnables.""" runnable = RunnableLambda(lambda x: x * 2) @@ -2765,9 +2768,9 @@ def test_reprsentation_of_runnables() -> None: "b": RunnableLambda(lambda x: x * 3), } ) == ( - "RunnableLambda(lambda x: x * 3)\n" + "RunnableLambda(...)\n" "| {\n" " a: RunnableLambda(...),\n" " b: RunnableLambda(...)\n" " }" - ) + ), "repr where code string contains multiple lambdas gives up"