core: generate mermaid syntax and render visual graph (#19599)

- **Description:** Add functionality to generate Mermaid syntax and
render flowcharts from graph data. This includes support for custom node
colors and edge curve styles, as well as the ability to export the
generated graphs to PNG images using either the Mermaid.INK API or
Pyppeteer for local rendering.
- **Dependencies:** Optional dependencies are `pyppeteer` if rendering
wants to be done using Pypeteer and Javascript code.

---------

Co-authored-by: Angel Igareta <angel.igareta@klarna.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
pull/19866/head
Ángel Igareta 6 months ago committed by GitHub
parent 8711a05a51
commit c2ccf22dfd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -2,9 +2,11 @@ from __future__ import annotations
import inspect
from dataclasses import dataclass, field
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
NamedTuple,
@ -51,6 +53,46 @@ class Node(NamedTuple):
data: Union[Type[BaseModel], RunnableType]
class Branch(NamedTuple):
"""Branch in a graph."""
condition: Callable[..., str]
ends: Optional[dict[str, str]]
class CurveStyle(Enum):
"""Enum for different curve styles supported by Mermaid"""
BASIS = "basis"
BUMP_X = "bumpX"
BUMP_Y = "bumpY"
CARDINAL = "cardinal"
CATMULL_ROM = "catmullRom"
LINEAR = "linear"
MONOTONE_X = "monotoneX"
MONOTONE_Y = "monotoneY"
NATURAL = "natural"
STEP = "step"
STEP_AFTER = "stepAfter"
STEP_BEFORE = "stepBefore"
@dataclass
class NodeColors:
"""Schema for Hexadecimal color codes for different node types"""
start: str = "#ffdfba"
end: str = "#baffc9"
other: str = "#fad7de"
class MermaidDrawMethod(Enum):
"""Enum for different draw methods supported by Mermaid"""
PYPPETEER = "pyppeteer" # Uses Pyppeteer to render the graph
API = "api" # Uses Mermaid.INK API to render the graph
def node_data_str(node: Node) -> str:
from langchain_core.runnables.base import Runnable
@ -112,6 +154,7 @@ class Graph:
nodes: Dict[str, Node] = field(default_factory=dict)
edges: List[Edge] = field(default_factory=list)
branches: Optional[Dict[str, List[Branch]]] = field(default_factory=dict)
def to_json(self) -> Dict[str, List[Dict[str, Any]]]:
"""Convert the graph to a JSON-serializable format."""
@ -277,3 +320,59 @@ class Graph:
edges=labels["edges"] if labels is not None else {},
),
).draw(self, output_file_path)
def draw_mermaid(
self,
curve_style: CurveStyle = CurveStyle.LINEAR,
node_colors: NodeColors = NodeColors(
start="#ffdfba", end="#baffc9", other="#fad7de"
),
wrap_label_n_words: int = 9,
) -> str:
from langchain_core.runnables.graph_mermaid import draw_mermaid
nodes = {node.id: node_data_str(node) for node in self.nodes.values()}
first_node = self.first_node()
first_label = node_data_str(first_node) if first_node is not None else None
last_node = self.last_node()
last_label = node_data_str(last_node) if last_node is not None else None
return draw_mermaid(
nodes=nodes,
edges=self.edges,
branches=self.branches,
first_node_label=first_label,
last_node_label=last_label,
curve_style=curve_style,
node_colors=node_colors,
wrap_label_n_words=wrap_label_n_words,
)
def draw_mermaid_png(
self,
curve_style: CurveStyle = CurveStyle.LINEAR,
node_colors: NodeColors = NodeColors(
start="#ffdfba", end="#baffc9", other="#fad7de"
),
wrap_label_n_words: int = 9,
output_file_path: str = "graph.png",
draw_method: MermaidDrawMethod = MermaidDrawMethod.API,
background_color: str = "white",
padding: int = 10,
) -> None:
from langchain_core.runnables.graph_mermaid import draw_mermaid_png
mermaid_syntax = self.draw_mermaid(
curve_style=curve_style,
node_colors=node_colors,
wrap_label_n_words=wrap_label_n_words,
)
draw_mermaid_png(
mermaid_syntax=mermaid_syntax,
output_file_path=output_file_path,
draw_method=draw_method,
background_color=background_color,
padding=padding,
)

