You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/core/langchain_core/runnables/graph.py

141 lines
4.9 KiB
Python

from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, NamedTuple, Optional, Type, Union
from uuid import uuid4
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.graph_draw import draw
if TYPE_CHECKING:
from langchain_core.runnables.base import Runnable as RunnableType
class Edge(NamedTuple):
source: str
target: str
class Node(NamedTuple):
id: str
data: Union[Type[BaseModel], RunnableType]
@dataclass
class Graph:
nodes: Dict[str, Node] = field(default_factory=dict)
edges: List[Edge] = field(default_factory=list)
def __bool__(self) -> bool:
return bool(self.nodes)
def next_id(self) -> str:
return uuid4().hex
def add_node(self, data: Union[Type[BaseModel], RunnableType]) -> Node:
"""Add a node to the graph and return it."""
node = Node(id=self.next_id(), data=data)
self.nodes[node.id] = node
return node
def remove_node(self, node: Node) -> None:
"""Remove a node from the graphm and all edges connected to it."""
self.nodes.pop(node.id)
self.edges = [
edge
for edge in self.edges
if edge.source != node.id and edge.target != node.id
]
def add_edge(self, source: Node, target: Node) -> Edge:
"""Add an edge to the graph and return it."""
if source.id not in self.nodes:
raise ValueError(f"Source node {source.id} not in graph")
if target.id not in self.nodes:
raise ValueError(f"Target node {target.id} not in graph")
edge = Edge(source=source.id, target=target.id)
self.edges.append(edge)
return edge
def extend(self, graph: Graph) -> None:
"""Add all nodes and edges from another graph.
Note this doesn't check for duplicates, nor does it connect the graphs."""
self.nodes.update(graph.nodes)
self.edges.extend(graph.edges)
def first_node(self) -> Optional[Node]:
"""Find the single node that is not a target of any edge.
If there is no such node, or there are multiple, return None.
When drawing the graph this node would be the origin."""
targets = {edge.target for edge in self.edges}
found: List[Node] = []
for node in self.nodes.values():
if node.id not in targets:
found.append(node)
return found[0] if len(found) == 1 else None
def last_node(self) -> Optional[Node]:
"""Find the single node that is not a source of any edge.
If there is no such node, or there are multiple, return None.
When drawing the graph this node would be the destination.
"""
sources = {edge.source for edge in self.edges}
found: List[Node] = []
for node in self.nodes.values():
if node.id not in sources:
found.append(node)
return found[0] if len(found) == 1 else None
def trim_first_node(self) -> None:
"""Remove the first node if it exists and has a single outgoing edge,
ie. if removing it would not leave the graph without a "first" node."""
first_node = self.first_node()
if first_node:
if (
len(self.nodes) == 1
or len([edge for edge in self.edges if edge.source == first_node.id])
== 1
):
self.remove_node(first_node)
def trim_last_node(self) -> None:
"""Remove the last node if it exists and has a single incoming edge,
ie. if removing it would not leave the graph without a "last" node."""
last_node = self.last_node()
if last_node:
if (
len(self.nodes) == 1
or len([edge for edge in self.edges if edge.target == last_node.id])
== 1
):
self.remove_node(last_node)
def draw_ascii(self) -> str:
from langchain_core.runnables.base import Runnable
def node_data(node: Node) -> str:
if isinstance(node.data, Runnable):
try:
data = str(node.data)
if (
data.startswith("<")
or data[0] != data[0].upper()
or len(data.splitlines()) > 1
):
data = node.data.__class__.__name__
elif len(data) > 42:
data = data[:42] + "..."
except Exception:
data = node.data.__class__.__name__
else:
data = node.data.__name__
return data if not data.startswith("Runnable") else data[8:]
return draw(
{node.id: node_data(node) for node in self.nodes.values()},
[(edge.source, edge.target) for edge in self.edges],
)
def print_ascii(self) -> None:
print(self.draw_ascii())