Add JSON representation of runnable graph to serialized representation (#17745)

Sent to LangSmith

Thank you for contributing to LangChain!

Checklist:

- [ ] PR title: Please title your PR "package: description", where
"package" is whichever of langchain, community, core, experimental, etc.
is being modified. Use "docs: ..." for purely docs changes, "templates:
..." for template changes, "infra: ..." for CI changes.
  - Example: "community: add foobar LLM"
- [ ] PR message: **Delete this entire template message** and replace it
with the following bulleted list
    - **Description:** a description of the change
    - **Issue:** the issue # it fixes, if applicable
    - **Dependencies:** any dependencies required for this change
- **Twitter handle:** if your PR gets announced, and you'd like a
mention, we'll gladly shout you out!
- [ ] Pass lint and test: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified to check that you're
passing lint and testing. See contribution guidelines for more
information on how to write/run tests, lint, etc:
https://python.langchain.com/docs/contributing/
- [ ] Add tests and docs: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, hwchase17.
This commit is contained in:
Nuno Campos 2024-02-20 14:51:09 -08:00 committed by GitHub
parent 6e854ae371
commit 223e5eff14
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 50361 additions and 75 deletions

View File

@ -1,5 +1,16 @@
from abc import ABC
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast
from typing import (
Any,
Dict,
List,
Literal,
Optional,
TypedDict,
Union,
cast,
)
from typing_extensions import NotRequired
from langchain_core.pydantic_v1 import BaseModel, PrivateAttr
@ -9,6 +20,8 @@ class BaseSerialized(TypedDict):
lc: int
id: List[str]
name: NotRequired[str]
graph: NotRequired[Dict[str, Any]]
class SerializedConstructor(BaseSerialized):

View File

@ -37,7 +37,11 @@ from typing_extensions import Literal, get_args
from langchain_core._api import beta_decorator
from langchain_core.load.dump import dumpd
from langchain_core.load.serializable import Serializable
from langchain_core.load.serializable import (
Serializable,
SerializedConstructor,
SerializedNotImplemented,
)
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables.config import (
RunnableConfig,
@ -1630,6 +1634,16 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
name: Optional[str] = None
"""The name of the runnable. Used for debugging and tracing."""
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
"""Serialize the runnable to JSON."""
dumped = super().to_json()
try:
dumped["name"] = self.get_name()
dumped["graph"] = self.get_graph().to_json()
except Exception:
pass
return dumped
def configurable_fields(
self, **kwargs: AnyConfigurableField
) -> RunnableSerializable[Input, Output]:

View File

@ -1,8 +1,9 @@
from __future__ import annotations
import inspect
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, NamedTuple, Optional, Type, Union
from uuid import uuid4
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Type, Union
from uuid import UUID, uuid4
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.graph_draw import draw
@ -11,11 +12,20 @@ if TYPE_CHECKING:
from langchain_core.runnables.base import Runnable as RunnableType
def is_uuid(value: str) -> bool:
try:
UUID(value)
return True
except ValueError:
return False
class Edge(NamedTuple):
"""Edge in a graph."""
source: str
target: str
data: Optional[str] = None
class Node(NamedTuple):
@ -25,6 +35,61 @@ class Node(NamedTuple):
data: Union[Type[BaseModel], RunnableType]
def node_data_str(node: Node) -> str:
from langchain_core.runnables.base import Runnable
if not is_uuid(node.id):
return node.id
elif isinstance(node.data, Runnable):
try:
data = str(node.data)
if (
data.startswith("<")
or data[0] != data[0].upper()
or len(data.splitlines()) > 1
):
data = node.data.__class__.__name__
elif len(data) > 42:
data = data[:42] + "..."
except Exception:
data = node.data.__class__.__name__
else:
data = node.data.__name__
return data if not data.startswith("Runnable") else data[8:]
def node_data_json(node: Node) -> Dict[str, Union[str, Dict[str, Any]]]:
from langchain_core.load.serializable import to_json_not_implemented
from langchain_core.runnables.base import Runnable, RunnableSerializable
if isinstance(node.data, RunnableSerializable):
return {
"type": "runnable",
"data": {
"id": node.data.lc_id(),
"name": node.data.get_name(),
},
}
elif isinstance(node.data, Runnable):
return {
"type": "runnable",
"data": {
"id": to_json_not_implemented(node.data)["id"],
"name": node.data.get_name(),
},
}
elif inspect.isclass(node.data) and issubclass(node.data, BaseModel):
return {
"type": "schema",
"data": node.data.schema(),
}
else:
return {
"type": "unknown",
"data": node_data_str(node),
}
@dataclass
class Graph:
"""Graph of nodes and edges."""
@ -32,15 +97,46 @@ class Graph:
nodes: Dict[str, Node] = field(default_factory=dict)
edges: List[Edge] = field(default_factory=list)
def to_json(self) -> Dict[str, List[Dict[str, Any]]]:
"""Convert the graph to a JSON-serializable format."""
stable_node_ids = {
node.id: i if is_uuid(node.id) else node.id
for i, node in enumerate(self.nodes.values())
}
return {
"nodes": [
{"id": stable_node_ids[node.id], **node_data_json(node)}
for node in self.nodes.values()
],
"edges": [
{
"source": stable_node_ids[edge.source],
"target": stable_node_ids[edge.target],
"data": edge.data,
}
if edge.data is not None
else {
"source": stable_node_ids[edge.source],
"target": stable_node_ids[edge.target],
}
for edge in self.edges
],
}
def __bool__(self) -> bool:
return bool(self.nodes)
def next_id(self) -> str:
return uuid4().hex
def add_node(self, data: Union[Type[BaseModel], RunnableType]) -> Node:
def add_node(
self, data: Union[Type[BaseModel], RunnableType], id: Optional[str] = None
) -> Node:
"""Add a node to the graph and return it."""
node = Node(id=self.next_id(), data=data)
if id is not None and id in self.nodes:
raise ValueError(f"Node with id {id} already exists")
node = Node(id=id or self.next_id(), data=data)
self.nodes[node.id] = node
return node
@ -53,13 +149,13 @@ class Graph:
if edge.source != node.id and edge.target != node.id
]
def add_edge(self, source: Node, target: Node) -> Edge:
def add_edge(self, source: Node, target: Node, data: Optional[str] = None) -> Edge:
"""Add an edge to the graph and return it."""
if source.id not in self.nodes:
raise ValueError(f"Source node {source.id} not in graph")
if target.id not in self.nodes:
raise ValueError(f"Target node {target.id} not in graph")
edge = Edge(source=source.id, target=target.id)
edge = Edge(source=source.id, target=target.id, data=data)
self.edges.append(edge)
return edge
@ -117,28 +213,8 @@ class Graph:
self.remove_node(last_node)
def draw_ascii(self) -> str:
from langchain_core.runnables.base import Runnable
def node_data(node: Node) -> str:
if isinstance(node.data, Runnable):
try:
data = str(node.data)
if (
data.startswith("<")
or data[0] != data[0].upper()
or len(data.splitlines()) > 1
):
data = node.data.__class__.__name__
elif len(data) > 42:
data = data[:42] + "..."
except Exception:
data = node.data.__class__.__name__
else:
data = node.data.__name__
return data if not data.startswith("Runnable") else data[8:]
return draw(
{node.id: node_data(node) for node in self.nodes.values()},
{node.id: node_data_str(node) for node in self.nodes.values()},
[(edge.source, edge.target) for edge in self.edges],
)

File diff suppressed because one or more lines are too long

View File

@ -31,6 +31,63 @@ def test_graph_sequence(snapshot: SnapshotAssertion) -> None:
sequence = prompt | fake_llm | list_parser
graph = sequence.get_graph()
assert graph.to_json() == {
"nodes": [
{
"id": 0,
"type": "schema",
"data": {
"title": "PromptInput",
"type": "object",
"properties": {"name": {"title": "Name", "type": "string"}},
},
},
{
"id": 1,
"type": "runnable",
"data": {
"id": ["langchain", "prompts", "prompt", "PromptTemplate"],
"name": "PromptTemplate",
},
},
{
"id": 2,
"type": "runnable",
"data": {
"id": ["tests", "unit_tests", "fake", "llm", "FakeListLLM"],
"name": "FakeListLLM",
},
},
{
"id": 3,
"type": "runnable",
"data": {
"id": [
"langchain",
"output_parsers",
"list",
"CommaSeparatedListOutputParser",
],
"name": "CommaSeparatedListOutputParser",
},
},
{
"id": 4,
"type": "schema",
"data": {
"title": "CommaSeparatedListOutputParserOutput",
"type": "array",
"items": {"type": "string"},
},
},
],
"edges": [
{"source": 0, "target": 1},
{"source": 1, "target": 2},
{"source": 3, "target": 4},
{"source": 2, "target": 3},
],
}
assert graph.draw_ascii() == snapshot
@ -56,4 +113,343 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
}
)
graph = sequence.get_graph()
assert graph.to_json() == {
"nodes": [
{
"id": 0,
"type": "schema",
"data": {
"title": "PromptInput",
"type": "object",
"properties": {"name": {"title": "Name", "type": "string"}},
},
},
{
"id": 1,
"type": "runnable",
"data": {
"id": ["langchain", "prompts", "prompt", "PromptTemplate"],
"name": "PromptTemplate",
},
},
{
"id": 2,
"type": "runnable",
"data": {
"id": ["tests", "unit_tests", "fake", "llm", "FakeListLLM"],
"name": "FakeListLLM",
},
},
{
"id": 3,
"type": "schema",
"data": {
"title": "RunnableParallel<as_list,as_str>Input",
"anyOf": [
{"type": "string"},
{"$ref": "#/definitions/AIMessage"},
{"$ref": "#/definitions/HumanMessage"},
{"$ref": "#/definitions/ChatMessage"},
{"$ref": "#/definitions/SystemMessage"},
{"$ref": "#/definitions/FunctionMessage"},
{"$ref": "#/definitions/ToolMessage"},
],
"definitions": {
"AIMessage": {
"title": "AIMessage",
"description": "Message from an AI.",
"type": "object",
"properties": {
"content": {
"title": "Content",
"anyOf": [
{"type": "string"},
{
"type": "array",
"items": {
"anyOf": [
{"type": "string"},
{"type": "object"},
]
},
},
],
},
"additional_kwargs": {
"title": "Additional Kwargs",
"type": "object",
},
"type": {
"title": "Type",
"default": "ai",
"enum": ["ai"],
"type": "string",
},
"name": {"title": "Name", "type": "string"},
"example": {
"title": "Example",
"default": False,
"type": "boolean",
},
},
"required": ["content"],
},
"HumanMessage": {
"title": "HumanMessage",
"description": "Message from a human.",
"type": "object",
"properties": {
"content": {
"title": "Content",
"anyOf": [
{"type": "string"},
{
"type": "array",
"items": {
"anyOf": [
{"type": "string"},
{"type": "object"},
]
},
},
],
},
"additional_kwargs": {
"title": "Additional Kwargs",
"type": "object",
},
"type": {
"title": "Type",
"default": "human",
"enum": ["human"],
"type": "string",
},
"name": {"title": "Name", "type": "string"},
"example": {
"title": "Example",
"default": False,
"type": "boolean",
},
},
"required": ["content"],
},
"ChatMessage": {
"title": "ChatMessage",
"description": "Message that can be assigned an arbitrary speaker (i.e. role).", # noqa: E501
"type": "object",
"properties": {
"content": {
"title": "Content",
"anyOf": [
{"type": "string"},
{
"type": "array",
"items": {
"anyOf": [
{"type": "string"},
{"type": "object"},
]
},
},
],
},
"additional_kwargs": {
"title": "Additional Kwargs",
"type": "object",
},
"type": {
"title": "Type",
"default": "chat",
"enum": ["chat"],
"type": "string",
},
"name": {"title": "Name", "type": "string"},
"role": {"title": "Role", "type": "string"},
},
"required": ["content", "role"],
},
"SystemMessage": {
"title": "SystemMessage",
"description": "Message for priming AI behavior, usually passed in as the first of a sequence\nof input messages.", # noqa: E501
"type": "object",
"properties": {
"content": {
"title": "Content",
"anyOf": [
{"type": "string"},
{
"type": "array",
"items": {
"anyOf": [
{"type": "string"},
{"type": "object"},
]
},
},
],
},
"additional_kwargs": {
"title": "Additional Kwargs",
"type": "object",
},
"type": {
"title": "Type",
"default": "system",
"enum": ["system"],
"type": "string",
},
"name": {"title": "Name", "type": "string"},
},
"required": ["content"],
},
"FunctionMessage": {
"title": "FunctionMessage",
"description": "Message for passing the result of executing a function back to a model.", # noqa: E501
"type": "object",
"properties": {
"content": {
"title": "Content",
"anyOf": [
{"type": "string"},
{
"type": "array",
"items": {
"anyOf": [
{"type": "string"},
{"type": "object"},
]
},
},
],
},
"additional_kwargs": {
"title": "Additional Kwargs",
"type": "object",
},
"type": {
"title": "Type",
"default": "function",
"enum": ["function"],
"type": "string",
},
"name": {"title": "Name", "type": "string"},
},
"required": ["content", "name"],
},
"ToolMessage": {
"title": "ToolMessage",
"description": "Message for passing the result of executing a tool back to a model.", # noqa: E501
"type": "object",
"properties": {
"content": {
"title": "Content",
"anyOf": [
{"type": "string"},
{
"type": "array",
"items": {
"anyOf": [
{"type": "string"},
{"type": "object"},
]
},
},
],
},
"additional_kwargs": {
"title": "Additional Kwargs",
"type": "object",
},
"type": {
"title": "Type",
"default": "tool",
"enum": ["tool"],
"type": "string",
},
"name": {"title": "Name", "type": "string"},
"tool_call_id": {
"title": "Tool Call Id",
"type": "string",
},
},
"required": ["content", "tool_call_id"],
},
},
},
},
{
"id": 4,
"type": "schema",
"data": {
"title": "RunnableParallel<as_list,as_str>Output",
"type": "object",
"properties": {
"as_list": {
"title": "As List",
"type": "array",
"items": {"type": "string"},
},
"as_str": {"title": "As Str"},
},
},
},
{
"id": 5,
"type": "runnable",
"data": {
"id": [
"langchain",
"output_parsers",
"list",
"CommaSeparatedListOutputParser",
],
"name": "CommaSeparatedListOutputParser",
},
},
{
"id": 6,
"type": "schema",
"data": {"title": "conditional_str_parser_input", "type": "string"},
},
{
"id": 7,
"type": "schema",
"data": {"title": "conditional_str_parser_output"},
},
{
"id": 8,
"type": "runnable",
"data": {
"id": ["langchain", "schema", "output_parser", "StrOutputParser"],
"name": "StrOutputParser",
},
},
{
"id": 9,
"type": "runnable",
"data": {
"id": [
"langchain_core",
"output_parsers",
"xml",
"XMLOutputParser",
],
"name": "XMLOutputParser",
},
},
],
"edges": [
{"source": 0, "target": 1},
{"source": 1, "target": 2},
{"source": 3, "target": 5},
{"source": 5, "target": 4},
{"source": 6, "target": 8},
{"source": 8, "target": 7},
{"source": 6, "target": 9},
{"source": 9, "target": 7},
{"source": 3, "target": 6},
{"source": 7, "target": 4},
{"source": 2, "target": 3},
],
}
assert graph.draw_ascii() == snapshot