diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py
index d9d30e3b1e..77137d0a86 100644
--- a/libs/core/langchain_core/runnables/graph.py
+++ b/libs/core/langchain_core/runnables/graph.py
@@ -2,16 +2,32 @@ from __future__ import annotations
import inspect
from dataclasses import dataclass, field
-from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Type, Union
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ List,
+ NamedTuple,
+ Optional,
+ Type,
+ TypedDict,
+ Union,
+ overload,
+)
from uuid import UUID, uuid4
from langchain_core.pydantic_v1 import BaseModel
-from langchain_core.runnables.graph_draw import draw
+from langchain_core.runnables.graph_ascii import draw_ascii
if TYPE_CHECKING:
from langchain_core.runnables.base import Runnable as RunnableType
+class LabelsDict(TypedDict):
+ nodes: dict[str, str]
+ edges: dict[str, str]
+
+
def is_uuid(value: str) -> bool:
try:
UUID(value)
@@ -213,10 +229,38 @@ class Graph:
self.remove_node(last_node)
def draw_ascii(self) -> str:
- return draw(
+ return draw_ascii(
{node.id: node_data_str(node) for node in self.nodes.values()},
[(edge.source, edge.target) for edge in self.edges],
)
def print_ascii(self) -> None:
print(self.draw_ascii()) # noqa: T201
+
+ @overload
+ def draw_png(
+ self,
+ output_file_path: str,
+ fontname: Optional[str] = None,
+ labels: Optional[LabelsDict] = None,
+ ) -> None:
+ ...
+
+ @overload
+ def draw_png(
+ self,
+ output_file_path: None,
+ fontname: Optional[str] = None,
+ labels: Optional[LabelsDict] = None,
+ ) -> bytes:
+ ...
+
+ def draw_png(
+ self,
+ output_file_path: Optional[str] = None,
+ fontname: Optional[str] = None,
+ labels: Optional[LabelsDict] = None,
+ ) -> Union[bytes, None]:
+ from langchain_core.runnables.graph_png import PngDrawer
+
+ return PngDrawer(fontname, labels).draw(self, output_file_path)
diff --git a/libs/core/langchain_core/runnables/graph_draw.py b/libs/core/langchain_core/runnables/graph_ascii.py
similarity index 99%
rename from libs/core/langchain_core/runnables/graph_draw.py
rename to libs/core/langchain_core/runnables/graph_ascii.py
index 87facfb92e..bc809dce28 100644
--- a/libs/core/langchain_core/runnables/graph_draw.py
+++ b/libs/core/langchain_core/runnables/graph_ascii.py
@@ -209,7 +209,7 @@ def _build_sugiyama_layout(
return sug
-def draw(vertices: Mapping[str, str], edges: Sequence[Tuple[str, str]]) -> str:
+def draw_ascii(vertices: Mapping[str, str], edges: Sequence[Tuple[str, str]]) -> str:
"""Build a DAG and draw it in ASCII.
Args:
diff --git a/libs/core/langchain_core/runnables/graph_png.py b/libs/core/langchain_core/runnables/graph_png.py
new file mode 100644
index 0000000000..51cfcb9577
--- /dev/null
+++ b/libs/core/langchain_core/runnables/graph_png.py
@@ -0,0 +1,108 @@
+from typing import Any, Optional
+
+from langchain_core.runnables.graph import Graph, LabelsDict
+
+
+class PngDrawer:
+ """
+ A helper class to draw a state graph into a PNG file.
+ Requires graphviz and pygraphviz to be installed.
+ :param fontname: The font to use for the labels
+ :param labels: A dictionary of label overrides. The dictionary
+ should have the following format:
+ {
+ "nodes": {
+ "node1": "CustomLabel1",
+ "node2": "CustomLabel2",
+ "__end__": "End Node"
+ },
+ "edges": {
+ "continue": "ContinueLabel",
+ "end": "EndLabel"
+ }
+ }
+ The keys are the original labels, and the values are the new labels.
+ Usage:
+ drawer = PngDrawer()
+ drawer.draw(state_graph, 'graph.png')
+ """
+
+ def __init__(
+ self, fontname: Optional[str] = None, labels: Optional[LabelsDict] = None
+ ) -> None:
+ self.fontname = fontname or "arial"
+ self.labels = labels or LabelsDict(nodes={}, edges={})
+
+ def get_node_label(self, label: str) -> str:
+ label = self.labels.get("nodes", {}).get(label, label)
+ return f"<{label}>"
+
+ def get_edge_label(self, label: str) -> str:
+ label = self.labels.get("edges", {}).get(label, label)
+ return f"<{label}>"
+
+ def add_node(self, viz: Any, node: str) -> None:
+ viz.add_node(
+ node,
+ label=self.get_node_label(node),
+ style="filled",
+ fillcolor="yellow",
+ fontsize=15,
+ fontname=self.fontname,
+ )
+
+ def add_edge(
+ self, viz: Any, source: str, target: str, label: Optional[str] = None
+ ) -> None:
+ viz.add_edge(
+ source,
+ target,
+ label=self.get_edge_label(label) if label else "",
+ fontsize=12,
+ fontname=self.fontname,
+ )
+
+ def draw(self, graph: Graph, output_path: Optional[str] = None) -> Optional[bytes]:
+ """
+ Draws the given state graph into a PNG file.
+ Requires graphviz and pygraphviz to be installed.
+ :param graph: The graph to draw
+ :param output_path: The path to save the PNG. If None, PNG bytes are returned.
+ """
+
+ try:
+ import pygraphviz as pgv # type: ignore[import]
+ except ImportError as exc:
+ raise ImportError(
+ "Install pygraphviz to draw graphs: `pip install pygraphviz`."
+ ) from exc
+
+ # Create a directed graph
+ viz = pgv.AGraph(directed=True, nodesep=0.9, ranksep=1.0)
+
+ # Add nodes, conditional edges, and edges to the graph
+ self.add_nodes(viz, graph)
+ self.add_edges(viz, graph)
+
+ # Update entrypoint and END styles
+ self.update_styles(viz, graph)
+
+ # Save the graph as PNG
+ try:
+ return viz.draw(output_path, format="png", prog="dot")
+ finally:
+ viz.close()
+
+ def add_nodes(self, viz: Any, graph: Graph) -> None:
+ for node in graph.nodes:
+ self.add_node(viz, node)
+
+ def add_edges(self, viz: Any, graph: Graph) -> None:
+ for start, end, label in graph.edges:
+ self.add_edge(viz, start, end, label)
+
+ def update_styles(self, viz: Any, graph: Graph) -> None:
+ if first := graph.first_node():
+ viz.get_node(first.id).attr.update(fillcolor="lightblue")
+ if last := graph.last_node():
+ viz.get_node(last.id).attr.update(fillcolor="orange")
diff --git a/libs/core/poetry.lock b/libs/core/poetry.lock
index 3078c2ae55..3ecb233f5f 100644
--- a/libs/core/poetry.lock
+++ b/libs/core/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
[[package]]
name = "annotated-types"
@@ -1996,6 +1996,7 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
+ {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},