mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
core[patch]: support drawing nested subgraphs in draw_mermaid (#25581)
Previously the code was able to only handle a single level of nesting for subgraphs in mermaid. This change adds support for arbitrary nesting of subgraphs.
This commit is contained in:
parent
1c31234eed
commit
46d344c33d
@ -78,47 +78,78 @@ def draw_mermaid(
|
||||
)
|
||||
mermaid_graph += f"\t{node_label}\n"
|
||||
|
||||
subgraph = ""
|
||||
# Add edges to the graph
|
||||
# Group edges by their common prefixes
|
||||
edge_groups: Dict[str, List[Edge]] = {}
|
||||
for edge in edges:
|
||||
src_prefix = edge.source.split(":")[0] if ":" in edge.source else None
|
||||
tgt_prefix = edge.target.split(":")[0] if ":" in edge.target else None
|
||||
# exit subgraph if source or target is not in the same subgraph
|
||||
if subgraph and (subgraph != src_prefix or subgraph != tgt_prefix):
|
||||
mermaid_graph += "\tend\n"
|
||||
subgraph = ""
|
||||
# enter subgraph if source and target are in the same subgraph
|
||||
if not subgraph and src_prefix and src_prefix == tgt_prefix:
|
||||
mermaid_graph += f"\tsubgraph {src_prefix}\n"
|
||||
subgraph = src_prefix
|
||||
|
||||
source, target = edge.source, edge.target
|
||||
|
||||
# Add BR every wrap_label_n_words words
|
||||
if edge.data is not None:
|
||||
edge_data = edge.data
|
||||
words = str(edge_data).split() # Split the string into words
|
||||
# Group words into chunks of wrap_label_n_words size
|
||||
if len(words) > wrap_label_n_words:
|
||||
edge_data = " <br> ".join(
|
||||
" ".join(words[i : i + wrap_label_n_words])
|
||||
for i in range(0, len(words), wrap_label_n_words)
|
||||
)
|
||||
if edge.conditional:
|
||||
edge_label = f" -.  {edge_data}  .-> "
|
||||
else:
|
||||
edge_label = f" --  {edge_data}  --> "
|
||||
else:
|
||||
if edge.conditional:
|
||||
edge_label = " -.-> "
|
||||
else:
|
||||
edge_label = " --> "
|
||||
mermaid_graph += (
|
||||
f"\t{_escape_node_label(source)}{edge_label}"
|
||||
f"{_escape_node_label(target)};\n"
|
||||
src_parts = edge.source.split(":")
|
||||
tgt_parts = edge.target.split(":")
|
||||
common_prefix = ":".join(
|
||||
src for src, tgt in zip(src_parts, tgt_parts) if src == tgt
|
||||
)
|
||||
if subgraph:
|
||||
mermaid_graph += "end\n"
|
||||
edge_groups.setdefault(common_prefix, []).append(edge)
|
||||
|
||||
seen_subgraphs = set()
|
||||
|
||||
def add_subgraph(edges: List[Edge], prefix: str) -> None:
|
||||
nonlocal mermaid_graph
|
||||
self_loop = len(edges) == 1 and edges[0].source == edges[0].target
|
||||
if prefix and not self_loop:
|
||||
subgraph = prefix.split(":")[-1]
|
||||
if subgraph in seen_subgraphs:
|
||||
raise ValueError(
|
||||
f"Found duplicate subgraph '{subgraph}' -- this likely means that "
|
||||
"you're reusing a subgraph node with the same name. "
|
||||
"Please adjust your graph to have subgraph nodes with unique names."
|
||||
)
|
||||
|
||||
seen_subgraphs.add(subgraph)
|
||||
mermaid_graph += f"\tsubgraph {subgraph}\n"
|
||||
|
||||
for edge in edges:
|
||||
source, target = edge.source, edge.target
|
||||
|
||||
# Add BR every wrap_label_n_words words
|
||||
if edge.data is not None:
|
||||
edge_data = edge.data
|
||||
words = str(edge_data).split() # Split the string into words
|
||||
# Group words into chunks of wrap_label_n_words size
|
||||
if len(words) > wrap_label_n_words:
|
||||
edge_data = " <br> ".join(
|
||||
" ".join(words[i : i + wrap_label_n_words])
|
||||
for i in range(0, len(words), wrap_label_n_words)
|
||||
)
|
||||
if edge.conditional:
|
||||
edge_label = f" -.  {edge_data}  .-> "
|
||||
else:
|
||||
edge_label = f" --  {edge_data}  --> "
|
||||
else:
|
||||
if edge.conditional:
|
||||
edge_label = " -.-> "
|
||||
else:
|
||||
edge_label = " --> "
|
||||
|
||||
mermaid_graph += (
|
||||
f"\t{_escape_node_label(source)}{edge_label}"
|
||||
f"{_escape_node_label(target)};\n"
|
||||
)
|
||||
|
||||
# Recursively add nested subgraphs
|
||||
for nested_prefix in edge_groups.keys():
|
||||
if not nested_prefix.startswith(prefix + ":") or nested_prefix == prefix:
|
||||
continue
|
||||
add_subgraph(edge_groups[nested_prefix], nested_prefix)
|
||||
|
||||
if prefix and not self_loop:
|
||||
mermaid_graph += "\tend\n"
|
||||
|
||||
# Start with the top-level edges (no common prefix)
|
||||
add_subgraph(edge_groups.get("", []), "")
|
||||
|
||||
# Add remaining subgraphs
|
||||
for prefix in edge_groups.keys():
|
||||
if ":" in prefix or prefix == "":
|
||||
continue
|
||||
add_subgraph(edge_groups[prefix], prefix)
|
||||
|
||||
# Add custom styles for nodes
|
||||
if with_styles:
|
||||
|
@ -1063,6 +1063,63 @@
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_parallel_subgraph_mermaid[mermaid]
|
||||
'''
|
||||
%%{init: {'flowchart': {'curve': 'linear'}}}%%
|
||||
graph TD;
|
||||
__start__([__start__]):::first
|
||||
outer_1(outer_1)
|
||||
inner_1_inner_1(inner_1)
|
||||
inner_1_inner_2(inner_2<hr/><small><em>__interrupt = before</em></small>)
|
||||
inner_2_inner_1(inner_1)
|
||||
inner_2_inner_2(inner_2)
|
||||
outer_2(outer_2)
|
||||
__end__([__end__]):::last
|
||||
__start__ --> outer_1;
|
||||
inner_1_inner_2 --> outer_2;
|
||||
inner_2_inner_2 --> outer_2;
|
||||
outer_1 --> inner_1_inner_1;
|
||||
outer_1 --> inner_2_inner_1;
|
||||
outer_2 --> __end__;
|
||||
subgraph inner_1
|
||||
inner_1_inner_1 --> inner_1_inner_2;
|
||||
end
|
||||
subgraph inner_2
|
||||
inner_2_inner_1 --> inner_2_inner_2;
|
||||
end
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_double_nested_subgraph_mermaid[mermaid]
|
||||
'''
|
||||
%%{init: {'flowchart': {'curve': 'linear'}}}%%
|
||||
graph TD;
|
||||
__start__([__start__]):::first
|
||||
parent_1(parent_1)
|
||||
child_child_1_grandchild_1(grandchild_1)
|
||||
child_child_1_grandchild_2(grandchild_2<hr/><small><em>__interrupt = before</em></small>)
|
||||
child_child_2(child_2)
|
||||
parent_2(parent_2)
|
||||
__end__([__end__]):::last
|
||||
__start__ --> parent_1;
|
||||
child_child_2 --> parent_2;
|
||||
parent_1 --> child_child_1_grandchild_1;
|
||||
parent_2 --> __end__;
|
||||
subgraph child
|
||||
child_child_1_grandchild_2 --> child_child_2;
|
||||
subgraph child_1
|
||||
child_child_1_grandchild_1 --> child_child_1_grandchild_2;
|
||||
end
|
||||
end
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_graph_single_runnable[ascii]
|
||||
'''
|
||||
+----------------------+
|
||||
|
@ -9,7 +9,7 @@ from langchain_core.output_parsers.xml import XMLOutputParser
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables.base import Runnable, RunnableConfig
|
||||
from langchain_core.runnables.graph import Graph
|
||||
from langchain_core.runnables.graph import Edge, Graph, Node
|
||||
from langchain_core.runnables.graph_mermaid import _escape_node_label
|
||||
from tests.unit_tests.pydantic_utils import _schema
|
||||
|
||||
@ -216,6 +216,136 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
|
||||
assert graph.draw_mermaid(with_styles=False) == snapshot(name="mermaid-simple")
|
||||
|
||||
|
||||
def test_parallel_subgraph_mermaid(snapshot: SnapshotAssertion) -> None:
|
||||
empty_data = BaseModel
|
||||
nodes = {
|
||||
"__start__": Node(
|
||||
id="__start__", name="__start__", data=empty_data, metadata=None
|
||||
),
|
||||
"outer_1": Node(id="outer_1", name="outer_1", data=empty_data, metadata=None),
|
||||
"inner_1:inner_1": Node(
|
||||
id="inner_1:inner_1", name="inner_1", data=empty_data, metadata=None
|
||||
),
|
||||
"inner_1:inner_2": Node(
|
||||
id="inner_1:inner_2",
|
||||
name="inner_2",
|
||||
data=empty_data,
|
||||
metadata={"__interrupt": "before"},
|
||||
),
|
||||
"inner_2:inner_1": Node(
|
||||
id="inner_2:inner_1", name="inner_1", data=empty_data, metadata=None
|
||||
),
|
||||
"inner_2:inner_2": Node(
|
||||
id="inner_2:inner_2", name="inner_2", data=empty_data, metadata=None
|
||||
),
|
||||
"outer_2": Node(id="outer_2", name="outer_2", data=empty_data, metadata=None),
|
||||
"__end__": Node(id="__end__", name="__end__", data=empty_data, metadata=None),
|
||||
}
|
||||
edges = [
|
||||
Edge(
|
||||
source="inner_1:inner_1",
|
||||
target="inner_1:inner_2",
|
||||
data=None,
|
||||
conditional=False,
|
||||
),
|
||||
Edge(
|
||||
source="inner_2:inner_1",
|
||||
target="inner_2:inner_2",
|
||||
data=None,
|
||||
conditional=False,
|
||||
),
|
||||
Edge(source="__start__", target="outer_1", data=None, conditional=False),
|
||||
Edge(
|
||||
source="inner_1:inner_2",
|
||||
target="outer_2",
|
||||
data=None,
|
||||
conditional=False,
|
||||
),
|
||||
Edge(
|
||||
source="inner_2:inner_2",
|
||||
target="outer_2",
|
||||
data=None,
|
||||
conditional=False,
|
||||
),
|
||||
Edge(
|
||||
source="outer_1",
|
||||
target="inner_1:inner_1",
|
||||
data=None,
|
||||
conditional=False,
|
||||
),
|
||||
Edge(
|
||||
source="outer_1",
|
||||
target="inner_2:inner_1",
|
||||
data=None,
|
||||
conditional=False,
|
||||
),
|
||||
Edge(source="outer_2", target="__end__", data=None, conditional=False),
|
||||
]
|
||||
graph = Graph(nodes, edges)
|
||||
assert graph.draw_mermaid() == snapshot(name="mermaid")
|
||||
|
||||
|
||||
def test_double_nested_subgraph_mermaid(snapshot: SnapshotAssertion) -> None:
|
||||
empty_data = BaseModel
|
||||
nodes = {
|
||||
"__start__": Node(
|
||||
id="__start__", name="__start__", data=empty_data, metadata=None
|
||||
),
|
||||
"parent_1": Node(
|
||||
id="parent_1", name="parent_1", data=empty_data, metadata=None
|
||||
),
|
||||
"child:child_1:grandchild_1": Node(
|
||||
id="child:child_1:grandchild_1",
|
||||
name="grandchild_1",
|
||||
data=empty_data,
|
||||
metadata=None,
|
||||
),
|
||||
"child:child_1:grandchild_2": Node(
|
||||
id="child:child_1:grandchild_2",
|
||||
name="grandchild_2",
|
||||
data=empty_data,
|
||||
metadata={"__interrupt": "before"},
|
||||
),
|
||||
"child:child_2": Node(
|
||||
id="child:child_2", name="child_2", data=empty_data, metadata=None
|
||||
),
|
||||
"parent_2": Node(
|
||||
id="parent_2", name="parent_2", data=empty_data, metadata=None
|
||||
),
|
||||
"__end__": Node(id="__end__", name="__end__", data=empty_data, metadata=None),
|
||||
}
|
||||
edges = [
|
||||
Edge(
|
||||
source="child:child_1:grandchild_1",
|
||||
target="child:child_1:grandchild_2",
|
||||
data=None,
|
||||
conditional=False,
|
||||
),
|
||||
Edge(
|
||||
source="child:child_1:grandchild_2",
|
||||
target="child:child_2",
|
||||
data=None,
|
||||
conditional=False,
|
||||
),
|
||||
Edge(source="__start__", target="parent_1", data=None, conditional=False),
|
||||
Edge(
|
||||
source="child:child_2",
|
||||
target="parent_2",
|
||||
data=None,
|
||||
conditional=False,
|
||||
),
|
||||
Edge(
|
||||
source="parent_1",
|
||||
target="child:child_1:grandchild_1",
|
||||
data=None,
|
||||
conditional=False,
|
||||
),
|
||||
Edge(source="parent_2", target="__end__", data=None, conditional=False),
|
||||
]
|
||||
graph = Graph(nodes, edges)
|
||||
assert graph.draw_mermaid() == snapshot(name="mermaid")
|
||||
|
||||
|
||||
def test_runnable_get_graph_with_invalid_input_type() -> None:
|
||||
"""Test that error isn't raised when getting graph with invalid input type."""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user