Implement better reprs for Runnables (#11175)

```
ChatPromptTemplate(messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], template='You are a nice assistant.')), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], template='{question}'))])
| RunnableLambda(lambda x: x)
| {
    chat: FakeListChatModel(responses=["i'm a chatbot"]),
    llm: FakeListLLM(responses=["i'm a textbot"])
  }
```

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
pull/11218/head
Nuno Campos 11 months ago committed by GitHub
commit 61b5942adf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -77,6 +77,13 @@ class Serializable(BaseModel, ABC):
class Config:
extra = "ignore"
def __repr_args__(self) -> Any:
return [
(k, v)
for k, v in super().__repr_args__()
if (k not in self.__fields__ or self.__fields__[k].get_default() != v)
]
_lc_kwargs = PrivateAttr(default_factory=dict)
def __init__(self, **kwargs: Any) -> None:

@ -59,6 +59,8 @@ from langchain.schema.runnable.utils import (
accepts_run_manager,
gather_with_concurrency,
get_function_first_arg_dict_keys,
get_lambda_source,
indent_lines_after_first,
)
from langchain.utils.aiter import atee, py_anext
from langchain.utils.iter import safetee
@ -1298,6 +1300,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
def output_schema(self) -> Type[BaseModel]:
return self.last.output_schema
def __repr__(self) -> str:
return "\n| ".join(
repr(s) if i == 0 else indent_lines_after_first(repr(s), "| ")
for i, s in enumerate(self.steps)
)
def __or__(
self,
other: Union[
@ -1819,6 +1827,13 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
**{k: (v.OutputType, None) for k, v in self.steps.items()},
)
def __repr__(self) -> str:
map_for_repr = ",\n ".join(
f"{k}: {indent_lines_after_first(repr(v), ' ' + k + ': ')}"
for k, v in self.steps.items()
)
return "{\n " + map_for_repr + "\n}"
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
@ -2134,7 +2149,7 @@ class RunnableLambda(Runnable[Input, Output]):
return False
def __repr__(self) -> str:
return "RunnableLambda(...)"
return f"RunnableLambda({get_lambda_source(self.func) or '...'})"
def _invoke(
self,

@ -87,6 +87,17 @@ class IsFunctionArgDict(ast.NodeVisitor):
IsLocalDict(input_arg_name, self.keys).visit(node)
class GetLambdaSource(ast.NodeVisitor):
def __init__(self) -> None:
self.source: Optional[str] = None
self.count = 0
def visit_Lambda(self, node: ast.Lambda) -> Any:
self.count += 1
if hasattr(ast, "unparse"):
self.source = ast.unparse(node)
def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
try:
code = inspect.getsource(func)
@ -94,5 +105,40 @@ def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
visitor = IsFunctionArgDict()
visitor.visit(tree)
return list(visitor.keys) if visitor.keys else None
except (TypeError, OSError):
except (SyntaxError, TypeError, OSError):
return None
def get_lambda_source(func: Callable) -> Optional[str]:
"""Get the source code of a lambda function.
Args:
func: a callable that can be a lambda function
Returns:
str: the source code of the lambda function
"""
try:
code = inspect.getsource(func)
tree = ast.parse(textwrap.dedent(code))
visitor = GetLambdaSource()
visitor.visit(tree)
return visitor.source if visitor.count == 1 else None
except (SyntaxError, TypeError, OSError):
return None
def indent_lines_after_first(text: str, prefix: str) -> str:
"""Indent all lines of text after the first line.
Args:
text: The text to indent
prefix: Used to determine the number of spaces to indent
Returns:
str: The indented text
"""
n_spaces = len(prefix)
spaces = " " * n_spaces
lines = text.splitlines()
return "\n".join([lines[0]] + [spaces + line for line in lines[1:]])

File diff suppressed because one or more lines are too long

@ -867,6 +867,7 @@ async def test_prompt_with_chat_model(
chain = prompt | chat
assert repr(chain) == snapshot
assert isinstance(chain, RunnableSequence)
assert chain.first == prompt
assert chain.middle == []
@ -1276,7 +1277,8 @@ def test_combining_sequences(
assert chain.first == prompt
assert chain.middle == [chat]
assert chain.last == parser
assert dumps(chain, pretty=True) == snapshot
if sys.version_info >= (3, 9):
assert dumps(chain, pretty=True) == snapshot
prompt2 = (
SystemMessagePromptTemplate.from_template("You are a nicer assistant.")
@ -1294,7 +1296,8 @@ def test_combining_sequences(
assert chain2.first == input_formatter
assert chain2.middle == [prompt2, chat2]
assert chain2.last == parser2
assert dumps(chain2, pretty=True) == snapshot
if sys.version_info >= (3, 9):
assert dumps(chain2, pretty=True) == snapshot
combined_chain = chain | chain2
@ -1307,7 +1310,8 @@ def test_combining_sequences(
chat2,
]
assert combined_chain.last == parser2
assert dumps(combined_chain, pretty=True) == snapshot
if sys.version_info >= (3, 9):
assert dumps(combined_chain, pretty=True) == snapshot
# Test invoke
tracer = FakeTracer()
@ -1315,7 +1319,8 @@ def test_combining_sequences(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == ["baz", "qux"]
assert tracer.runs == snapshot
if sys.version_info >= (3, 9):
assert tracer.runs == snapshot
@freeze_time("2023-01-01")
@ -1350,6 +1355,7 @@ Question:
| parser
)
assert repr(chain) == snapshot
assert isinstance(chain, RunnableSequence)
assert isinstance(chain.first, RunnableMap)
assert chain.middle == [prompt, chat]
@ -1375,7 +1381,7 @@ Question:
SystemMessage(content="You are a nice assistant."),
HumanMessage(
content="""Context:
[Document(page_content='foo', metadata={}), Document(page_content='bar', metadata={})]
[Document(page_content='foo'), Document(page_content='bar')]
Question:
What is your name?"""
@ -1413,6 +1419,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
}
)
assert repr(chain) == snapshot
assert isinstance(chain, RunnableSequence)
assert chain.first == prompt
assert chain.middle == [RunnableLambda(passthrough)]
@ -2098,7 +2105,8 @@ async def test_llm_with_fallbacks(
assert await runnable.ainvoke("hello") == "bar"
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
assert list(await runnable.ainvoke("hello")) == list("bar")
assert dumps(runnable, pretty=True) == snapshot
if sys.version_info >= (3, 9):
assert dumps(runnable, pretty=True) == snapshot
class FakeSplitIntoListParser(BaseOutputParser[List[str]]):
@ -2196,6 +2204,7 @@ def test_retrying(mocker: MockerFixture) -> None:
with pytest.raises(RuntimeError):
runnable.with_retry(
stop_after_attempt=2,
wait_exponential_jitter=False,
retry_if_exception_type=(ValueError,),
).invoke(2)
@ -2205,6 +2214,7 @@ def test_retrying(mocker: MockerFixture) -> None:
with pytest.raises(ValueError):
runnable.with_retry(
stop_after_attempt=2,
wait_exponential_jitter=False,
retry_if_exception_type=(ValueError,),
).batch([1, 2, 0])
@ -2214,6 +2224,7 @@ def test_retrying(mocker: MockerFixture) -> None:
output = runnable.with_retry(
stop_after_attempt=2,
wait_exponential_jitter=False,
retry_if_exception_type=(ValueError,),
).batch([1, 2, 0], return_exceptions=True)
@ -2248,6 +2259,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
with pytest.raises(ValueError):
await runnable.with_retry(
stop_after_attempt=2,
wait_exponential_jitter=False,
retry_if_exception_type=(ValueError, KeyError),
).ainvoke(1)
@ -2257,6 +2269,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
with pytest.raises(RuntimeError):
await runnable.with_retry(
stop_after_attempt=2,
wait_exponential_jitter=False,
retry_if_exception_type=(ValueError,),
).ainvoke(2)
@ -2266,6 +2279,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
with pytest.raises(ValueError):
await runnable.with_retry(
stop_after_attempt=2,
wait_exponential_jitter=False,
retry_if_exception_type=(ValueError,),
).abatch([1, 2, 0])
@ -2275,6 +2289,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
output = await runnable.with_retry(
stop_after_attempt=2,
wait_exponential_jitter=False,
retry_if_exception_type=(ValueError,),
).abatch([1, 2, 0], return_exceptions=True)
@ -2729,3 +2744,38 @@ async def test_runnable_branch_abatch() -> None:
)
assert await branch.abatch([1, 10, 0]) == [2, 100, -1]
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
)
def test_representation_of_runnables() -> None:
"""Test representation of runnables."""
runnable = RunnableLambda(lambda x: x * 2)
assert repr(runnable) == "RunnableLambda(lambda x: x * 2)"
def f(x: int) -> int:
"""Return 2."""
return 2
assert repr(RunnableLambda(func=f)) == "RunnableLambda(...)"
async def af(x: int) -> int:
"""Return 2."""
return 2
assert repr(RunnableLambda(func=f, afunc=af)) == "RunnableLambda(...)"
assert repr(
RunnableLambda(lambda x: x + 2)
| {
"a": RunnableLambda(lambda x: x * 2),
"b": RunnableLambda(lambda x: x * 3),
}
) == (
"RunnableLambda(...)\n"
"| {\n"
" a: RunnableLambda(...),\n"
" b: RunnableLambda(...)\n"
" }"
), "repr where code string contains multiple lambdas gives up"

@ -0,0 +1,39 @@
import sys
from typing import Callable
import pytest
from langchain.schema.runnable.utils import (
get_lambda_source,
indent_lines_after_first,
)
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
)
@pytest.mark.parametrize(
"func, expected_source",
[
(lambda x: x * 2, "lambda x: x * 2"),
(lambda a, b: a + b, "lambda a, b: a + b"),
(lambda x: x if x > 0 else 0, "lambda x: x if x > 0 else 0"),
],
)
def test_get_lambda_source(func: Callable, expected_source: str) -> None:
"""Test get_lambda_source function"""
source = get_lambda_source(func)
assert source == expected_source
@pytest.mark.parametrize(
"text,prefix,expected_output",
[
("line 1\nline 2\nline 3", "1", "line 1\n line 2\n line 3"),
("line 1\nline 2\nline 3", "ax", "line 1\n line 2\n line 3"),
],
)
def test_indent_lines_after_first(text: str, prefix: str, expected_output: str) -> None:
"""Test indent_lines_after_first function"""
indented_text = indent_lines_after_first(text, prefix)
assert indented_text == expected_output
Loading…
Cancel
Save