diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index 969780446e..b9690b9edb 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -19,7 +19,6 @@ from typing import ( from uuid import UUID, uuid4 from langchain_core.pydantic_v1 import BaseModel -from langchain_core.runnables.graph_ascii import draw_ascii if TYPE_CHECKING: from langchain_core.runnables.base import Runnable as RunnableType @@ -44,6 +43,7 @@ class Edge(NamedTuple): source: str target: str data: Optional[str] = None + conditional: bool = False class Node(NamedTuple): @@ -219,13 +219,21 @@ class Graph: if edge.source != node.id and edge.target != node.id ] - def add_edge(self, source: Node, target: Node, data: Optional[str] = None) -> Edge: + def add_edge( + self, + source: Node, + target: Node, + data: Optional[str] = None, + conditional: bool = False, + ) -> Edge: """Add an edge to the graph and return it.""" if source.id not in self.nodes: raise ValueError(f"Source node {source.id} not in graph") if target.id not in self.nodes: raise ValueError(f"Target node {target.id} not in graph") - edge = Edge(source=source.id, target=target.id, data=data) + edge = Edge( + source=source.id, target=target.id, data=data, conditional=conditional + ) self.edges.append(edge) return edge @@ -283,9 +291,11 @@ class Graph: self.remove_node(last_node) def draw_ascii(self) -> str: + from langchain_core.runnables.graph_ascii import draw_ascii + return draw_ascii( {node.id: node_data_str(node) for node in self.nodes.values()}, - [(edge.source, edge.target) for edge in self.edges], + self.edges, ) def print_ascii(self) -> None: diff --git a/libs/core/langchain_core/runnables/graph_ascii.py b/libs/core/langchain_core/runnables/graph_ascii.py index bc809dce28..089cdb9923 100644 --- a/libs/core/langchain_core/runnables/graph_ascii.py +++ b/libs/core/langchain_core/runnables/graph_ascii.py @@ -3,7 +3,9 @@ Adapted from https://github.com/iterative/dvc/blob/main/dvc/dagascii.py""" import math import os -from typing import Any, Mapping, Sequence, Tuple +from typing import Any, Mapping, Sequence + +from langchain_core.runnables.graph import Edge as LangEdge class VertexViewer: @@ -156,7 +158,7 @@ class AsciiCanvas: def _build_sugiyama_layout( - vertices: Mapping[str, str], edges: Sequence[Tuple[str, str]] + vertices: Mapping[str, str], edges: Sequence[LangEdge] ) -> Any: try: from grandalf.graphs import Edge, Graph, Vertex # type: ignore[import] @@ -181,7 +183,7 @@ def _build_sugiyama_layout( # vertices_ = {id: Vertex(f" {data} ") for id, data in vertices.items()} - edges_ = [Edge(vertices_[s], vertices_[e]) for s, e in edges] + edges_ = [Edge(vertices_[s], vertices_[e], data=cond) for s, e, _, cond in edges] vertices_list = vertices_.values() graph = Graph(vertices_list, edges_) @@ -209,7 +211,7 @@ def _build_sugiyama_layout( return sug -def draw_ascii(vertices: Mapping[str, str], edges: Sequence[Tuple[str, str]]) -> str: +def draw_ascii(vertices: Mapping[str, str], edges: Sequence[LangEdge]) -> str: """Build a DAG and draw it in ASCII. Args: @@ -220,7 +222,6 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[Tuple[str, str]]) -> str: ASCII representation Example: - >>> from dvc.dagascii import draw >>> vertices = [1, 2, 3, 4] >>> edges = [(1, 2), (2, 3), (2, 4), (1, 4)] >>> print(draw(vertices, edges)) @@ -287,7 +288,7 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[Tuple[str, str]]) -> assert end_x >= 0 assert end_y >= 0 - canvas.line(start_x, start_y, end_x, end_y, "*") + canvas.line(start_x, start_y, end_x, end_y, "." if edge.data else "*") for vertex in sug.g.sV: # NOTE: moving boxes w/2 to the left diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index 6912f2bcc2..210a3b5be5 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -76,9 +76,15 @@ def draw_mermaid( for i in range(0, len(words), wrap_label_n_words) ] ) - edge_label = f" -- {edge_data} --> " + if edge.conditional: + edge_label = f" -. {edge_data} .-> " + else: + edge_label = f" -- {edge_data} --> " else: - edge_label = " --> " + if edge.conditional: + edge_label = " -.-> " + else: + edge_label = " --> " mermaid_graph += ( f"\t{_escape_node_label(source)}{edge_label}" f"{_escape_node_label(target)};\n" diff --git a/libs/core/langchain_core/runnables/graph_png.py b/libs/core/langchain_core/runnables/graph_png.py index 51cfcb9577..9116fc81f9 100644 --- a/libs/core/langchain_core/runnables/graph_png.py +++ b/libs/core/langchain_core/runnables/graph_png.py @@ -52,7 +52,12 @@ class PngDrawer: ) def add_edge( - self, viz: Any, source: str, target: str, label: Optional[str] = None + self, + viz: Any, + source: str, + target: str, + label: Optional[str] = None, + conditional: bool = False, ) -> None: viz.add_edge( source, @@ -60,6 +65,7 @@ class PngDrawer: label=self.get_edge_label(label) if label else "", fontsize=12, fontname=self.fontname, + style="dotted" if conditional else "solid", ) def draw(self, graph: Graph, output_path: Optional[str] = None) -> Optional[bytes]: @@ -98,8 +104,8 @@ class PngDrawer: self.add_node(viz, node) def add_edges(self, viz: Any, graph: Graph) -> None: - for start, end, label in graph.edges: - self.add_edge(viz, start, end, label) + for start, end, label, cond in graph.edges: + self.add_edge(viz, start, end, label, cond) def update_styles(self, viz: Any, graph: Graph) -> None: if first := graph.first_node():