@ -0,0 +1,292 @@
import base64
import re
from dataclasses import asdict
from typing import Dict, List, Optional, Tuple
from langchain_core.runnables.graph import (
Branch,
CurveStyle,
Edge,
MermaidDrawMethod,
NodeColors,
)
def draw_mermaid(
nodes: Dict[str, str],
edges: List[Edge],
branches: Optional[Dict[str, List[Branch]]] = None,
first_node_label: Optional[str] = None,
last_node_label: Optional[str] = None,
curve_style: CurveStyle = CurveStyle.LINEAR,
node_colors: NodeColors = NodeColors(),
wrap_label_n_words: int = 9,
) -> str:
"""Draws a Mermaid graph using the provided graph data
Args:
nodes (dict[str, str]): List of node ids
edges (List[Edge]): List of edges, object with source,
target and data.
branches (defaultdict[str, list[Branch]]): Branches for the graph (
in case of langgraph) to remove intermediate condition nodes.
curve_style (CurveStyle, optional): Curve style for the edges.
node_colors (NodeColors, optional): Node colors for different types.
wrap_label_n_words (int, optional): Words to wrap the edge labels.
Returns:
str: Mermaid graph syntax
"""
# Initialize Mermaid graph configuration
mermaid_graph = (
f"%%{{init: {{'flowchart': {{'curve': '{curve_style.value}'"
f"}}}}}}%%\ngraph TD;\n"
)
# Node formatting templates
default_class_label = "default"
format_dict = {default_class_label: "{0}([{0}]):::otherclass"}
if first_node_label is not None:
format_dict[first_node_label] = "{0}[{0}]:::startclass"
if last_node_label is not None:
format_dict[last_node_label] = "{0}[{0}]:::endclass"
# Filter out nodes that were created due to conditional edges
# Remove combinations where node name is the same as a branch + condition
mapping_intermediate_node_pure_node = {}
if branches is not None:
for agent, agent_branches in branches.items():
for branch in agent_branches:
condition_name = branch.condition.__name__
intermediate_node_label = f"{agent}_{condition_name}"
if intermediate_node_label in nodes:
mapping_intermediate_node_pure_node[intermediate_node_label] = agent
# Not intermediate nodes
pure_nodes = {
id: value
for id, value in nodes.items()
if value not in mapping_intermediate_node_pure_node.keys()
}
# Add __end__ node if it is in any of the edges.target
if any("__end__" in edge.target for edge in edges):
pure_nodes["__end__"] = "__end__"
# Add nodes to the graph
for node in pure_nodes.values():
node_label = format_dict.get(node, format_dict[default_class_label]).format(
_escape_node_label(node)
)
mermaid_graph += f"\t{node_label};\n"
# Add edges to the graph
for edge in edges:
adjusted_edge = _adjust_mermaid_edge(
edge, nodes, mapping_intermediate_node_pure_node
)
if (
adjusted_edge is None
): # Ignore if it is connection between source and intermediate node
continue
source, target = adjusted_edge
# Add BR every wrap_label_n_words words
if edge.data is not None:
edge_data = edge.data
words = edge_data.split() # Split the string into words
# Group words into chunks of wrap_label_n_words size
if len(words) > wrap_label_n_words:
edge_data = "<br>".join(
[
" ".join(words[i : i + wrap_label_n_words])
for i in range(0, len(words), wrap_label_n_words)
]
)
edge_label = f" -- {edge_data} --> "
else:
edge_label = " --> "
mermaid_graph += (
f"\t{_escape_node_label(source)}{edge_label}"
f"{_escape_node_label(target)};\n"
)
# Add custom styles for nodes
mermaid_graph += _generate_mermaid_graph_styles(node_colors)
return mermaid_graph
def _escape_node_label(node_label: str) -> str:
"""Escapes the node label for Mermaid syntax."""
return re.sub(r"[^a-zA-Z-_]", "_", node_label)
def _adjust_mermaid_edge(
edge: Edge,
nodes: Dict[str, str],
mapping_intermediate_node_pure_node: Dict[str, str],
) -> Optional[Tuple[str, str]]:
"""Adjusts Mermaid edge to map conditional nodes to pure nodes."""
source_node_label = nodes.get(edge.source, edge.source)
target_node_label = nodes.get(edge.target, edge.target)
# Remove nodes between source node to intermediate node
if target_node_label in mapping_intermediate_node_pure_node.keys():
return None
# Replace intermediate nodes by source nodes
if source_node_label in mapping_intermediate_node_pure_node.keys():
source_node_label = mapping_intermediate_node_pure_node[source_node_label]
return source_node_label, target_node_label
def _generate_mermaid_graph_styles(node_colors: NodeColors) -> str:
"""Generates Mermaid graph styles for different node types."""
styles = ""
for class_name, color in asdict(node_colors).items():
styles += f"\tclassDef {class_name}class fill:{color};\n"
return styles
def draw_mermaid_png(
mermaid_syntax: str,
output_file_path: Optional[str] = None,
draw_method: MermaidDrawMethod = MermaidDrawMethod.API,
background_color: Optional[str] = "white",
padding: int = 10,
) -> bytes:
"""Draws a Mermaid graph as PNG using provided syntax."""
if draw_method == MermaidDrawMethod.PYPPETEER:
import asyncio
img_bytes = asyncio.run(
_render_mermaid_using_pyppeteer(
mermaid_syntax, output_file_path, background_color, padding
)
)
elif draw_method == MermaidDrawMethod.API:
img_bytes = _render_mermaid_using_api(
mermaid_syntax, output_file_path, background_color
)
else:
supported_methods = ", ".join([m.value for m in MermaidDrawMethod])
raise ValueError(
f"Invalid draw method: {draw_method}. "
f"Supported draw methods are: {supported_methods}"
)
return img_bytes
async def _render_mermaid_using_pyppeteer(
mermaid_syntax: str,
output_file_path: Optional[str] = None,
background_color: Optional[str] = "white",
padding: int = 10,
) -> bytes:
"""Renders Mermaid graph using Pyppeteer."""
try:
from pyppeteer import launch # type: ignore[import]
except ImportError as e:
raise ImportError(
"Install Pyppeteer to use the Pyppeteer method: `pip install pyppeteer`."
) from e
browser = await launch()
page = await browser.newPage()
# Setup Mermaid JS
await page.goto("about:blank")
await page.addScriptTag({"url": "https://unpkg.com/mermaid/dist/mermaid.min.js"})
await page.evaluate(
"""() => {
mermaid.initialize({startOnLoad:true});
}"""
)
# Render SVG
svg_code = await page.evaluate(
"""(mermaidGraph) => {
return mermaid.mermaidAPI.render('mermaid', mermaidGraph);
}""",
mermaid_syntax,
)
# Set the page background to white
await page.evaluate(
"""(svg, background_color) => {
document.body.innerHTML = svg;
document.body.style.background = background_color;
}""",
svg_code["svg"],
background_color,
)
# Take a screenshot
dimensions = await page.evaluate(
"""() => {
const svgElement = document.querySelector('svg');
const rect = svgElement.getBoundingClientRect();
return { width: rect.width, height: rect.height };
}"""
)
await page.setViewport(
{
"width": int(dimensions["width"] + padding),
"height": int(dimensions["height"] + padding),
}
)
img_bytes = await page.screenshot({"fullPage": False})
await browser.close()
if output_file_path is not None:
with open(output_file_path, "wb") as file:
file.write(img_bytes)
return img_bytes
def _render_mermaid_using_api(
mermaid_syntax: str,
output_file_path: Optional[str] = None,
background_color: Optional[str] = "white",
) -> bytes:
"""Renders Mermaid graph using the Mermaid.INK API."""
try:
import requests # type: ignore[import]
except ImportError as e:
raise ImportError(
"Install the `requests` module to use the Mermaid.INK API: "
"`pip install requests`."
) from e
# Use Mermaid API to render the image
mermaid_syntax_encoded = base64.b64encode(mermaid_syntax.encode("utf8")).decode(
"ascii"
)
# Check if the background color is a hexadecimal color code using regex
if background_color is not None:
hex_color_pattern = re.compile(r"^#(?:[0-9a-fA-F]{3}){1,2}$")
if not hex_color_pattern.match(background_color):
background_color = f"!{background_color}"
image_url = (
f"https://mermaid.ink/img/{mermaid_syntax_encoded}?bgColor={background_color}"
)
response = requests.get(image_url)
if response.status_code == 200:
img_bytes = response.content
if output_file_path is not None:
with open(output_file_path, "wb") as file:
file.write(response.content)
return img_bytes
else:
raise ValueError(
f"Failed to render the graph using the Mermaid.INK API. "
f"Status code: {response.status_code}."
)

