From 9ac302cb97cbf9392ab8c37309873b069e277176 Mon Sep 17 00:00:00 2001 From: Vadym Barda Date: Fri, 21 Jun 2024 17:35:32 -0400 Subject: [PATCH] core[minor]: update draw_mermaid node label processing (#23285) This fixes processing issue for nodes with numbers in their labels (e.g. `"node_1"`, which would previously be relabeled as `"node__"`, and now are correctly processed as `"node_1"`) --- libs/core/langchain_core/runnables/graph_mermaid.py | 2 +- libs/core/tests/unit_tests/runnables/test_graph.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index 402b0ee4f3..a5cad0a9e5 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -114,7 +114,7 @@ def draw_mermaid( def _escape_node_label(node_label: str) -> str: """Escapes the node label for Mermaid syntax.""" - return re.sub(r"[^a-zA-Z-_]", "_", node_label) + return re.sub(r"[^a-zA-Z-_0-9]", "_", node_label) def _adjust_mermaid_edge( diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index dcefdffcfd..bd1d033f7f 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -8,6 +8,7 @@ from langchain_core.output_parsers.string import StrOutputParser from langchain_core.output_parsers.xml import XMLOutputParser from langchain_core.prompts.prompt import PromptTemplate from langchain_core.runnables.base import Runnable, RunnableConfig +from langchain_core.runnables.graph_mermaid import _escape_node_label from tests.unit_tests.stubs import AnyStr @@ -734,3 +735,11 @@ def test_runnable_get_graph_with_invalid_output_type() -> None: assert runnable.invoke(1) == 1 # check whether runnable.get_graph works runnable.get_graph() + + +def test_graph_mermaid_escape_node_label() -> None: + """Test that node labels are correctly preprocessed for draw_mermaid""" + assert _escape_node_label("foo") == "foo" + assert _escape_node_label("foo-bar") == "foo-bar" + assert _escape_node_label("foo_1") == "foo_1" + assert _escape_node_label("#foo*&!") == "_foo___"