mirror of https://github.com/hwchase17/langchain
Implement AgentExecutorIterator (#6929)
- Description: Implements a `.iter()` method for the `AgentExecutor` class. This allows hooking into and intercepting intermediate agent steps. - Issue: #6925 - Dependencies: None - Tag maintainer: @vowelparrot @agola11 - Twitter handle: @SlapDron3 @lacicocodes --------- Co-authored-by: Lacico <Lacicocodes@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>pull/8162/head
parent
77bf75c236
commit
961a0e200f
@ -0,0 +1,245 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "feb31cc6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Running Agent as an Iterator\n",
|
||||
"\n",
|
||||
"To demonstrate the `AgentExecutorIterator` functionality, we will set up a problem where an Agent must:\n",
|
||||
"\n",
|
||||
"- Retrieve three prime numbers from a Tool\n",
|
||||
"- Multiply these together. \n",
|
||||
"\n",
|
||||
"In this simple problem we can demonstrate adding some logic to verify intermediate steps by checking whether their outputs are prime."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "8167db11",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import dotenv\n",
|
||||
"import pydantic\n",
|
||||
"from langchain.agents import AgentExecutor, initialize_agent, AgentType\n",
|
||||
"from langchain.schema import AgentFinish\n",
|
||||
"from langchain.agents.tools import Tool\n",
|
||||
"from langchain import LLMMathChain\n",
|
||||
"from langchain.chat_models import ChatOpenAI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "7e41b9e6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Uncomment if you have a .env in root of repo contains OPENAI_API_KEY\n",
|
||||
"# dotenv.load_dotenv(\"../../../../../.env\")\n",
|
||||
"\n",
|
||||
"# need to use GPT-4 here as GPT-3.5 does not understand, however hard you insist, that\n",
|
||||
"# it should use the calculator to perform the final calculation\n",
|
||||
"llm = ChatOpenAI(temperature=0, model=\"gpt-4\")\n",
|
||||
"llm_math_chain = LLMMathChain.from_llm(llm=llm, verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "81e88aa5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Define tools which provide:\n",
|
||||
"- The `n`th prime number (using a small subset for this example) \n",
|
||||
"- The LLMMathChain to act as a calculator"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "86f04b55",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"primes = {998: 7901, 999: 7907, 1000: 7919}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class CalculatorInput(pydantic.BaseModel):\n",
|
||||
" question: str = pydantic.Field()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class PrimeInput(pydantic.BaseModel):\n",
|
||||
" n: int = pydantic.Field()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def is_prime(n: int) -> bool:\n",
|
||||
" if n <= 1 or (n % 2 == 0 and n > 2):\n",
|
||||
" return False\n",
|
||||
" for i in range(3, int(n**0.5) + 1, 2):\n",
|
||||
" if n % i == 0:\n",
|
||||
" return False\n",
|
||||
" return True\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_prime(n: int, primes: dict = primes) -> str:\n",
|
||||
" return str(primes.get(int(n)))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def aget_prime(n: int, primes: dict = primes) -> str:\n",
|
||||
" return str(primes.get(int(n)))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"tools = [\n",
|
||||
" Tool(\n",
|
||||
" name=\"GetPrime\",\n",
|
||||
" func=get_prime,\n",
|
||||
" description=\"A tool that returns the `n`th prime number\",\n",
|
||||
" args_schema=PrimeInput,\n",
|
||||
" coroutine=aget_prime,\n",
|
||||
" ),\n",
|
||||
" Tool.from_function(\n",
|
||||
" func=llm_math_chain.run,\n",
|
||||
" name=\"Calculator\",\n",
|
||||
" description=\"Useful for when you need to compute mathematical expressions\",\n",
|
||||
" args_schema=CalculatorInput,\n",
|
||||
" coroutine=llm_math_chain.arun,\n",
|
||||
" ),\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0e660ee6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Construct the agent. We will use the default agent type here."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "21c775b0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"agent = initialize_agent(\n",
|
||||
" tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a233fe4e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Run the iteration and perform a custom check on certain steps:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "582d61f4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3mI need to find the 998th, 999th and 1000th prime numbers first.\n",
|
||||
"Action: GetPrime\n",
|
||||
"Action Input: 998\u001b[0m\n",
|
||||
"Observation: \u001b[36;1m\u001b[1;3m7901\u001b[0m\n",
|
||||
"Thought:Checking whether 7901 is prime...\n",
|
||||
"Should the agent continue (Y/n)?:\n",
|
||||
"Y\n",
|
||||
"\u001b[32;1m\u001b[1;3mI have the 998th prime number. Now I need to find the 999th prime number.\n",
|
||||
"Action: GetPrime\n",
|
||||
"Action Input: 999\u001b[0m\n",
|
||||
"Observation: \u001b[36;1m\u001b[1;3m7907\u001b[0m\n",
|
||||
"Thought:Checking whether 7907 is prime...\n",
|
||||
"Should the agent continue (Y/n)?:\n",
|
||||
"Y\n",
|
||||
"\u001b[32;1m\u001b[1;3mI have the 999th prime number. Now I need to find the 1000th prime number.\n",
|
||||
"Action: GetPrime\n",
|
||||
"Action Input: 1000\u001b[0m\n",
|
||||
"Observation: \u001b[36;1m\u001b[1;3m7919\u001b[0m\n",
|
||||
"Thought:Checking whether 7919 is prime...\n",
|
||||
"Should the agent continue (Y/n)?:\n",
|
||||
"Y\n",
|
||||
"\u001b[32;1m\u001b[1;3mI have all three prime numbers. Now I need to calculate the product of these numbers.\n",
|
||||
"Action: Calculator\n",
|
||||
"Action Input: 7901 * 7907 * 7919\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"7901 * 7907 * 7919\u001b[32;1m\u001b[1;3m```text\n",
|
||||
"7901 * 7907 * 7919\n",
|
||||
"```\n",
|
||||
"...numexpr.evaluate(\"7901 * 7907 * 7919\")...\n",
|
||||
"\u001b[0m\n",
|
||||
"Answer: \u001b[33;1m\u001b[1;3m494725326233\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\n",
|
||||
"Observation: \u001b[33;1m\u001b[1;3mAnswer: 494725326233\u001b[0m\n",
|
||||
"Thought:Should the agent continue (Y/n)?:\n",
|
||||
"Y\n",
|
||||
"\u001b[32;1m\u001b[1;3mI now know the final answer\n",
|
||||
"Final Answer: 494725326233\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"question = \"What is the product of the 998th, 999th and 1000th prime numbers?\"\n",
|
||||
"\n",
|
||||
"for step in agent.iter(question):\n",
|
||||
" if output := step.get(\"intermediate_step\"):\n",
|
||||
" action, value = output[0]\n",
|
||||
" if action.tool == \"GetPrime\":\n",
|
||||
" print(f\"Checking whether {value} is prime...\")\n",
|
||||
" assert is_prime(int(value))\n",
|
||||
" # Ask user if they want to continue\n",
|
||||
" _continue = input(\"Should the agent continue (Y/n)?:\\n\")\n",
|
||||
" if _continue != \"Y\":\n",
|
||||
" break"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6934ff8e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"language": "python",
|
||||
"name": "venv"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -0,0 +1,501 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio import CancelledError
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManager,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.input import get_color_mapping
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.schema import RUN_KEY, AgentAction, AgentFinish, RunInfo
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.utilities.asyncio import asyncio_timeout
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseAgentExecutorIterator(ABC):
|
||||
@abstractmethod
|
||||
def build_callback_manager(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def rebuild_callback_manager_on_set(
|
||||
setter_method: Callable[..., None]
|
||||
) -> Callable[..., None]:
|
||||
"""Decorator to force setters to rebuild callback mgr"""
|
||||
|
||||
@wraps(setter_method)
|
||||
def wrapper(self: BaseAgentExecutorIterator, *args: Any, **kwargs: Any) -> None:
|
||||
setter_method(self, *args, **kwargs)
|
||||
self.build_callback_manager()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class AgentExecutorIterator(BaseAgentExecutorIterator):
|
||||
def __init__(
|
||||
self,
|
||||
agent_executor: AgentExecutor,
|
||||
inputs: Any,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[list[str]] = None,
|
||||
include_run_info: bool = False,
|
||||
async_: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the AgentExecutorIterator with the given AgentExecutor,
|
||||
inputs, and optional callbacks.
|
||||
"""
|
||||
self._agent_executor = agent_executor
|
||||
self.inputs = inputs
|
||||
self.async_ = async_
|
||||
# build callback manager on tags setter
|
||||
self._callbacks = callbacks
|
||||
self.tags = tags
|
||||
self.include_run_info = include_run_info
|
||||
self.run_manager = None
|
||||
self.reset()
|
||||
|
||||
_callback_manager: Union[AsyncCallbackManager, CallbackManager]
|
||||
_inputs: dict[str, str]
|
||||
_final_outputs: Optional[dict[str, str]]
|
||||
run_manager: Optional[
|
||||
Union[AsyncCallbackManagerForChainRun, CallbackManagerForChainRun]
|
||||
]
|
||||
timeout_manager: Any # TODO: Fix a type here; the shim makes it tricky.
|
||||
|
||||
@property
|
||||
def inputs(self) -> dict[str, str]:
|
||||
return self._inputs
|
||||
|
||||
@inputs.setter
|
||||
def inputs(self, inputs: Any) -> None:
|
||||
self._inputs = self.agent_executor.prep_inputs(inputs)
|
||||
|
||||
@property
|
||||
def callbacks(self) -> Callbacks:
|
||||
return self._callbacks
|
||||
|
||||
@callbacks.setter
|
||||
@rebuild_callback_manager_on_set
|
||||
def callbacks(self, callbacks: Callbacks) -> None:
|
||||
"""When callbacks are changed after __init__, rebuild callback mgr"""
|
||||
self._callbacks = callbacks
|
||||
|
||||
@property
|
||||
def tags(self) -> Optional[List[str]]:
|
||||
return self._tags
|
||||
|
||||
@tags.setter
|
||||
@rebuild_callback_manager_on_set
|
||||
def tags(self, tags: Optional[List[str]]) -> None:
|
||||
"""When tags are changed after __init__, rebuild callback mgr"""
|
||||
self._tags = tags
|
||||
|
||||
@property
|
||||
def agent_executor(self) -> AgentExecutor:
|
||||
return self._agent_executor
|
||||
|
||||
@agent_executor.setter
|
||||
@rebuild_callback_manager_on_set
|
||||
def agent_executor(self, agent_executor: AgentExecutor) -> None:
|
||||
self._agent_executor = agent_executor
|
||||
# force re-prep inputs in case agent_executor's prep_inputs fn changed
|
||||
self.inputs = self.inputs
|
||||
|
||||
@property
|
||||
def callback_manager(self) -> Union[AsyncCallbackManager, CallbackManager]:
|
||||
return self._callback_manager
|
||||
|
||||
def build_callback_manager(self) -> None:
|
||||
"""
|
||||
Create and configure the callback manager based on the current
|
||||
callbacks and tags.
|
||||
"""
|
||||
CallbackMgr: Union[Type[AsyncCallbackManager], Type[CallbackManager]] = (
|
||||
AsyncCallbackManager if self.async_ else CallbackManager
|
||||
)
|
||||
self._callback_manager = CallbackMgr.configure(
|
||||
self.callbacks,
|
||||
self.agent_executor.callbacks,
|
||||
self.agent_executor.verbose,
|
||||
self.tags,
|
||||
self.agent_executor.tags,
|
||||
)
|
||||
|
||||
@property
|
||||
def name_to_tool_map(self) -> dict[str, BaseTool]:
|
||||
return {tool.name: tool for tool in self.agent_executor.tools}
|
||||
|
||||
@property
|
||||
def color_mapping(self) -> dict[str, str]:
|
||||
return get_color_mapping(
|
||||
[tool.name for tool in self.agent_executor.tools],
|
||||
excluded_colors=["green", "red"],
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Reset the iterator to its initial state, clearing intermediate steps,
|
||||
iterations, and time elapsed.
|
||||
"""
|
||||
logger.debug("(Re)setting AgentExecutorIterator to fresh state")
|
||||
self.intermediate_steps: list[tuple[AgentAction, str]] = []
|
||||
self.iterations = 0
|
||||
# maybe better to start these on the first __anext__ call?
|
||||
self.time_elapsed = 0.0
|
||||
self.start_time = time.time()
|
||||
self._final_outputs = None
|
||||
|
||||
def update_iterations(self) -> None:
|
||||
"""
|
||||
Increment the number of iterations and update the time elapsed.
|
||||
"""
|
||||
self.iterations += 1
|
||||
self.time_elapsed = time.time() - self.start_time
|
||||
logger.debug(
|
||||
f"Agent Iterations: {self.iterations} ({self.time_elapsed:.2f}s elapsed)"
|
||||
)
|
||||
|
||||
def raise_stopiteration(self, output: Any) -> NoReturn:
|
||||
"""
|
||||
Raise a StopIteration exception with the given output.
|
||||
"""
|
||||
logger.debug("Chain end: stop iteration")
|
||||
raise StopIteration(output)
|
||||
|
||||
async def raise_stopasynciteration(self, output: Any) -> NoReturn:
|
||||
"""
|
||||
Raise a StopAsyncIteration exception with the given output.
|
||||
Close the timeout context manager.
|
||||
"""
|
||||
logger.debug("Chain end: stop async iteration")
|
||||
if self.timeout_manager is not None:
|
||||
await self.timeout_manager.__aexit__(None, None, None)
|
||||
raise StopAsyncIteration(output)
|
||||
|
||||
@property
|
||||
def final_outputs(self) -> Optional[dict[str, Any]]:
|
||||
return self._final_outputs
|
||||
|
||||
@final_outputs.setter
|
||||
def final_outputs(self, outputs: Optional[Dict[str, Any]]) -> None:
|
||||
# have access to intermediate steps by design in iterator,
|
||||
# so return only outputs may as well always be true.
|
||||
|
||||
self._final_outputs = None
|
||||
if outputs:
|
||||
prepared_outputs: dict[str, Any] = self.agent_executor.prep_outputs(
|
||||
self.inputs, outputs, return_only_outputs=True
|
||||
)
|
||||
if self.include_run_info and self.run_manager is not None:
|
||||
logger.debug("Assign run key")
|
||||
prepared_outputs[RUN_KEY] = RunInfo(run_id=self.run_manager.run_id)
|
||||
self._final_outputs = prepared_outputs
|
||||
|
||||
def __iter__(self: "AgentExecutorIterator") -> "AgentExecutorIterator":
|
||||
logger.debug("Initialising AgentExecutorIterator")
|
||||
self.reset()
|
||||
assert isinstance(self.callback_manager, CallbackManager)
|
||||
self.run_manager = self.callback_manager.on_chain_start(
|
||||
dumpd(self.agent_executor),
|
||||
self.inputs,
|
||||
)
|
||||
return self
|
||||
|
||||
def __aiter__(self) -> "AgentExecutorIterator":
|
||||
"""
|
||||
N.B. __aiter__ must be a normal method, so need to initialise async run manager
|
||||
on first __anext__ call where we can await it
|
||||
"""
|
||||
logger.debug("Initialising AgentExecutorIterator (async)")
|
||||
self.reset()
|
||||
if self.agent_executor.max_execution_time:
|
||||
self.timeout_manager = asyncio_timeout(
|
||||
self.agent_executor.max_execution_time
|
||||
)
|
||||
else:
|
||||
self.timeout_manager = None
|
||||
return self
|
||||
|
||||
def _on_first_step(self) -> None:
|
||||
"""
|
||||
Perform any necessary setup for the first step of the synchronous iterator.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def _on_first_async_step(self) -> None:
|
||||
"""
|
||||
Perform any necessary setup for the first step of the asynchronous iterator.
|
||||
"""
|
||||
# on first step, need to await callback manager and start async timeout ctxmgr
|
||||
if self.iterations == 0:
|
||||
assert isinstance(self.callback_manager, AsyncCallbackManager)
|
||||
self.run_manager = await self.callback_manager.on_chain_start(
|
||||
dumpd(self.agent_executor),
|
||||
self.inputs,
|
||||
)
|
||||
if self.timeout_manager:
|
||||
await self.timeout_manager.__aenter__()
|
||||
|
||||
def __next__(self) -> dict[str, Any]:
|
||||
"""
|
||||
AgentExecutor AgentExecutorIterator
|
||||
__call__ (__iter__ ->) __next__
|
||||
_call <=> _call_next
|
||||
_take_next_step _take_next_step
|
||||
"""
|
||||
# first step
|
||||
if self.iterations == 0:
|
||||
self._on_first_step()
|
||||
# N.B. timeout taken care of by "_should_continue" in sync case
|
||||
try:
|
||||
return self._call_next()
|
||||
except StopIteration:
|
||||
raise
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
if self.run_manager:
|
||||
self.run_manager.on_chain_error(e)
|
||||
raise
|
||||
|
||||
async def __anext__(self) -> dict[str, Any]:
|
||||
"""
|
||||
AgentExecutor AgentExecutorIterator
|
||||
acall (__aiter__ ->) __anext__
|
||||
_acall <=> _acall_next
|
||||
_atake_next_step _atake_next_step
|
||||
"""
|
||||
if self.iterations == 0:
|
||||
await self._on_first_async_step()
|
||||
try:
|
||||
return await self._acall_next()
|
||||
except StopAsyncIteration:
|
||||
raise
|
||||
except (TimeoutError, CancelledError):
|
||||
await self.timeout_manager.__aexit__(None, None, None)
|
||||
self.timeout_manager = None
|
||||
return await self._astop()
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
if self.run_manager:
|
||||
assert isinstance(self.run_manager, AsyncCallbackManagerForChainRun)
|
||||
await self.run_manager.on_chain_error(e)
|
||||
raise
|
||||
|
||||
def _execute_next_step(
|
||||
self, run_manager: Optional[CallbackManagerForChainRun]
|
||||
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
|
||||
"""
|
||||
Execute the next step in the chain using the
|
||||
AgentExecutor's _take_next_step method.
|
||||
"""
|
||||
return self.agent_executor._take_next_step(
|
||||
self.name_to_tool_map,
|
||||
self.color_mapping,
|
||||
self.inputs,
|
||||
self.intermediate_steps,
|
||||
run_manager=run_manager,
|
||||
)
|
||||
|
||||
async def _execute_next_async_step(
|
||||
self, run_manager: Optional[AsyncCallbackManagerForChainRun]
|
||||
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
|
||||
"""
|
||||
Execute the next step in the chain using the
|
||||
AgentExecutor's _atake_next_step method.
|
||||
"""
|
||||
return await self.agent_executor._atake_next_step(
|
||||
self.name_to_tool_map,
|
||||
self.color_mapping,
|
||||
self.inputs,
|
||||
self.intermediate_steps,
|
||||
run_manager=run_manager,
|
||||
)
|
||||
|
||||
def _process_next_step_output(
|
||||
self,
|
||||
next_step_output: Union[AgentFinish, List[Tuple[AgentAction, str]]],
|
||||
run_manager: Optional[CallbackManagerForChainRun],
|
||||
) -> Dict[str, Union[str, List[Tuple[AgentAction, str]]]]:
|
||||
"""
|
||||
Process the output of the next step,
|
||||
handling AgentFinish and tool return cases.
|
||||
"""
|
||||
logger.debug("Processing output of Agent loop step")
|
||||
if isinstance(next_step_output, AgentFinish):
|
||||
logger.debug(
|
||||
"Hit AgentFinish: _return -> on_chain_end -> run final output logic"
|
||||
)
|
||||
output = self.agent_executor._return(
|
||||
next_step_output, self.intermediate_steps, run_manager=run_manager
|
||||
)
|
||||
if self.run_manager:
|
||||
self.run_manager.on_chain_end(output)
|
||||
self.final_outputs = output
|
||||
return output
|
||||
|
||||
self.intermediate_steps.extend(next_step_output)
|
||||
logger.debug("Updated intermediate_steps with step output")
|
||||
|
||||
# Check for tool return
|
||||
if len(next_step_output) == 1:
|
||||
next_step_action = next_step_output[0]
|
||||
tool_return = self.agent_executor._get_tool_return(next_step_action)
|
||||
if tool_return is not None:
|
||||
output = self.agent_executor._return(
|
||||
tool_return, self.intermediate_steps, run_manager=run_manager
|
||||
)
|
||||
if self.run_manager:
|
||||
self.run_manager.on_chain_end(output)
|
||||
self.final_outputs = output
|
||||
return output
|
||||
|
||||
output = {"intermediate_step": next_step_output}
|
||||
return output
|
||||
|
||||
async def _aprocess_next_step_output(
|
||||
self,
|
||||
next_step_output: Union[AgentFinish, List[Tuple[AgentAction, str]]],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun],
|
||||
) -> Dict[str, Union[str, List[Tuple[AgentAction, str]]]]:
|
||||
"""
|
||||
Process the output of the next async step,
|
||||
handling AgentFinish and tool return cases.
|
||||
"""
|
||||
logger.debug("Processing output of async Agent loop step")
|
||||
if isinstance(next_step_output, AgentFinish):
|
||||
logger.debug(
|
||||
"Hit AgentFinish: _areturn -> on_chain_end -> run final output logic"
|
||||
)
|
||||
output = await self.agent_executor._areturn(
|
||||
next_step_output, self.intermediate_steps, run_manager=run_manager
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_chain_end(output)
|
||||
self.final_outputs = output
|
||||
return output
|
||||
|
||||
self.intermediate_steps.extend(next_step_output)
|
||||
logger.debug("Updated intermediate_steps with step output")
|
||||
|
||||
# Check for tool return
|
||||
if len(next_step_output) == 1:
|
||||
next_step_action = next_step_output[0]
|
||||
tool_return = self.agent_executor._get_tool_return(next_step_action)
|
||||
if tool_return is not None:
|
||||
output = await self.agent_executor._areturn(
|
||||
tool_return, self.intermediate_steps, run_manager=run_manager
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_chain_end(output)
|
||||
self.final_outputs = output
|
||||
return output
|
||||
|
||||
output = {"intermediate_step": next_step_output}
|
||||
return output
|
||||
|
||||
def _stop(self) -> dict[str, Any]:
|
||||
"""
|
||||
Stop the iterator and raise a StopIteration exception with the stopped response.
|
||||
"""
|
||||
logger.warning("Stopping agent prematurely due to triggering stop condition")
|
||||
# this manually constructs agent finish with output key
|
||||
output = self.agent_executor.agent.return_stopped_response(
|
||||
self.agent_executor.early_stopping_method,
|
||||
self.intermediate_steps,
|
||||
**self.inputs,
|
||||
)
|
||||
assert (
|
||||
isinstance(self.run_manager, CallbackManagerForChainRun)
|
||||
or self.run_manager is None
|
||||
)
|
||||
returned_output = self.agent_executor._return(
|
||||
output, self.intermediate_steps, run_manager=self.run_manager
|
||||
)
|
||||
self.final_outputs = returned_output
|
||||
return returned_output
|
||||
|
||||
async def _astop(self) -> dict[str, Any]:
|
||||
"""
|
||||
Stop the async iterator and raise a StopAsyncIteration exception with
|
||||
the stopped response.
|
||||
"""
|
||||
logger.warning("Stopping agent prematurely due to triggering stop condition")
|
||||
output = self.agent_executor.agent.return_stopped_response(
|
||||
self.agent_executor.early_stopping_method,
|
||||
self.intermediate_steps,
|
||||
**self.inputs,
|
||||
)
|
||||
assert (
|
||||
isinstance(self.run_manager, AsyncCallbackManagerForChainRun)
|
||||
or self.run_manager is None
|
||||
)
|
||||
returned_output = await self.agent_executor._areturn(
|
||||
output, self.intermediate_steps, run_manager=self.run_manager
|
||||
)
|
||||
self.final_outputs = returned_output
|
||||
return returned_output
|
||||
|
||||
def _call_next(self) -> dict[str, Any]:
|
||||
"""
|
||||
Perform a single iteration of the synchronous AgentExecutorIterator.
|
||||
"""
|
||||
# final output already reached: stopiteration (final output)
|
||||
if self.final_outputs is not None:
|
||||
self.raise_stopiteration(self.final_outputs)
|
||||
# timeout/max iterations: stopiteration (stopped response)
|
||||
if not self.agent_executor._should_continue(self.iterations, self.time_elapsed):
|
||||
return self._stop()
|
||||
assert (
|
||||
isinstance(self.run_manager, CallbackManagerForChainRun)
|
||||
or self.run_manager is None
|
||||
)
|
||||
next_step_output = self._execute_next_step(self.run_manager)
|
||||
output = self._process_next_step_output(next_step_output, self.run_manager)
|
||||
self.update_iterations()
|
||||
return output
|
||||
|
||||
async def _acall_next(self) -> dict[str, Any]:
|
||||
"""
|
||||
Perform a single iteration of the asynchronous AgentExecutorIterator.
|
||||
"""
|
||||
# final output already reached: stopiteration (final output)
|
||||
if self.final_outputs is not None:
|
||||
await self.raise_stopasynciteration(self.final_outputs)
|
||||
# timeout/max iterations: stopiteration (stopped response)
|
||||
if not self.agent_executor._should_continue(self.iterations, self.time_elapsed):
|
||||
return await self._astop()
|
||||
assert (
|
||||
isinstance(self.run_manager, AsyncCallbackManagerForChainRun)
|
||||
or self.run_manager is None
|
||||
)
|
||||
next_step_output = await self._execute_next_async_step(self.run_manager)
|
||||
output = await self._aprocess_next_step_output(
|
||||
next_step_output, self.run_manager
|
||||
)
|
||||
self.update_iterations()
|
||||
return output
|
@ -0,0 +1,360 @@
|
||||
import pytest
|
||||
|
||||
from langchain.agents import (
|
||||
AgentExecutor,
|
||||
AgentExecutorIterator,
|
||||
AgentType,
|
||||
initialize_agent,
|
||||
)
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.llms import FakeListLLM
|
||||
from tests.unit_tests.agents.test_agent import _get_agent
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
def test_agent_iterator_bad_action() -> None:
|
||||
"""Test react chain iterator when bad action given."""
|
||||
agent = _get_agent()
|
||||
agent_iter = agent.iter(inputs="when was langchain made")
|
||||
|
||||
outputs = []
|
||||
for step in agent_iter:
|
||||
outputs.append(step)
|
||||
|
||||
assert isinstance(outputs[-1], dict)
|
||||
assert outputs[-1]["output"] == "curses foiled again"
|
||||
|
||||
|
||||
def test_agent_iterator_stopped_early() -> None:
|
||||
"""
|
||||
Test react chain iterator when max iterations or
|
||||
max execution time is exceeded.
|
||||
"""
|
||||
# iteration limit
|
||||
agent = _get_agent(max_iterations=1)
|
||||
agent_iter = agent.iter(inputs="when was langchain made")
|
||||
|
||||
outputs = []
|
||||
for step in agent_iter:
|
||||
outputs.append(step)
|
||||
# NOTE: we don't use agent.run like in the test for the regular agent executor,
|
||||
# so the dict structure for outputs stays intact
|
||||
assert isinstance(outputs[-1], dict)
|
||||
assert (
|
||||
outputs[-1]["output"] == "Agent stopped due to iteration limit or time limit."
|
||||
)
|
||||
|
||||
# execution time limit
|
||||
agent = _get_agent(max_execution_time=1e-5)
|
||||
agent_iter = agent.iter(inputs="when was langchain made")
|
||||
|
||||
outputs = []
|
||||
for step in agent_iter:
|
||||
outputs.append(step)
|
||||
assert isinstance(outputs[-1], dict)
|
||||
assert (
|
||||
outputs[-1]["output"] == "Agent stopped due to iteration limit or time limit."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_async_iterator_stopped_early() -> None:
|
||||
"""
|
||||
Test react chain async iterator when max iterations or
|
||||
max execution time is exceeded.
|
||||
"""
|
||||
# iteration limit
|
||||
agent = _get_agent(max_iterations=1)
|
||||
agent_async_iter = agent.iter(inputs="when was langchain made", async_=True)
|
||||
|
||||
outputs = []
|
||||
assert isinstance(agent_async_iter, AgentExecutorIterator)
|
||||
async for step in agent_async_iter:
|
||||
outputs.append(step)
|
||||
|
||||
assert isinstance(outputs[-1], dict)
|
||||
assert (
|
||||
outputs[-1]["output"] == "Agent stopped due to iteration limit or time limit."
|
||||
)
|
||||
|
||||
# execution time limit
|
||||
agent = _get_agent(max_execution_time=1e-5)
|
||||
agent_async_iter = agent.iter(inputs="when was langchain made", async_=True)
|
||||
assert isinstance(agent_async_iter, AgentExecutorIterator)
|
||||
|
||||
outputs = []
|
||||
async for step in agent_async_iter:
|
||||
outputs.append(step)
|
||||
|
||||
assert (
|
||||
outputs[-1]["output"] == "Agent stopped due to iteration limit or time limit."
|
||||
)
|
||||
|
||||
|
||||
def test_agent_iterator_with_callbacks() -> None:
|
||||
"""Test react chain iterator with callbacks by setting verbose globally."""
|
||||
handler1 = FakeCallbackHandler()
|
||||
handler2 = FakeCallbackHandler()
|
||||
bad_action_name = "BadAction"
|
||||
responses = [
|
||||
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment",
|
||||
"Oh well\nFinal Answer: curses foiled again",
|
||||
]
|
||||
fake_llm = FakeListLLM(cache=False, responses=responses, callbacks=[handler2])
|
||||
|
||||
tools = [
|
||||
Tool(
|
||||
name="Search",
|
||||
func=lambda x: x,
|
||||
description="Useful for searching",
|
||||
),
|
||||
Tool(
|
||||
name="Lookup",
|
||||
func=lambda x: x,
|
||||
description="Useful for looking up things in a table",
|
||||
),
|
||||
]
|
||||
|
||||
agent = initialize_agent(
|
||||
tools, fake_llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
agent_iter = agent.iter(inputs="when was langchain made", callbacks=[handler1])
|
||||
|
||||
outputs = []
|
||||
for step in agent_iter:
|
||||
outputs.append(step)
|
||||
assert isinstance(outputs[-1], dict)
|
||||
assert outputs[-1]["output"] == "curses foiled again"
|
||||
|
||||
# 1 top level chain run runs, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
||||
assert handler1.chain_starts == handler1.chain_ends == 3
|
||||
assert handler1.llm_starts == handler1.llm_ends == 2
|
||||
assert handler1.tool_starts == 1
|
||||
assert handler1.tool_ends == 1
|
||||
# 1 extra agent action
|
||||
assert handler1.starts == 7
|
||||
# 1 extra agent end
|
||||
assert handler1.ends == 7
|
||||
print("h:", handler1)
|
||||
assert handler1.errors == 0
|
||||
# during LLMChain
|
||||
assert handler1.text == 2
|
||||
|
||||
assert handler2.llm_starts == 2
|
||||
assert handler2.llm_ends == 2
|
||||
assert (
|
||||
handler2.chain_starts
|
||||
== handler2.tool_starts
|
||||
== handler2.tool_ends
|
||||
== handler2.chain_ends
|
||||
== 0
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_async_iterator_with_callbacks() -> None:
|
||||
"""Test react chain async iterator with callbacks by setting verbose globally."""
|
||||
handler1 = FakeCallbackHandler()
|
||||
handler2 = FakeCallbackHandler()
|
||||
|
||||
bad_action_name = "BadAction"
|
||||
responses = [
|
||||
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment",
|
||||
"Oh well\nFinal Answer: curses foiled again",
|
||||
]
|
||||
fake_llm = FakeListLLM(cache=False, responses=responses, callbacks=[handler2])
|
||||
|
||||
tools = [
|
||||
Tool(
|
||||
name="Search",
|
||||
func=lambda x: x,
|
||||
description="Useful for searching",
|
||||
),
|
||||
Tool(
|
||||
name="Lookup",
|
||||
func=lambda x: x,
|
||||
description="Useful for looking up things in a table",
|
||||
),
|
||||
]
|
||||
|
||||
agent = initialize_agent(
|
||||
tools, fake_llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
agent_async_iter = agent.iter(
|
||||
inputs="when was langchain made",
|
||||
callbacks=[handler1],
|
||||
async_=True,
|
||||
)
|
||||
assert isinstance(agent_async_iter, AgentExecutorIterator)
|
||||
|
||||
outputs = []
|
||||
async for step in agent_async_iter:
|
||||
outputs.append(step)
|
||||
|
||||
assert outputs[-1]["output"] == "curses foiled again"
|
||||
|
||||
# 1 top level chain run runs, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
||||
assert handler1.chain_starts == handler1.chain_ends == 3
|
||||
assert handler1.llm_starts == handler1.llm_ends == 2
|
||||
assert handler1.tool_starts == 1
|
||||
assert handler1.tool_ends == 1
|
||||
# 1 extra agent action
|
||||
assert handler1.starts == 7
|
||||
# 1 extra agent end
|
||||
assert handler1.ends == 7
|
||||
assert handler1.errors == 0
|
||||
# during LLMChain
|
||||
assert handler1.text == 2
|
||||
|
||||
assert handler2.llm_starts == 2
|
||||
assert handler2.llm_ends == 2
|
||||
assert (
|
||||
handler2.chain_starts
|
||||
== handler2.tool_starts
|
||||
== handler2.tool_ends
|
||||
== handler2.chain_ends
|
||||
== 0
|
||||
)
|
||||
|
||||
|
||||
def test_agent_iterator_properties_and_setters() -> None:
|
||||
"""Test properties and setters of AgentExecutorIterator."""
|
||||
agent = _get_agent()
|
||||
agent.tags = None
|
||||
agent_iter = agent.iter(inputs="when was langchain made")
|
||||
|
||||
assert isinstance(agent_iter, AgentExecutorIterator)
|
||||
assert isinstance(agent_iter.inputs, dict)
|
||||
assert isinstance(agent_iter.callbacks, type(None))
|
||||
assert isinstance(agent_iter.tags, type(None))
|
||||
assert isinstance(agent_iter.agent_executor, AgentExecutor)
|
||||
|
||||
agent_iter.inputs = "New input" # type: ignore
|
||||
assert isinstance(agent_iter.inputs, dict)
|
||||
|
||||
agent_iter.callbacks = [FakeCallbackHandler()]
|
||||
assert isinstance(agent_iter.callbacks, list)
|
||||
|
||||
agent_iter.tags = ["test"]
|
||||
assert isinstance(agent_iter.tags, list)
|
||||
|
||||
new_agent = _get_agent()
|
||||
agent_iter.agent_executor = new_agent
|
||||
assert isinstance(agent_iter.agent_executor, AgentExecutor)
|
||||
|
||||
|
||||
def test_agent_iterator_reset() -> None:
|
||||
"""Test reset functionality of AgentExecutorIterator."""
|
||||
agent = _get_agent()
|
||||
agent_iter = agent.iter(inputs="when was langchain made")
|
||||
assert isinstance(agent_iter, AgentExecutorIterator)
|
||||
|
||||
# Perform one iteration
|
||||
next(agent_iter)
|
||||
|
||||
# Check if properties are updated
|
||||
assert agent_iter.iterations == 1
|
||||
assert agent_iter.time_elapsed > 0.0
|
||||
assert agent_iter.intermediate_steps
|
||||
|
||||
# Reset the iterator
|
||||
agent_iter.reset()
|
||||
|
||||
# Check if properties are reset
|
||||
assert agent_iter.iterations == 0
|
||||
assert agent_iter.time_elapsed == 0.0
|
||||
assert not agent_iter.intermediate_steps
|
||||
|
||||
|
||||
def test_agent_iterator_output_structure() -> None:
|
||||
"""Test the output structure of AgentExecutorIterator."""
|
||||
agent = _get_agent()
|
||||
agent_iter = agent.iter(inputs="when was langchain made")
|
||||
|
||||
for step in agent_iter:
|
||||
assert isinstance(step, dict)
|
||||
if "intermediate_step" in step:
|
||||
assert isinstance(step["intermediate_step"], list)
|
||||
elif "output" in step:
|
||||
assert isinstance(step["output"], str)
|
||||
else:
|
||||
assert False, "Unexpected output structure"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_async_iterator_output_structure() -> None:
|
||||
"""Test the async output structure of AgentExecutorIterator."""
|
||||
agent = _get_agent()
|
||||
agent_async_iter = agent.iter(inputs="when was langchain made", async_=True)
|
||||
|
||||
assert isinstance(agent_async_iter, AgentExecutorIterator)
|
||||
async for step in agent_async_iter:
|
||||
assert isinstance(step, dict)
|
||||
if "intermediate_step" in step:
|
||||
assert isinstance(step["intermediate_step"], list)
|
||||
elif "output" in step:
|
||||
assert isinstance(step["output"], str)
|
||||
else:
|
||||
assert False, "Unexpected output structure"
|
||||
|
||||
|
||||
def test_agent_iterator_empty_input() -> None:
|
||||
"""Test AgentExecutorIterator with empty input."""
|
||||
agent = _get_agent()
|
||||
agent_iter = agent.iter(inputs="")
|
||||
|
||||
outputs = []
|
||||
for step in agent_iter:
|
||||
outputs.append(step)
|
||||
|
||||
assert isinstance(outputs[-1], dict)
|
||||
assert outputs[-1]["output"] # Check if there is an output
|
||||
|
||||
|
||||
def test_agent_iterator_custom_stopping_condition() -> None:
|
||||
"""Test AgentExecutorIterator with a custom stopping condition."""
|
||||
agent = _get_agent()
|
||||
|
||||
class CustomAgentExecutorIterator(AgentExecutorIterator):
|
||||
def _should_continue(self) -> bool:
|
||||
return self.iterations < 2 # Custom stopping condition
|
||||
|
||||
agent_iter = CustomAgentExecutorIterator(agent, inputs="when was langchain made")
|
||||
|
||||
outputs = []
|
||||
for step in agent_iter:
|
||||
outputs.append(step)
|
||||
|
||||
assert len(outputs) == 2 # Check if the custom stopping condition is respected
|
||||
|
||||
|
||||
def test_agent_iterator_failing_tool() -> None:
|
||||
"""Test AgentExecutorIterator with a tool that raises an exception."""
|
||||
|
||||
# Get agent for testing.
|
||||
bad_action_name = "FailingTool"
|
||||
responses = [
|
||||
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment",
|
||||
"Oh well\nFinal Answer: curses foiled again",
|
||||
]
|
||||
fake_llm = FakeListLLM(responses=responses)
|
||||
|
||||
tools = [
|
||||
Tool(
|
||||
name="FailingTool",
|
||||
func=lambda x: 1 / 0, # This tool will raise a ZeroDivisionError
|
||||
description="A tool that fails",
|
||||
),
|
||||
]
|
||||
|
||||
agent = initialize_agent(
|
||||
tools, fake_llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
|
||||
agent_iter = agent.iter(inputs="when was langchain made")
|
||||
assert isinstance(agent_iter, AgentExecutorIterator)
|
||||
# initialise iterator
|
||||
iter(agent_iter)
|
||||
|
||||
with pytest.raises(ZeroDivisionError):
|
||||
next(agent_iter)
|
Loading…
Reference in New Issue