mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
core: Fix implementation of trim_first_node/trim_last_node to use exact same definition of first/last node as in the getter methods (#24802)
This commit is contained in:
parent
c2706cfb9e
commit
68ecebf1ec
@ -13,6 +13,7 @@ from typing import (
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypedDict,
|
||||
@ -448,48 +449,27 @@ class Graph:
|
||||
"""Find the single node that is not a target of any edge.
|
||||
If there is no such node, or there are multiple, return None.
|
||||
When drawing the graph, this node would be the origin."""
|
||||
targets = {edge.target for edge in self.edges}
|
||||
found: List[Node] = []
|
||||
for node in self.nodes.values():
|
||||
if node.id not in targets:
|
||||
found.append(node)
|
||||
return found[0] if len(found) == 1 else None
|
||||
return _first_node(self)
|
||||
|
||||
def last_node(self) -> Optional[Node]:
|
||||
"""Find the single node that is not a source of any edge.
|
||||
If there is no such node, or there are multiple, return None.
|
||||
When drawing the graph, this node would be the destination.
|
||||
"""
|
||||
sources = {edge.source for edge in self.edges}
|
||||
found: List[Node] = []
|
||||
for node in self.nodes.values():
|
||||
if node.id not in sources:
|
||||
found.append(node)
|
||||
return found[0] if len(found) == 1 else None
|
||||
When drawing the graph, this node would be the destination."""
|
||||
return _last_node(self)
|
||||
|
||||
def trim_first_node(self) -> None:
|
||||
"""Remove the first node if it exists and has a single outgoing edge,
|
||||
i.e., if removing it would not leave the graph without a "first" node."""
|
||||
first_node = self.first_node()
|
||||
if first_node:
|
||||
if (
|
||||
len(self.nodes) == 1
|
||||
or len([edge for edge in self.edges if edge.source == first_node.id])
|
||||
== 1
|
||||
):
|
||||
self.remove_node(first_node)
|
||||
if first_node and _first_node(self, exclude=[first_node.id]):
|
||||
self.remove_node(first_node)
|
||||
|
||||
def trim_last_node(self) -> None:
|
||||
"""Remove the last node if it exists and has a single incoming edge,
|
||||
i.e., if removing it would not leave the graph without a "last" node."""
|
||||
last_node = self.last_node()
|
||||
if last_node:
|
||||
if (
|
||||
len(self.nodes) == 1
|
||||
or len([edge for edge in self.edges if edge.target == last_node.id])
|
||||
== 1
|
||||
):
|
||||
self.remove_node(last_node)
|
||||
if last_node and _last_node(self, exclude=[last_node.id]):
|
||||
self.remove_node(last_node)
|
||||
|
||||
def draw_ascii(self) -> str:
|
||||
"""Draw the graph as an ASCII art string."""
|
||||
@ -631,3 +611,29 @@ class Graph:
|
||||
background_color=background_color,
|
||||
padding=padding,
|
||||
)
|
||||
|
||||
|
||||
def _first_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]:
|
||||
"""Find the single node that is not a target of any edge.
|
||||
Exclude nodes/sources with ids in the exclude list.
|
||||
If there is no such node, or there are multiple, return None.
|
||||
When drawing the graph, this node would be the origin."""
|
||||
targets = {edge.target for edge in graph.edges if edge.source not in exclude}
|
||||
found: List[Node] = []
|
||||
for node in graph.nodes.values():
|
||||
if node.id not in exclude and node.id not in targets:
|
||||
found.append(node)
|
||||
return found[0] if len(found) == 1 else None
|
||||
|
||||
|
||||
def _last_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]:
|
||||
"""Find the single node that is not a source of any edge.
|
||||
Exclude nodes/targets with ids in the exclude list.
|
||||
If there is no such node, or there are multiple, return None.
|
||||
When drawing the graph, this node would be the destination."""
|
||||
sources = {edge.source for edge in graph.edges if edge.target not in exclude}
|
||||
found: List[Node] = []
|
||||
for node in graph.nodes.values():
|
||||
if node.id not in exclude and node.id not in sources:
|
||||
found.append(node)
|
||||
return found[0] if len(found) == 1 else None
|
||||
|
@ -1097,3 +1097,65 @@
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_trim
|
||||
dict({
|
||||
'edges': list([
|
||||
dict({
|
||||
'source': '__start__',
|
||||
'target': 'ask_question',
|
||||
}),
|
||||
dict({
|
||||
'source': 'ask_question',
|
||||
'target': 'answer_question',
|
||||
}),
|
||||
dict({
|
||||
'conditional': True,
|
||||
'source': 'answer_question',
|
||||
'target': 'ask_question',
|
||||
}),
|
||||
dict({
|
||||
'conditional': True,
|
||||
'source': 'answer_question',
|
||||
'target': '__end__',
|
||||
}),
|
||||
]),
|
||||
'nodes': list([
|
||||
dict({
|
||||
'data': '__start__',
|
||||
'id': '__start__',
|
||||
'type': 'schema',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'id': list([
|
||||
'langchain',
|
||||
'schema',
|
||||
'output_parser',
|
||||
'StrOutputParser',
|
||||
]),
|
||||
'name': 'ask_question',
|
||||
}),
|
||||
'id': 'ask_question',
|
||||
'type': 'runnable',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'id': list([
|
||||
'langchain',
|
||||
'schema',
|
||||
'output_parser',
|
||||
'StrOutputParser',
|
||||
]),
|
||||
'name': 'answer_question',
|
||||
}),
|
||||
'id': 'answer_question',
|
||||
'type': 'runnable',
|
||||
}),
|
||||
dict({
|
||||
'data': '__end__',
|
||||
'id': '__end__',
|
||||
'type': 'schema',
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
|
@ -7,7 +7,9 @@ from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
|
||||
from langchain_core.output_parsers.string import StrOutputParser
|
||||
from langchain_core.output_parsers.xml import XMLOutputParser
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables.base import Runnable, RunnableConfig
|
||||
from langchain_core.runnables.graph import Graph
|
||||
from langchain_core.runnables.graph_mermaid import _escape_node_label
|
||||
|
||||
|
||||
@ -27,6 +29,42 @@ def test_graph_single_runnable(snapshot: SnapshotAssertion) -> None:
|
||||
assert graph.draw_ascii() == snapshot(name="ascii")
|
||||
assert graph.draw_mermaid() == snapshot(name="mermaid")
|
||||
|
||||
graph.trim_first_node()
|
||||
first_node = graph.first_node()
|
||||
assert first_node is not None
|
||||
assert first_node.data == runnable
|
||||
|
||||
graph.trim_last_node()
|
||||
last_node = graph.last_node()
|
||||
assert last_node is not None
|
||||
assert last_node.data == runnable
|
||||
|
||||
|
||||
def test_trim(snapshot: SnapshotAssertion) -> None:
|
||||
runnable = StrOutputParser()
|
||||
|
||||
class Schema(BaseModel):
|
||||
a: int
|
||||
|
||||
graph = Graph()
|
||||
start = graph.add_node(Schema, id="__start__")
|
||||
ask = graph.add_node(runnable, id="ask_question")
|
||||
answer = graph.add_node(runnable, id="answer_question")
|
||||
end = graph.add_node(Schema, id="__end__")
|
||||
graph.add_edge(start, ask)
|
||||
graph.add_edge(ask, answer)
|
||||
graph.add_edge(answer, ask, conditional=True)
|
||||
graph.add_edge(answer, end, conditional=True)
|
||||
|
||||
assert graph.to_json() == snapshot
|
||||
assert graph.first_node() is start
|
||||
assert graph.last_node() is end
|
||||
# can't trim start or end node
|
||||
graph.trim_first_node()
|
||||
assert graph.first_node() is start
|
||||
graph.trim_last_node()
|
||||
assert graph.last_node() is end
|
||||
|
||||
|
||||
def test_graph_sequence(snapshot: SnapshotAssertion) -> None:
|
||||
fake_llm = FakeListLLM(responses=["a"])
|
||||
|
Loading…
Reference in New Issue
Block a user