Improve graph repr for runnable passthrough and itemgetter (#15083)

<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This commit is contained in:
Nuno Campos 2023-12-22 16:05:48 -08:00 committed by GitHub
parent 0d0901ea18
commit a2d3042823
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 72 additions and 57 deletions

View File

@ -2007,7 +2007,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
):
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"RunnableParallelInput",
"RunnableMapInput",
**{
k: (v.annotation, v.default)
for step in self.steps.values()
@ -2024,7 +2024,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"RunnableParallelOutput",
"RunnableMapOutput",
**{k: (v.OutputType, None) for k, v in self.steps.items()},
__config__=_SchemaConfig,
)
@ -2650,7 +2650,9 @@ class RunnableLambda(Runnable[Input, Output]):
def __repr__(self) -> str:
"""A string representation of this runnable."""
if hasattr(self, "func"):
if hasattr(self, "func") and isinstance(self.func, itemgetter):
return f"RunnableLambda({str(self.func)[len('operator.'):]})"
elif hasattr(self, "func"):
return f"RunnableLambda({get_lambda_source(self.func) or '...'})"
elif hasattr(self, "afunc"):
return f"RunnableLambda(afunc={get_lambda_source(self.afunc) or '...'})"

View File

@ -123,13 +123,13 @@ class Graph:
or len(data.splitlines()) > 1
):
data = node.data.__class__.__name__
elif len(data) > 36:
data = data[:36] + "..."
elif len(data) > 42:
data = data[:42] + "..."
except Exception:
data = node.data.__class__.__name__
else:
data = node.data.__name__
return data
return data if not data.startswith("Runnable") else data[8:]
return draw(
{node.id: node_data(node) for node in self.nodes.values()},

View File

@ -34,6 +34,7 @@ from langchain_core.runnables.config import (
get_executor_for_config,
patch_config,
)
from langchain_core.runnables.graph import Graph
from langchain_core.runnables.utils import AddableDict, ConfigurableFieldSpec
from langchain_core.utils.aiter import atee, py_anext
from langchain_core.utils.iter import safetee
@ -297,6 +298,9 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
yield chunk
_graph_passthrough: RunnablePassthrough = RunnablePassthrough()
class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
"""
A runnable that assigns key-value pairs to Dict[str, Any] inputs.
@ -355,6 +359,18 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
def config_specs(self) -> List[ConfigurableFieldSpec]:
return self.mapper.config_specs
def get_graph(self, config: RunnableConfig | None = None) -> Graph:
# get graph from mapper
graph = self.mapper.get_graph(config)
# add passthrough node and edges
input_node = graph.first_node()
output_node = graph.last_node()
if input_node is not None and output_node is not None:
passthrough_node = graph.add_node(_graph_passthrough)
graph.add_edge(input_node, passthrough_node)
graph.add_edge(passthrough_node, output_node)
return graph
def _invoke(
self,
input: Dict[str, Any],

View File

@ -32,51 +32,51 @@
# ---
# name: test_graph_sequence_map
'''
+-------------+
| PromptInput |
+-------------+
*
*
*
+----------------+
| PromptTemplate |
+----------------+
*
*
*
+-------------+
| FakeListLLM |
+-------------+
*
*
*
+-----------------------+
| RunnableParallelInput |
+-----------------------+**
**** *******
**** *****
** *******
+---------------------+ ***
| RunnableLambdaInput | *
+---------------------+ *
*** *** *
*** *** *
** ** *
+-----------------+ +-----------------+ *
| StrOutputParser | | XMLOutputParser | *
+-----------------+ +-----------------+ *
*** *** *
*** *** *
** ** *
+----------------------+ +--------------------------------+
| RunnableLambdaOutput | | CommaSeparatedListOutputParser |
+----------------------+ +--------------------------------+
**** *******
**** *****
** ****
+------------------------+
| RunnableParallelOutput |
+------------------------+
+-------------+
| PromptInput |
+-------------+
*
*
*
+----------------+
| PromptTemplate |
+----------------+
*
*
*
+-------------+
| FakeListLLM |
+-------------+
*
*
*
+---------------+
| ParallelInput |
+---------------+*****
*** ******
*** *****
** *****
+-------------+ ***
| LambdaInput | *
+-------------+ *
** ** *
*** *** *
** ** *
+-----------------+ +-----------------+ *
| StrOutputParser | | XMLOutputParser | *
+-----------------+ +-----------------+ *
** ** *
*** *** *
** ** *
+--------------+ +--------------------------------+
| LambdaOutput | | CommaSeparatedListOutputParser |
+--------------+ +--------------------------------+
*** ******
*** *****
** ***
+-----------+
| MapOutput |
+-----------+
'''
# ---
# name: test_graph_single_runnable

View File

@ -569,7 +569,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
"properties": {"name": {"title": "Name", "type": "string"}},
}
assert seq_w_map.output_schema.schema() == {
"title": "RunnableParallelOutput",
"title": "RunnableMapOutput",
"type": "object",
"properties": {
"original": {"title": "Original", "type": "string"},
@ -613,7 +613,7 @@ def test_passthrough_assign_schema() -> None:
# expected dict input_schema
assert invalid_seq_w_assign.input_schema.schema() == {
"properties": {"question": {"title": "Question"}},
"title": "RunnableParallelInput",
"title": "RunnableMapInput",
"type": "object",
}
@ -768,7 +768,7 @@ def test_schema_complex_seq() -> None:
)
assert chain2.input_schema.schema() == {
"title": "RunnableParallelInput",
"title": "RunnableMapInput",
"type": "object",
"properties": {
"person": {"title": "Person", "type": "string"},
@ -2221,7 +2221,6 @@ async def test_stream_log_lists() -> None:
}
@pytest.mark.asyncio
@freeze_time("2023-01-01")
async def test_prompt_with_llm_and_async_lambda(
mocker: MockerFixture, snapshot: SnapshotAssertion
@ -4262,7 +4261,6 @@ def test_with_config_callbacks() -> None:
assert isinstance(result, RunnableBinding)
@pytest.mark.asyncio
async def test_ainvoke_on_returned_runnable() -> None:
"""Verify that a runnable returned by a sync runnable in the async path will
be runthroughaasync path (issue #13407)"""
@ -4301,7 +4299,6 @@ def test_invoke_stream_passthrough_assign_trace() -> None:
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"
@pytest.mark.asyncio
async def test_ainvoke_astream_passthrough_assign_trace() -> None:
def idchain_sync(__input: dict) -> bool:
return False