diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index d9d30e3b1e..77137d0a86 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -2,16 +2,32 @@ from __future__ import annotations import inspect from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + NamedTuple, + Optional, + Type, + TypedDict, + Union, + overload, +) from uuid import UUID, uuid4 from langchain_core.pydantic_v1 import BaseModel -from langchain_core.runnables.graph_draw import draw +from langchain_core.runnables.graph_ascii import draw_ascii if TYPE_CHECKING: from langchain_core.runnables.base import Runnable as RunnableType +class LabelsDict(TypedDict): + nodes: dict[str, str] + edges: dict[str, str] + + def is_uuid(value: str) -> bool: try: UUID(value) @@ -213,10 +229,38 @@ class Graph: self.remove_node(last_node) def draw_ascii(self) -> str: - return draw( + return draw_ascii( {node.id: node_data_str(node) for node in self.nodes.values()}, [(edge.source, edge.target) for edge in self.edges], ) def print_ascii(self) -> None: print(self.draw_ascii()) # noqa: T201 + + @overload + def draw_png( + self, + output_file_path: str, + fontname: Optional[str] = None, + labels: Optional[LabelsDict] = None, + ) -> None: + ... + + @overload + def draw_png( + self, + output_file_path: None, + fontname: Optional[str] = None, + labels: Optional[LabelsDict] = None, + ) -> bytes: + ... + + def draw_png( + self, + output_file_path: Optional[str] = None, + fontname: Optional[str] = None, + labels: Optional[LabelsDict] = None, + ) -> Union[bytes, None]: + from langchain_core.runnables.graph_png import PngDrawer + + return PngDrawer(fontname, labels).draw(self, output_file_path) diff --git a/libs/core/langchain_core/runnables/graph_draw.py b/libs/core/langchain_core/runnables/graph_ascii.py similarity index 99% rename from libs/core/langchain_core/runnables/graph_draw.py rename to libs/core/langchain_core/runnables/graph_ascii.py index 87facfb92e..bc809dce28 100644 --- a/libs/core/langchain_core/runnables/graph_draw.py +++ b/libs/core/langchain_core/runnables/graph_ascii.py @@ -209,7 +209,7 @@ def _build_sugiyama_layout( return sug -def draw(vertices: Mapping[str, str], edges: Sequence[Tuple[str, str]]) -> str: +def draw_ascii(vertices: Mapping[str, str], edges: Sequence[Tuple[str, str]]) -> str: """Build a DAG and draw it in ASCII. Args: diff --git a/libs/core/langchain_core/runnables/graph_png.py b/libs/core/langchain_core/runnables/graph_png.py new file mode 100644 index 0000000000..51cfcb9577 --- /dev/null +++ b/libs/core/langchain_core/runnables/graph_png.py @@ -0,0 +1,108 @@ +from typing import Any, Optional + +from langchain_core.runnables.graph import Graph, LabelsDict + + +class PngDrawer: + """ + A helper class to draw a state graph into a PNG file. + Requires graphviz and pygraphviz to be installed. + :param fontname: The font to use for the labels + :param labels: A dictionary of label overrides. The dictionary + should have the following format: + { + "nodes": { + "node1": "CustomLabel1", + "node2": "CustomLabel2", + "__end__": "End Node" + }, + "edges": { + "continue": "ContinueLabel", + "end": "EndLabel" + } + } + The keys are the original labels, and the values are the new labels. + Usage: + drawer = PngDrawer() + drawer.draw(state_graph, 'graph.png') + """ + + def __init__( + self, fontname: Optional[str] = None, labels: Optional[LabelsDict] = None + ) -> None: + self.fontname = fontname or "arial" + self.labels = labels or LabelsDict(nodes={}, edges={}) + + def get_node_label(self, label: str) -> str: + label = self.labels.get("nodes", {}).get(label, label) + return f"<{label}>" + + def get_edge_label(self, label: str) -> str: + label = self.labels.get("edges", {}).get(label, label) + return f"<{label}>" + + def add_node(self, viz: Any, node: str) -> None: + viz.add_node( + node, + label=self.get_node_label(node), + style="filled", + fillcolor="yellow", + fontsize=15, + fontname=self.fontname, + ) + + def add_edge( + self, viz: Any, source: str, target: str, label: Optional[str] = None + ) -> None: + viz.add_edge( + source, + target, + label=self.get_edge_label(label) if label else "", + fontsize=12, + fontname=self.fontname, + ) + + def draw(self, graph: Graph, output_path: Optional[str] = None) -> Optional[bytes]: + """ + Draws the given state graph into a PNG file. + Requires graphviz and pygraphviz to be installed. + :param graph: The graph to draw + :param output_path: The path to save the PNG. If None, PNG bytes are returned. + """ + + try: + import pygraphviz as pgv # type: ignore[import] + except ImportError as exc: + raise ImportError( + "Install pygraphviz to draw graphs: `pip install pygraphviz`." + ) from exc + + # Create a directed graph + viz = pgv.AGraph(directed=True, nodesep=0.9, ranksep=1.0) + + # Add nodes, conditional edges, and edges to the graph + self.add_nodes(viz, graph) + self.add_edges(viz, graph) + + # Update entrypoint and END styles + self.update_styles(viz, graph) + + # Save the graph as PNG + try: + return viz.draw(output_path, format="png", prog="dot") + finally: + viz.close() + + def add_nodes(self, viz: Any, graph: Graph) -> None: + for node in graph.nodes: + self.add_node(viz, node) + + def add_edges(self, viz: Any, graph: Graph) -> None: + for start, end, label in graph.edges: + self.add_edge(viz, start, end, label) + + def update_styles(self, viz: Any, graph: Graph) -> None: + if first := graph.first_node(): + viz.get_node(first.id).attr.update(fillcolor="lightblue") + if last := graph.last_node(): + viz.get_node(last.id).attr.update(fillcolor="orange") diff --git a/libs/core/poetry.lock b/libs/core/poetry.lock index 3078c2ae55..3ecb233f5f 100644 --- a/libs/core/poetry.lock +++ b/libs/core/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "annotated-types" @@ -1996,6 +1996,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},