Implement better reprs for Runnables

pull/11175/head
Nuno Campos 11 months ago
parent cfa2203c62
commit 5c1f462bb9

@ -68,6 +68,13 @@ class Serializable(BaseModel, ABC):
class Config:
extra = "ignore"
def __repr_args__(self) -> Any:
return [
(k, v)
for k, v in super().__repr_args__()
if (k not in self.__fields__ or self.__fields__[k].get_default() != v)
]
_lc_kwargs = PrivateAttr(default_factory=dict)
def __init__(self, **kwargs: Any) -> None:

@ -59,6 +59,8 @@ from langchain.schema.runnable.utils import (
accepts_run_manager,
gather_with_concurrency,
get_function_first_arg_dict_keys,
get_lambda_source,
indent_lines_after_first,
)
from langchain.utils.aiter import atee, py_anext
from langchain.utils.iter import safetee
@ -1298,6 +1300,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
def output_schema(self) -> Type[BaseModel]:
return self.last.output_schema
def __repr__(self) -> str:
return "\n| ".join(
repr(s) if i == 0 else indent_lines_after_first(repr(s), "| ")
for i, s in enumerate(self.steps)
)
def __or__(
self,
other: Union[
@ -1819,6 +1827,13 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
**{k: (v.OutputType, None) for k, v in self.steps.items()},
)
def __repr__(self) -> str:
map_for_repr = ",\n ".join(
f"{k}: {indent_lines_after_first(repr(v), ' ' + k + ': ')}"
for k, v in self.steps.items()
)
return "{\n " + map_for_repr + "\n}"
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
@ -2134,7 +2149,7 @@ class RunnableLambda(Runnable[Input, Output]):
return False
def __repr__(self) -> str:
return "RunnableLambda(...)"
return f"RunnableLambda({get_lambda_source(self.func) or '...'})"
def _invoke(
self,

@ -87,6 +87,14 @@ class IsFunctionArgDict(ast.NodeVisitor):
IsLocalDict(input_arg_name, self.keys).visit(node)
class GetLambdaSource(ast.NodeVisitor):
def __init__(self) -> None:
self.source: Optional[str] = None
def visit_Lambda(self, node: ast.Lambda) -> Any:
self.source = ast.unparse(node)
def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
try:
code = inspect.getsource(func)
@ -94,5 +102,23 @@ def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
visitor = IsFunctionArgDict()
visitor.visit(tree)
return list(visitor.keys) if visitor.keys else None
except (TypeError, OSError):
except (SyntaxError, TypeError, OSError):
return None
def get_lambda_source(func: Callable) -> Optional[str]:
try:
code = inspect.getsource(func)
tree = ast.parse(textwrap.dedent(code))
visitor = GetLambdaSource()
visitor.visit(tree)
return visitor.source
except (SyntaxError, TypeError, OSError):
return None
def indent_lines_after_first(text: str, prefix: str) -> str:
n_spaces = len(prefix)
spaces = " " * n_spaces
lines = text.splitlines()
return "\n".join([lines[0]] + [spaces + line for line in lines[1:]])

File diff suppressed because one or more lines are too long

@ -867,6 +867,7 @@ async def test_prompt_with_chat_model(
chain = prompt | chat
assert repr(chain) == snapshot
assert isinstance(chain, RunnableSequence)
assert chain.first == prompt
assert chain.middle == []
@ -1350,6 +1351,7 @@ Question:
| parser
)
assert repr(chain) == snapshot
assert isinstance(chain, RunnableSequence)
assert isinstance(chain.first, RunnableMap)
assert chain.middle == [prompt, chat]
@ -1375,7 +1377,7 @@ Question:
SystemMessage(content="You are a nice assistant."),
HumanMessage(
content="""Context:
[Document(page_content='foo', metadata={}), Document(page_content='bar', metadata={})]
[Document(page_content='foo'), Document(page_content='bar')]
Question:
What is your name?"""
@ -1413,6 +1415,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
}
)
assert repr(chain) == snapshot
assert isinstance(chain, RunnableSequence)
assert chain.first == prompt
assert chain.middle == [RunnableLambda(passthrough)]
@ -2196,6 +2199,7 @@ def test_retrying(mocker: MockerFixture) -> None:
with pytest.raises(RuntimeError):
runnable.with_retry(
stop_after_attempt=2,
wait_exponential_jitter=False,
retry_if_exception_type=(ValueError,),
).invoke(2)
@ -2205,6 +2209,7 @@ def test_retrying(mocker: MockerFixture) -> None:
with pytest.raises(ValueError):
runnable.with_retry(
stop_after_attempt=2,
wait_exponential_jitter=False,
retry_if_exception_type=(ValueError,),
).batch([1, 2, 0])
@ -2214,6 +2219,7 @@ def test_retrying(mocker: MockerFixture) -> None:
output = runnable.with_retry(
stop_after_attempt=2,
wait_exponential_jitter=False,
retry_if_exception_type=(ValueError,),
).batch([1, 2, 0], return_exceptions=True)
@ -2248,6 +2254,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
with pytest.raises(ValueError):
await runnable.with_retry(
stop_after_attempt=2,
wait_exponential_jitter=False,
retry_if_exception_type=(ValueError, KeyError),
).ainvoke(1)
@ -2257,6 +2264,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
with pytest.raises(RuntimeError):
await runnable.with_retry(
stop_after_attempt=2,
wait_exponential_jitter=False,
retry_if_exception_type=(ValueError,),
).ainvoke(2)
@ -2266,6 +2274,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
with pytest.raises(ValueError):
await runnable.with_retry(
stop_after_attempt=2,
wait_exponential_jitter=False,
retry_if_exception_type=(ValueError,),
).abatch([1, 2, 0])
@ -2275,6 +2284,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
output = await runnable.with_retry(
stop_after_attempt=2,
wait_exponential_jitter=False,
retry_if_exception_type=(ValueError,),
).abatch([1, 2, 0], return_exceptions=True)

Loading…
Cancel
Save