core: Fix implementation of trim_first_node/trim_last_node to use exact same definition of first/last node as in the getter methods (#24802)

This commit is contained in:
Nuno Campos 2024-07-30 08:44:27 -07:00 committed by GitHub
parent c2706cfb9e
commit 68ecebf1ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 134 additions and 28 deletions

View File

@ -13,6 +13,7 @@ from typing import (
NamedTuple,
Optional,
Protocol,
Sequence,
Tuple,
Type,
TypedDict,
@ -448,48 +449,27 @@ class Graph:
"""Find the single node that is not a target of any edge.
If there is no such node, or there are multiple, return None.
When drawing the graph, this node would be the origin."""
targets = {edge.target for edge in self.edges}
found: List[Node] = []
for node in self.nodes.values():
if node.id not in targets:
found.append(node)
return found[0] if len(found) == 1 else None
return _first_node(self)
def last_node(self) -> Optional[Node]:
"""Find the single node that is not a source of any edge.
If there is no such node, or there are multiple, return None.
When drawing the graph, this node would be the destination.
"""
sources = {edge.source for edge in self.edges}
found: List[Node] = []
for node in self.nodes.values():
if node.id not in sources:
found.append(node)
return found[0] if len(found) == 1 else None
When drawing the graph, this node would be the destination."""
return _last_node(self)
def trim_first_node(self) -> None:
"""Remove the first node if it exists and has a single outgoing edge,
i.e., if removing it would not leave the graph without a "first" node."""
first_node = self.first_node()
if first_node:
if (
len(self.nodes) == 1
or len([edge for edge in self.edges if edge.source == first_node.id])
== 1
):
self.remove_node(first_node)
if first_node and _first_node(self, exclude=[first_node.id]):
self.remove_node(first_node)
def trim_last_node(self) -> None:
"""Remove the last node if it exists and has a single incoming edge,
i.e., if removing it would not leave the graph without a "last" node."""
last_node = self.last_node()
if last_node:
if (
len(self.nodes) == 1
or len([edge for edge in self.edges if edge.target == last_node.id])
== 1
):
self.remove_node(last_node)
if last_node and _last_node(self, exclude=[last_node.id]):
self.remove_node(last_node)
def draw_ascii(self) -> str:
"""Draw the graph as an ASCII art string."""
@ -631,3 +611,29 @@ class Graph:
background_color=background_color,
padding=padding,
)
def _first_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]:
"""Find the single node that is not a target of any edge.
Exclude nodes/sources with ids in the exclude list.
If there is no such node, or there are multiple, return None.
When drawing the graph, this node would be the origin."""
targets = {edge.target for edge in graph.edges if edge.source not in exclude}
found: List[Node] = []
for node in graph.nodes.values():
if node.id not in exclude and node.id not in targets:
found.append(node)
return found[0] if len(found) == 1 else None
def _last_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]:
"""Find the single node that is not a source of any edge.
Exclude nodes/targets with ids in the exclude list.
If there is no such node, or there are multiple, return None.
When drawing the graph, this node would be the destination."""
sources = {edge.source for edge in graph.edges if edge.target not in exclude}
found: List[Node] = []
for node in graph.nodes.values():
if node.id not in exclude and node.id not in sources:
found.append(node)
return found[0] if len(found) == 1 else None

View File

@ -1097,3 +1097,65 @@
'''
# ---
# name: test_trim
dict({
'edges': list([
dict({
'source': '__start__',
'target': 'ask_question',
}),
dict({
'source': 'ask_question',
'target': 'answer_question',
}),
dict({
'conditional': True,
'source': 'answer_question',
'target': 'ask_question',
}),
dict({
'conditional': True,
'source': 'answer_question',
'target': '__end__',
}),
]),
'nodes': list([
dict({
'data': '__start__',
'id': '__start__',
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain',
'schema',
'output_parser',
'StrOutputParser',
]),
'name': 'ask_question',
}),
'id': 'ask_question',
'type': 'runnable',
}),
dict({
'data': dict({
'id': list([
'langchain',
'schema',
'output_parser',
'StrOutputParser',
]),
'name': 'answer_question',
}),
'id': 'answer_question',
'type': 'runnable',
}),
dict({
'data': '__end__',
'id': '__end__',
'type': 'schema',
}),
]),
})
# ---

View File

@ -7,7 +7,9 @@ from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.output_parsers.xml import XMLOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import Runnable, RunnableConfig
from langchain_core.runnables.graph import Graph
from langchain_core.runnables.graph_mermaid import _escape_node_label
@ -27,6 +29,42 @@ def test_graph_single_runnable(snapshot: SnapshotAssertion) -> None:
assert graph.draw_ascii() == snapshot(name="ascii")
assert graph.draw_mermaid() == snapshot(name="mermaid")
graph.trim_first_node()
first_node = graph.first_node()
assert first_node is not None
assert first_node.data == runnable
graph.trim_last_node()
last_node = graph.last_node()
assert last_node is not None
assert last_node.data == runnable
def test_trim(snapshot: SnapshotAssertion) -> None:
runnable = StrOutputParser()
class Schema(BaseModel):
a: int
graph = Graph()
start = graph.add_node(Schema, id="__start__")
ask = graph.add_node(runnable, id="ask_question")
answer = graph.add_node(runnable, id="answer_question")
end = graph.add_node(Schema, id="__end__")
graph.add_edge(start, ask)
graph.add_edge(ask, answer)
graph.add_edge(answer, ask, conditional=True)
graph.add_edge(answer, end, conditional=True)
assert graph.to_json() == snapshot
assert graph.first_node() is start
assert graph.last_node() is end
# can't trim start or end node
graph.trim_first_node()
assert graph.first_node() is start
graph.trim_last_node()
assert graph.last_node() is end
def test_graph_sequence(snapshot: SnapshotAssertion) -> None:
fake_llm = FakeListLLM(responses=["a"])