This commit is contained in:
Eugene Yurtsev 2023-09-28 10:51:17 -04:00
parent 5c1f462bb9
commit a5b15e9d0f
3 changed files with 82 additions and 0 deletions

View File

@ -107,6 +107,14 @@ def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
def get_lambda_source(func: Callable) -> Optional[str]: def get_lambda_source(func: Callable) -> Optional[str]:
"""Get the source code of a lambda function.
Args:
func: a callable that can be a lambda function
Returns:
str: the source code of the lambda function
"""
try: try:
code = inspect.getsource(func) code = inspect.getsource(func)
tree = ast.parse(textwrap.dedent(code)) tree = ast.parse(textwrap.dedent(code))
@ -118,6 +126,15 @@ def get_lambda_source(func: Callable) -> Optional[str]:
def indent_lines_after_first(text: str, prefix: str) -> str: def indent_lines_after_first(text: str, prefix: str) -> str:
"""Indent all lines of text after the first line.
Args:
text: The text to indent
prefix: Used to determine the number of spaces to indent
Returns:
str: The indented text
"""
n_spaces = len(prefix) n_spaces = len(prefix)
spaces = " " * n_spaces spaces = " " * n_spaces
lines = text.splitlines() lines = text.splitlines()

View File

@ -2739,3 +2739,34 @@ async def test_runnable_branch_abatch() -> None:
) )
assert await branch.abatch([1, 10, 0]) == [2, 100, -1] assert await branch.abatch([1, 10, 0]) == [2, 100, -1]
def test_reprsentation_of_runnables() -> None:
"""Test representation of runnables."""
runnable = RunnableLambda(lambda x: x * 2)
assert repr(runnable) == "RunnableLambda(lambda x: x * 2)"
def f(x: int) -> int:
"""Return 2."""
return 2
assert repr(RunnableLambda(func=f)) == "RunnableLambda(...)"
async def af(x: int) -> int:
"""Return 2."""
return 2
assert repr(RunnableLambda(func=f, afunc=af)) == "RunnableLambda(...)"
sequence = RunnableLambda(lambda x: x + 2) | {
"a": RunnableLambda(lambda x: x * 2),
"b": RunnableLambda(lambda x: x * 3),
}
assert repr(sequence) == (
"RunnableLambda(lambda x: x * 3)\n"
"| {\n"
" a: RunnableLambda(...),\n"
" b: RunnableLambda(...)\n"
" }"
)

View File

@ -0,0 +1,34 @@
import pytest
from langchain.schema.runnable.utils import (
get_lambda_source,
indent_lines_after_first,
)
from langchain.schema.runnable.base import RunnableLambda
# Test get_lambda_source function
@pytest.mark.parametrize(
"func, expected_source",
[
(lambda x: x * 2, "lambda x: x * 2"),
(lambda a, b: a + b, "lambda a, b: a + b"),
(lambda x: x if x > 0 else 0, "lambda x: x if x > 0 else 0"),
],
)
def test_get_lambda_source(func, expected_source):
source = get_lambda_source(func)
assert source == expected_source
@pytest.mark.parametrize(
"text,prefix,expected_output",
[
("line 1\nline 2\nline 3", "1", "line 1\n line 2\n line 3"),
("line 1\nline 2\nline 3", "ax", "line 1\n line 2\n line 3"),
],
)
def test_indent_lines_after_first(text, prefix, expected_output):
indented_text = indent_lines_after_first(text, prefix)
assert indented_text == expected_output