@ -1,5 +1,5 @@
# serializer version: 1
# name: test_graph_sequence
# name: test_graph_sequence[ascii]
'''
+-------------+
| PromptInput |
@ -30,7 +30,26 @@
+--------------------------------------+
'''
# ---
# name: test_graph_sequence_map
# name: test_graph_sequence[mermaid]
'''
%%{init: {'flowchart': {'curve': 'linear'}}}%%
graph TD;
PromptInput[PromptInput]:::startclass;
PromptTemplate([PromptTemplate]):::otherclass;
FakeListLLM([FakeListLLM]):::otherclass;
CommaSeparatedListOutputParser([CommaSeparatedListOutputParser]):::otherclass;
CommaSeparatedListOutputParserOutput[CommaSeparatedListOutputParserOutput]:::endclass;
PromptInput --> PromptTemplate;
PromptTemplate --> FakeListLLM;
CommaSeparatedListOutputParser --> CommaSeparatedListOutputParserOutput;
FakeListLLM --> CommaSeparatedListOutputParser;
classDef startclass fill:#ffdfba;
classDef endclass fill:#baffc9;
classDef otherclass fill:#fad7de;
'''
# ---
# name: test_graph_sequence_map[ascii]
'''
+-------------+
| PromptInput |
@ -79,7 +98,38 @@
+--------------------------------+
'''
# ---
# name: test_graph_single_runnable
# name: test_graph_sequence_map[mermaid]
'''
%%{init: {'flowchart': {'curve': 'linear'}}}%%
graph TD;
PromptInput[PromptInput]:::startclass;
PromptTemplate([PromptTemplate]):::otherclass;
FakeListLLM([FakeListLLM]):::otherclass;
Parallel_as_list_as_str_Input([Parallel_as_list_as_str_Input]):::otherclass;
Parallel_as_list_as_str_Output[Parallel_as_list_as_str_Output]:::endclass;
CommaSeparatedListOutputParser([CommaSeparatedListOutputParser]):::otherclass;
conditional_str_parser_input([conditional_str_parser_input]):::otherclass;
conditional_str_parser_output([conditional_str_parser_output]):::otherclass;
StrOutputParser([StrOutputParser]):::otherclass;
XMLOutputParser([XMLOutputParser]):::otherclass;
PromptInput --> PromptTemplate;
PromptTemplate --> FakeListLLM;
Parallel_as_list_as_str_Input --> CommaSeparatedListOutputParser;
CommaSeparatedListOutputParser --> Parallel_as_list_as_str_Output;
conditional_str_parser_input --> StrOutputParser;
StrOutputParser --> conditional_str_parser_output;
conditional_str_parser_input --> XMLOutputParser;
XMLOutputParser --> conditional_str_parser_output;
Parallel_as_list_as_str_Input --> conditional_str_parser_input;
conditional_str_parser_output --> Parallel_as_list_as_str_Output;
FakeListLLM --> Parallel_as_list_as_str_Input;
classDef startclass fill:#ffdfba;
classDef endclass fill:#baffc9;
classDef otherclass fill:#fad7de;
'''
# ---
# name: test_graph_single_runnable[ascii]
'''
+----------------------+
| StrOutputParserInput |
@ -98,3 +148,18 @@
+-----------------------+
'''
# ---
# name: test_graph_single_runnable[mermaid]
'''
%%{init: {'flowchart': {'curve': 'linear'}}}%%
graph TD;
StrOutputParserInput[StrOutputParserInput]:::startclass;
StrOutputParser([StrOutputParser]):::otherclass;
StrOutputParserOutput[StrOutputParserOutput]:::endclass;
StrOutputParserInput --> StrOutputParser;
StrOutputParser --> StrOutputParserOutput;
classDef startclass fill:#ffdfba;
classDef endclass fill:#baffc9;
classDef otherclass fill:#fad7de;
'''
# ---

@ -21,7 +21,8 @@ def test_graph_single_runnable(snapshot: SnapshotAssertion) -> None:
assert len(graph.edges) == 2
assert graph.edges[0].source == first_node.id
assert graph.edges[1].target == last_node.id
assert graph.draw_ascii() == snapshot
assert graph.draw_ascii() == snapshot(name="ascii")
assert graph.draw_mermaid() == snapshot(name="mermaid")
def test_graph_sequence(snapshot: SnapshotAssertion) -> None:
@ -88,7 +89,8 @@ def test_graph_sequence(snapshot: SnapshotAssertion) -> None:
{"source": 2, "target": 3},
],
}
assert graph.draw_ascii() == snapshot
assert graph.draw_ascii() == snapshot(name="ascii")
assert graph.draw_mermaid() == snapshot(name="mermaid")
def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
@ -482,4 +484,5 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
{"source": 2, "target": 3},
],
}
assert graph.draw_ascii() == snapshot
assert graph.draw_ascii() == snapshot(name="ascii")
assert graph.draw_mermaid() == snapshot(name="mermaid")

Loading…
Cancel
Save