diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index 402b0ee4f3..a5cad0a9e5 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -114,7 +114,7 @@ def draw_mermaid( def _escape_node_label(node_label: str) -> str: """Escapes the node label for Mermaid syntax.""" - return re.sub(r"[^a-zA-Z-_]", "_", node_label) + return re.sub(r"[^a-zA-Z-_0-9]", "_", node_label) def _adjust_mermaid_edge( diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index dcefdffcfd..bd1d033f7f 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -8,6 +8,7 @@ from langchain_core.output_parsers.string import StrOutputParser from langchain_core.output_parsers.xml import XMLOutputParser from langchain_core.prompts.prompt import PromptTemplate from langchain_core.runnables.base import Runnable, RunnableConfig +from langchain_core.runnables.graph_mermaid import _escape_node_label from tests.unit_tests.stubs import AnyStr @@ -734,3 +735,11 @@ def test_runnable_get_graph_with_invalid_output_type() -> None: assert runnable.invoke(1) == 1 # check whether runnable.get_graph works runnable.get_graph() + + +def test_graph_mermaid_escape_node_label() -> None: + """Test that node labels are correctly preprocessed for draw_mermaid""" + assert _escape_node_label("foo") == "foo" + assert _escape_node_label("foo-bar") == "foo-bar" + assert _escape_node_label("foo_1") == "foo_1" + assert _escape_node_label("#foo*&!") == "_foo___"