Runnable graph viz improvements (#20529)

- Add conditional: bool property to json representation of the graphs
- Add option to generate mermaid graph stripped of styles (useful as a
text representation of graph)
pull/20534/head
Nuno Campos 2 months ago committed by GitHub
parent f3aa26d6bf
commit 806a54908c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -170,6 +170,17 @@ class Graph:
node.id: i if is_uuid(node.id) else node.id
for i, node in enumerate(self.nodes.values())
}
edges: List[Dict[str, Any]] = []
for edge in self.edges:
edge_dict = {
"source": stable_node_ids[edge.source],
"target": stable_node_ids[edge.target],
}
if edge.data is not None:
edge_dict["data"] = edge.data
if edge.conditional:
edge_dict["conditional"] = True
edges.append(edge_dict)
return {
"nodes": [
@ -179,19 +190,7 @@ class Graph:
}
for node in self.nodes.values()
],
"edges": [
{
"source": stable_node_ids[edge.source],
"target": stable_node_ids[edge.target],
"data": edge.data,
}
if edge.data is not None
else {
"source": stable_node_ids[edge.source],
"target": stable_node_ids[edge.target],
}
for edge in self.edges
],
"edges": edges,
}
def __bool__(self) -> bool:
@ -345,6 +344,7 @@ class Graph:
def draw_mermaid(
self,
*,
with_styles: bool = True,
curve_style: CurveStyle = CurveStyle.LINEAR,
node_colors: NodeColors = NodeColors(
start="#ffdfba", end="#baffc9", other="#fad7de"
@ -366,6 +366,7 @@ class Graph:
edges=self.edges,
first_node_label=first_label,
last_node_label=last_label,
with_styles=with_styles,
curve_style=curve_style,
node_colors=node_colors,
wrap_label_n_words=wrap_label_n_words,

@ -17,6 +17,7 @@ def draw_mermaid(
*,
first_node_label: Optional[str] = None,
last_node_label: Optional[str] = None,
with_styles: bool = True,
curve_style: CurveStyle = CurveStyle.LINEAR,
node_colors: NodeColors = NodeColors(),
wrap_label_n_words: int = 9,
@ -36,24 +37,29 @@ def draw_mermaid(
"""
# Initialize Mermaid graph configuration
mermaid_graph = (
f"%%{{init: {{'flowchart': {{'curve': '{curve_style.value}'"
f"}}}}}}%%\ngraph TD;\n"
(
f"%%{{init: {{'flowchart': {{'curve': '{curve_style.value}'"
f"}}}}}}%%\ngraph TD;\n"
)
if with_styles
else "graph TD;\n"
)
# Node formatting templates
default_class_label = "default"
format_dict = {default_class_label: "{0}([{0}]):::otherclass"}
if first_node_label is not None:
format_dict[first_node_label] = "{0}[{0}]:::startclass"
if last_node_label is not None:
format_dict[last_node_label] = "{0}[{0}]:::endclass"
# Add nodes to the graph
for node in nodes.values():
node_label = format_dict.get(node, format_dict[default_class_label]).format(
_escape_node_label(node)
)
mermaid_graph += f"\t{node_label};\n"
if with_styles:
# Node formatting templates
default_class_label = "default"
format_dict = {default_class_label: "{0}([{0}]):::otherclass"}
if first_node_label is not None:
format_dict[first_node_label] = "{0}[{0}]:::startclass"
if last_node_label is not None:
format_dict[last_node_label] = "{0}[{0}]:::endclass"
# Add nodes to the graph
for node in nodes.values():
node_label = format_dict.get(node, format_dict[default_class_label]).format(
_escape_node_label(node)
)
mermaid_graph += f"\t{node_label};\n"
# Add edges to the graph
for edge in edges:
@ -92,7 +98,8 @@ def draw_mermaid(
)
# Add custom styles for nodes
mermaid_graph += _generate_mermaid_graph_styles(node_colors)
if with_styles:
mermaid_graph += _generate_mermaid_graph_styles(node_colors)
return mermaid_graph

@ -98,6 +98,23 @@
+--------------------------------+
'''
# ---
# name: test_graph_sequence_map[mermaid-simple]
'''
graph TD;
PromptInput --> PromptTemplate;
PromptTemplate --> FakeListLLM;
Parallel_as_list_as_str_Input --> CommaSeparatedListOutputParser;
CommaSeparatedListOutputParser --> Parallel_as_list_as_str_Output;
conditional_str_parser_input --> StrOutputParser;
StrOutputParser --> conditional_str_parser_output;
conditional_str_parser_input --> XMLOutputParser;
XMLOutputParser --> conditional_str_parser_output;
Parallel_as_list_as_str_Input --> conditional_str_parser_input;
conditional_str_parser_output --> Parallel_as_list_as_str_Output;
FakeListLLM --> Parallel_as_list_as_str_Input;
'''
# ---
# name: test_graph_sequence_map[mermaid]
'''
%%{init: {'flowchart': {'curve': 'linear'}}}%%

@ -660,3 +660,4 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
}
assert graph.draw_ascii() == snapshot(name="ascii")
assert graph.draw_mermaid() == snapshot(name="mermaid")
assert graph.draw_mermaid(with_styles=False) == snapshot(name="mermaid-simple")

Loading…
Cancel
Save