mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Add JSON representation of runnable graph to serialized representation (#17745)
Sent to LangSmith Thank you for contributing to LangChain! Checklist: - [ ] PR title: Please title your PR "package: description", where "package" is whichever of langchain, community, core, experimental, etc. is being modified. Use "docs: ..." for purely docs changes, "templates: ..." for template changes, "infra: ..." for CI changes. - Example: "community: add foobar LLM" - [ ] PR message: **Delete this entire template message** and replace it with the following bulleted list - **Description:** a description of the change - **Issue:** the issue # it fixes, if applicable - **Dependencies:** any dependencies required for this change - **Twitter handle:** if your PR gets announced, and you'd like a mention, we'll gladly shout you out! - [ ] Pass lint and test: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified to check that you're passing lint and testing. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ - [ ] Add tests and docs: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, hwchase17.
This commit is contained in:
parent
6e854ae371
commit
223e5eff14
@ -1,5 +1,16 @@
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, PrivateAttr
|
||||
|
||||
@ -9,6 +20,8 @@ class BaseSerialized(TypedDict):
|
||||
|
||||
lc: int
|
||||
id: List[str]
|
||||
name: NotRequired[str]
|
||||
graph: NotRequired[Dict[str, Any]]
|
||||
|
||||
|
||||
class SerializedConstructor(BaseSerialized):
|
||||
|
@ -37,7 +37,11 @@ from typing_extensions import Literal, get_args
|
||||
|
||||
from langchain_core._api import beta_decorator
|
||||
from langchain_core.load.dump import dumpd
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.load.serializable import (
|
||||
Serializable,
|
||||
SerializedConstructor,
|
||||
SerializedNotImplemented,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.runnables.config import (
|
||||
RunnableConfig,
|
||||
@ -1630,6 +1634,16 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
||||
name: Optional[str] = None
|
||||
"""The name of the runnable. Used for debugging and tracing."""
|
||||
|
||||
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
|
||||
"""Serialize the runnable to JSON."""
|
||||
dumped = super().to_json()
|
||||
try:
|
||||
dumped["name"] = self.get_name()
|
||||
dumped["graph"] = self.get_graph().to_json()
|
||||
except Exception:
|
||||
pass
|
||||
return dumped
|
||||
|
||||
def configurable_fields(
|
||||
self, **kwargs: AnyConfigurableField
|
||||
) -> RunnableSerializable[Input, Output]:
|
||||
|
@ -1,8 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Dict, List, NamedTuple, Optional, Type, Union
|
||||
from uuid import uuid4
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Type, Union
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables.graph_draw import draw
|
||||
@ -11,11 +12,20 @@ if TYPE_CHECKING:
|
||||
from langchain_core.runnables.base import Runnable as RunnableType
|
||||
|
||||
|
||||
def is_uuid(value: str) -> bool:
|
||||
try:
|
||||
UUID(value)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
class Edge(NamedTuple):
|
||||
"""Edge in a graph."""
|
||||
|
||||
source: str
|
||||
target: str
|
||||
data: Optional[str] = None
|
||||
|
||||
|
||||
class Node(NamedTuple):
|
||||
@ -25,6 +35,61 @@ class Node(NamedTuple):
|
||||
data: Union[Type[BaseModel], RunnableType]
|
||||
|
||||
|
||||
def node_data_str(node: Node) -> str:
|
||||
from langchain_core.runnables.base import Runnable
|
||||
|
||||
if not is_uuid(node.id):
|
||||
return node.id
|
||||
elif isinstance(node.data, Runnable):
|
||||
try:
|
||||
data = str(node.data)
|
||||
if (
|
||||
data.startswith("<")
|
||||
or data[0] != data[0].upper()
|
||||
or len(data.splitlines()) > 1
|
||||
):
|
||||
data = node.data.__class__.__name__
|
||||
elif len(data) > 42:
|
||||
data = data[:42] + "..."
|
||||
except Exception:
|
||||
data = node.data.__class__.__name__
|
||||
else:
|
||||
data = node.data.__name__
|
||||
return data if not data.startswith("Runnable") else data[8:]
|
||||
|
||||
|
||||
def node_data_json(node: Node) -> Dict[str, Union[str, Dict[str, Any]]]:
|
||||
from langchain_core.load.serializable import to_json_not_implemented
|
||||
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||
|
||||
if isinstance(node.data, RunnableSerializable):
|
||||
return {
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": node.data.lc_id(),
|
||||
"name": node.data.get_name(),
|
||||
},
|
||||
}
|
||||
elif isinstance(node.data, Runnable):
|
||||
return {
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": to_json_not_implemented(node.data)["id"],
|
||||
"name": node.data.get_name(),
|
||||
},
|
||||
}
|
||||
elif inspect.isclass(node.data) and issubclass(node.data, BaseModel):
|
||||
return {
|
||||
"type": "schema",
|
||||
"data": node.data.schema(),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"type": "unknown",
|
||||
"data": node_data_str(node),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Graph:
|
||||
"""Graph of nodes and edges."""
|
||||
@ -32,15 +97,46 @@ class Graph:
|
||||
nodes: Dict[str, Node] = field(default_factory=dict)
|
||||
edges: List[Edge] = field(default_factory=list)
|
||||
|
||||
def to_json(self) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""Convert the graph to a JSON-serializable format."""
|
||||
stable_node_ids = {
|
||||
node.id: i if is_uuid(node.id) else node.id
|
||||
for i, node in enumerate(self.nodes.values())
|
||||
}
|
||||
|
||||
return {
|
||||
"nodes": [
|
||||
{"id": stable_node_ids[node.id], **node_data_json(node)}
|
||||
for node in self.nodes.values()
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": stable_node_ids[edge.source],
|
||||
"target": stable_node_ids[edge.target],
|
||||
"data": edge.data,
|
||||
}
|
||||
if edge.data is not None
|
||||
else {
|
||||
"source": stable_node_ids[edge.source],
|
||||
"target": stable_node_ids[edge.target],
|
||||
}
|
||||
for edge in self.edges
|
||||
],
|
||||
}
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.nodes)
|
||||
|
||||
def next_id(self) -> str:
|
||||
return uuid4().hex
|
||||
|
||||
def add_node(self, data: Union[Type[BaseModel], RunnableType]) -> Node:
|
||||
def add_node(
|
||||
self, data: Union[Type[BaseModel], RunnableType], id: Optional[str] = None
|
||||
) -> Node:
|
||||
"""Add a node to the graph and return it."""
|
||||
node = Node(id=self.next_id(), data=data)
|
||||
if id is not None and id in self.nodes:
|
||||
raise ValueError(f"Node with id {id} already exists")
|
||||
node = Node(id=id or self.next_id(), data=data)
|
||||
self.nodes[node.id] = node
|
||||
return node
|
||||
|
||||
@ -53,13 +149,13 @@ class Graph:
|
||||
if edge.source != node.id and edge.target != node.id
|
||||
]
|
||||
|
||||
def add_edge(self, source: Node, target: Node) -> Edge:
|
||||
def add_edge(self, source: Node, target: Node, data: Optional[str] = None) -> Edge:
|
||||
"""Add an edge to the graph and return it."""
|
||||
if source.id not in self.nodes:
|
||||
raise ValueError(f"Source node {source.id} not in graph")
|
||||
if target.id not in self.nodes:
|
||||
raise ValueError(f"Target node {target.id} not in graph")
|
||||
edge = Edge(source=source.id, target=target.id)
|
||||
edge = Edge(source=source.id, target=target.id, data=data)
|
||||
self.edges.append(edge)
|
||||
return edge
|
||||
|
||||
@ -117,28 +213,8 @@ class Graph:
|
||||
self.remove_node(last_node)
|
||||
|
||||
def draw_ascii(self) -> str:
|
||||
from langchain_core.runnables.base import Runnable
|
||||
|
||||
def node_data(node: Node) -> str:
|
||||
if isinstance(node.data, Runnable):
|
||||
try:
|
||||
data = str(node.data)
|
||||
if (
|
||||
data.startswith("<")
|
||||
or data[0] != data[0].upper()
|
||||
or len(data.splitlines()) > 1
|
||||
):
|
||||
data = node.data.__class__.__name__
|
||||
elif len(data) > 42:
|
||||
data = data[:42] + "..."
|
||||
except Exception:
|
||||
data = node.data.__class__.__name__
|
||||
else:
|
||||
data = node.data.__name__
|
||||
return data if not data.startswith("Runnable") else data[8:]
|
||||
|
||||
return draw(
|
||||
{node.id: node_data(node) for node in self.nodes.values()},
|
||||
{node.id: node_data_str(node) for node in self.nodes.values()},
|
||||
[(edge.source, edge.target) for edge in self.edges],
|
||||
)
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
@ -31,6 +31,63 @@ def test_graph_sequence(snapshot: SnapshotAssertion) -> None:
|
||||
|
||||
sequence = prompt | fake_llm | list_parser
|
||||
graph = sequence.get_graph()
|
||||
assert graph.to_json() == {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"name": {"title": "Name", "type": "string"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": ["langchain", "prompts", "prompt", "PromptTemplate"],
|
||||
"name": "PromptTemplate",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": ["tests", "unit_tests", "fake", "llm", "FakeListLLM"],
|
||||
"name": "FakeListLLM",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"output_parsers",
|
||||
"list",
|
||||
"CommaSeparatedListOutputParser",
|
||||
],
|
||||
"name": "CommaSeparatedListOutputParser",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 4,
|
||||
"type": "schema",
|
||||
"data": {
|
||||
"title": "CommaSeparatedListOutputParserOutput",
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
],
|
||||
"edges": [
|
||||
{"source": 0, "target": 1},
|
||||
{"source": 1, "target": 2},
|
||||
{"source": 3, "target": 4},
|
||||
{"source": 2, "target": 3},
|
||||
],
|
||||
}
|
||||
assert graph.draw_ascii() == snapshot
|
||||
|
||||
|
||||
@ -56,4 +113,343 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
|
||||
}
|
||||
)
|
||||
graph = sequence.get_graph()
|
||||
assert graph.to_json() == {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"name": {"title": "Name", "type": "string"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": ["langchain", "prompts", "prompt", "PromptTemplate"],
|
||||
"name": "PromptTemplate",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": ["tests", "unit_tests", "fake", "llm", "FakeListLLM"],
|
||||
"name": "FakeListLLM",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"type": "schema",
|
||||
"data": {
|
||||
"title": "RunnableParallel<as_list,as_str>Input",
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"$ref": "#/definitions/AIMessage"},
|
||||
{"$ref": "#/definitions/HumanMessage"},
|
||||
{"$ref": "#/definitions/ChatMessage"},
|
||||
{"$ref": "#/definitions/SystemMessage"},
|
||||
{"$ref": "#/definitions/FunctionMessage"},
|
||||
{"$ref": "#/definitions/ToolMessage"},
|
||||
],
|
||||
"definitions": {
|
||||
"AIMessage": {
|
||||
"title": "AIMessage",
|
||||
"description": "Message from an AI.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"title": "Content",
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "object"},
|
||||
]
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
"additional_kwargs": {
|
||||
"title": "Additional Kwargs",
|
||||
"type": "object",
|
||||
},
|
||||
"type": {
|
||||
"title": "Type",
|
||||
"default": "ai",
|
||||
"enum": ["ai"],
|
||||
"type": "string",
|
||||
},
|
||||
"name": {"title": "Name", "type": "string"},
|
||||
"example": {
|
||||
"title": "Example",
|
||||
"default": False,
|
||||
"type": "boolean",
|
||||
},
|
||||
},
|
||||
"required": ["content"],
|
||||
},
|
||||
"HumanMessage": {
|
||||
"title": "HumanMessage",
|
||||
"description": "Message from a human.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"title": "Content",
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "object"},
|
||||
]
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
"additional_kwargs": {
|
||||
"title": "Additional Kwargs",
|
||||
"type": "object",
|
||||
},
|
||||
"type": {
|
||||
"title": "Type",
|
||||
"default": "human",
|
||||
"enum": ["human"],
|
||||
"type": "string",
|
||||
},
|
||||
"name": {"title": "Name", "type": "string"},
|
||||
"example": {
|
||||
"title": "Example",
|
||||
"default": False,
|
||||
"type": "boolean",
|
||||
},
|
||||
},
|
||||
"required": ["content"],
|
||||
},
|
||||
"ChatMessage": {
|
||||
"title": "ChatMessage",
|
||||
"description": "Message that can be assigned an arbitrary speaker (i.e. role).", # noqa: E501
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"title": "Content",
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "object"},
|
||||
]
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
"additional_kwargs": {
|
||||
"title": "Additional Kwargs",
|
||||
"type": "object",
|
||||
},
|
||||
"type": {
|
||||
"title": "Type",
|
||||
"default": "chat",
|
||||
"enum": ["chat"],
|
||||
"type": "string",
|
||||
},
|
||||
"name": {"title": "Name", "type": "string"},
|
||||
"role": {"title": "Role", "type": "string"},
|
||||
},
|
||||
"required": ["content", "role"],
|
||||
},
|
||||
"SystemMessage": {
|
||||
"title": "SystemMessage",
|
||||
"description": "Message for priming AI behavior, usually passed in as the first of a sequence\nof input messages.", # noqa: E501
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"title": "Content",
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "object"},
|
||||
]
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
"additional_kwargs": {
|
||||
"title": "Additional Kwargs",
|
||||
"type": "object",
|
||||
},
|
||||
"type": {
|
||||
"title": "Type",
|
||||
"default": "system",
|
||||
"enum": ["system"],
|
||||
"type": "string",
|
||||
},
|
||||
"name": {"title": "Name", "type": "string"},
|
||||
},
|
||||
"required": ["content"],
|
||||
},
|
||||
"FunctionMessage": {
|
||||
"title": "FunctionMessage",
|
||||
"description": "Message for passing the result of executing a function back to a model.", # noqa: E501
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"title": "Content",
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "object"},
|
||||
]
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
"additional_kwargs": {
|
||||
"title": "Additional Kwargs",
|
||||
"type": "object",
|
||||
},
|
||||
"type": {
|
||||
"title": "Type",
|
||||
"default": "function",
|
||||
"enum": ["function"],
|
||||
"type": "string",
|
||||
},
|
||||
"name": {"title": "Name", "type": "string"},
|
||||
},
|
||||
"required": ["content", "name"],
|
||||
},
|
||||
"ToolMessage": {
|
||||
"title": "ToolMessage",
|
||||
"description": "Message for passing the result of executing a tool back to a model.", # noqa: E501
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"title": "Content",
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "object"},
|
||||
]
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
"additional_kwargs": {
|
||||
"title": "Additional Kwargs",
|
||||
"type": "object",
|
||||
},
|
||||
"type": {
|
||||
"title": "Type",
|
||||
"default": "tool",
|
||||
"enum": ["tool"],
|
||||
"type": "string",
|
||||
},
|
||||
"name": {"title": "Name", "type": "string"},
|
||||
"tool_call_id": {
|
||||
"title": "Tool Call Id",
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": ["content", "tool_call_id"],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 4,
|
||||
"type": "schema",
|
||||
"data": {
|
||||
"title": "RunnableParallel<as_list,as_str>Output",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"as_list": {
|
||||
"title": "As List",
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
"as_str": {"title": "As Str"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 5,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"output_parsers",
|
||||
"list",
|
||||
"CommaSeparatedListOutputParser",
|
||||
],
|
||||
"name": "CommaSeparatedListOutputParser",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 6,
|
||||
"type": "schema",
|
||||
"data": {"title": "conditional_str_parser_input", "type": "string"},
|
||||
},
|
||||
{
|
||||
"id": 7,
|
||||
"type": "schema",
|
||||
"data": {"title": "conditional_str_parser_output"},
|
||||
},
|
||||
{
|
||||
"id": 8,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": ["langchain", "schema", "output_parser", "StrOutputParser"],
|
||||
"name": "StrOutputParser",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 9,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain_core",
|
||||
"output_parsers",
|
||||
"xml",
|
||||
"XMLOutputParser",
|
||||
],
|
||||
"name": "XMLOutputParser",
|
||||
},
|
||||
},
|
||||
],
|
||||
"edges": [
|
||||
{"source": 0, "target": 1},
|
||||
{"source": 1, "target": 2},
|
||||
{"source": 3, "target": 5},
|
||||
{"source": 5, "target": 4},
|
||||
{"source": 6, "target": 8},
|
||||
{"source": 8, "target": 7},
|
||||
{"source": 6, "target": 9},
|
||||
{"source": 9, "target": 7},
|
||||
{"source": 3, "target": 6},
|
||||
{"source": 7, "target": 4},
|
||||
{"source": 2, "target": 3},
|
||||
],
|
||||
}
|
||||
assert graph.draw_ascii() == snapshot
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user