Implement AgentExecutorIterator (#6929)

- Description: Implements a `.iter()` method for the `AgentExecutor`
class. This allows hooking into and intercepting intermediate agent
steps.
  - Issue: #6925 
  - Dependencies: None
  - Tag maintainer: @vowelparrot @agola11 
  - Twitter handle: @SlapDron3 @lacicocodes

---------

Co-authored-by: Lacico <Lacicocodes@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
SlapDrone 2023-07-24 03:00:22 +02:00 committed by GitHub
parent 77bf75c236
commit 961a0e200f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 1135 additions and 1 deletions

View File

@ -0,0 +1,245 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "feb31cc6",
"metadata": {},
"source": [
"# Running Agent as an Iterator\n",
"\n",
"To demonstrate the `AgentExecutorIterator` functionality, we will set up a problem where an Agent must:\n",
"\n",
"- Retrieve three prime numbers from a Tool\n",
"- Multiply these together. \n",
"\n",
"In this simple problem we can demonstrate adding some logic to verify intermediate steps by checking whether their outputs are prime."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "8167db11",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"import dotenv\n",
"import pydantic\n",
"from langchain.agents import AgentExecutor, initialize_agent, AgentType\n",
"from langchain.schema import AgentFinish\n",
"from langchain.agents.tools import Tool\n",
"from langchain import LLMMathChain\n",
"from langchain.chat_models import ChatOpenAI"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7e41b9e6",
"metadata": {},
"outputs": [],
"source": [
"# Uncomment if you have a .env in root of repo contains OPENAI_API_KEY\n",
"# dotenv.load_dotenv(\"../../../../../.env\")\n",
"\n",
"# need to use GPT-4 here as GPT-3.5 does not understand, however hard you insist, that\n",
"# it should use the calculator to perform the final calculation\n",
"llm = ChatOpenAI(temperature=0, model=\"gpt-4\")\n",
"llm_math_chain = LLMMathChain.from_llm(llm=llm, verbose=True)"
]
},
{
"cell_type": "markdown",
"id": "81e88aa5",
"metadata": {},
"source": [
"Define tools which provide:\n",
"- The `n`th prime number (using a small subset for this example) \n",
"- The LLMMathChain to act as a calculator"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "86f04b55",
"metadata": {},
"outputs": [],
"source": [
"primes = {998: 7901, 999: 7907, 1000: 7919}\n",
"\n",
"\n",
"class CalculatorInput(pydantic.BaseModel):\n",
" question: str = pydantic.Field()\n",
"\n",
"\n",
"class PrimeInput(pydantic.BaseModel):\n",
" n: int = pydantic.Field()\n",
"\n",
"\n",
"def is_prime(n: int) -> bool:\n",
" if n <= 1 or (n % 2 == 0 and n > 2):\n",
" return False\n",
" for i in range(3, int(n**0.5) + 1, 2):\n",
" if n % i == 0:\n",
" return False\n",
" return True\n",
"\n",
"\n",
"def get_prime(n: int, primes: dict = primes) -> str:\n",
" return str(primes.get(int(n)))\n",
"\n",
"\n",
"async def aget_prime(n: int, primes: dict = primes) -> str:\n",
" return str(primes.get(int(n)))\n",
"\n",
"\n",
"tools = [\n",
" Tool(\n",
" name=\"GetPrime\",\n",
" func=get_prime,\n",
" description=\"A tool that returns the `n`th prime number\",\n",
" args_schema=PrimeInput,\n",
" coroutine=aget_prime,\n",
" ),\n",
" Tool.from_function(\n",
" func=llm_math_chain.run,\n",
" name=\"Calculator\",\n",
" description=\"Useful for when you need to compute mathematical expressions\",\n",
" args_schema=CalculatorInput,\n",
" coroutine=llm_math_chain.arun,\n",
" ),\n",
"]"
]
},
{
"cell_type": "markdown",
"id": "0e660ee6",
"metadata": {},
"source": [
"Construct the agent. We will use the default agent type here."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "21c775b0",
"metadata": {},
"outputs": [],
"source": [
"agent = initialize_agent(\n",
" tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True\n",
")"
]
},
{
"cell_type": "markdown",
"id": "a233fe4e",
"metadata": {},
"source": [
"Run the iteration and perform a custom check on certain steps:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "582d61f4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mI need to find the 998th, 999th and 1000th prime numbers first.\n",
"Action: GetPrime\n",
"Action Input: 998\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m7901\u001b[0m\n",
"Thought:Checking whether 7901 is prime...\n",
"Should the agent continue (Y/n)?:\n",
"Y\n",
"\u001b[32;1m\u001b[1;3mI have the 998th prime number. Now I need to find the 999th prime number.\n",
"Action: GetPrime\n",
"Action Input: 999\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m7907\u001b[0m\n",
"Thought:Checking whether 7907 is prime...\n",
"Should the agent continue (Y/n)?:\n",
"Y\n",
"\u001b[32;1m\u001b[1;3mI have the 999th prime number. Now I need to find the 1000th prime number.\n",
"Action: GetPrime\n",
"Action Input: 1000\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m7919\u001b[0m\n",
"Thought:Checking whether 7919 is prime...\n",
"Should the agent continue (Y/n)?:\n",
"Y\n",
"\u001b[32;1m\u001b[1;3mI have all three prime numbers. Now I need to calculate the product of these numbers.\n",
"Action: Calculator\n",
"Action Input: 7901 * 7907 * 7919\u001b[0m\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"7901 * 7907 * 7919\u001b[32;1m\u001b[1;3m```text\n",
"7901 * 7907 * 7919\n",
"```\n",
"...numexpr.evaluate(\"7901 * 7907 * 7919\")...\n",
"\u001b[0m\n",
"Answer: \u001b[33;1m\u001b[1;3m494725326233\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"Observation: \u001b[33;1m\u001b[1;3mAnswer: 494725326233\u001b[0m\n",
"Thought:Should the agent continue (Y/n)?:\n",
"Y\n",
"\u001b[32;1m\u001b[1;3mI now know the final answer\n",
"Final Answer: 494725326233\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
}
],
"source": [
"question = \"What is the product of the 998th, 999th and 1000th prime numbers?\"\n",
"\n",
"for step in agent.iter(question):\n",
" if output := step.get(\"intermediate_step\"):\n",
" action, value = output[0]\n",
" if action.tool == \"GetPrime\":\n",
" print(f\"Checking whether {value} is prime...\")\n",
" assert is_prime(int(value))\n",
" # Ask user if they want to continue\n",
" _continue = input(\"Should the agent continue (Y/n)?:\\n\")\n",
" if _continue != \"Y\":\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6934ff8e",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "venv"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -7,6 +7,7 @@ from langchain.agents.agent import (
BaseSingleActionAgent, BaseSingleActionAgent,
LLMSingleActionAgent, LLMSingleActionAgent,
) )
from langchain.agents.agent_iterator import AgentExecutorIterator
from langchain.agents.agent_toolkits import ( from langchain.agents.agent_toolkits import (
create_csv_agent, create_csv_agent,
create_json_agent, create_json_agent,
@ -42,6 +43,7 @@ from langchain.agents.tools import Tool, tool
__all__ = [ __all__ = [
"Agent", "Agent",
"AgentExecutor", "AgentExecutor",
"AgentExecutorIterator",
"AgentOutputParser", "AgentOutputParser",
"AgentType", "AgentType",
"BaseMultiActionAgent", "BaseMultiActionAgent",

View File

@ -12,6 +12,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import yaml import yaml
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
from langchain.agents.agent_iterator import AgentExecutorIterator
from langchain.agents.agent_types import AgentType from langchain.agents.agent_types import AgentType
from langchain.agents.tools import InvalidTool from langchain.agents.tools import InvalidTool
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
@ -732,6 +733,24 @@ s
"""Save the underlying agent.""" """Save the underlying agent."""
return self.agent.save(file_path) return self.agent.save(file_path)
def iter(
self,
inputs: Any,
callbacks: Callbacks = None,
*,
include_run_info: bool = False,
async_: bool = False,
) -> AgentExecutorIterator:
"""Enables iteration over steps taken to reach final output."""
return AgentExecutorIterator(
self,
inputs,
callbacks,
tags=self.tags,
include_run_info=include_run_info,
async_=async_,
)
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> List[str]:
"""Return the input keys. """Return the input keys.

View File

@ -0,0 +1,501 @@
from __future__ import annotations
import logging
import time
from abc import ABC, abstractmethod
from asyncio import CancelledError
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
NoReturn,
Optional,
Tuple,
Type,
Union,
)
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForChainRun,
CallbackManager,
CallbackManagerForChainRun,
Callbacks,
)
from langchain.input import get_color_mapping
from langchain.load.dump import dumpd
from langchain.schema import RUN_KEY, AgentAction, AgentFinish, RunInfo
from langchain.tools import BaseTool
from langchain.utilities.asyncio import asyncio_timeout
if TYPE_CHECKING:
from langchain.agents.agent import AgentExecutor
logger = logging.getLogger(__name__)
class BaseAgentExecutorIterator(ABC):
@abstractmethod
def build_callback_manager(self) -> None:
pass
def rebuild_callback_manager_on_set(
setter_method: Callable[..., None]
) -> Callable[..., None]:
"""Decorator to force setters to rebuild callback mgr"""
@wraps(setter_method)
def wrapper(self: BaseAgentExecutorIterator, *args: Any, **kwargs: Any) -> None:
setter_method(self, *args, **kwargs)
self.build_callback_manager()
return wrapper
class AgentExecutorIterator(BaseAgentExecutorIterator):
def __init__(
self,
agent_executor: AgentExecutor,
inputs: Any,
callbacks: Callbacks = None,
*,
tags: Optional[list[str]] = None,
include_run_info: bool = False,
async_: bool = False,
):
"""
Initialize the AgentExecutorIterator with the given AgentExecutor,
inputs, and optional callbacks.
"""
self._agent_executor = agent_executor
self.inputs = inputs
self.async_ = async_
# build callback manager on tags setter
self._callbacks = callbacks
self.tags = tags
self.include_run_info = include_run_info
self.run_manager = None
self.reset()
_callback_manager: Union[AsyncCallbackManager, CallbackManager]
_inputs: dict[str, str]
_final_outputs: Optional[dict[str, str]]
run_manager: Optional[
Union[AsyncCallbackManagerForChainRun, CallbackManagerForChainRun]
]
timeout_manager: Any # TODO: Fix a type here; the shim makes it tricky.
@property
def inputs(self) -> dict[str, str]:
return self._inputs
@inputs.setter
def inputs(self, inputs: Any) -> None:
self._inputs = self.agent_executor.prep_inputs(inputs)
@property
def callbacks(self) -> Callbacks:
return self._callbacks
@callbacks.setter
@rebuild_callback_manager_on_set
def callbacks(self, callbacks: Callbacks) -> None:
"""When callbacks are changed after __init__, rebuild callback mgr"""
self._callbacks = callbacks
@property
def tags(self) -> Optional[List[str]]:
return self._tags
@tags.setter
@rebuild_callback_manager_on_set
def tags(self, tags: Optional[List[str]]) -> None:
"""When tags are changed after __init__, rebuild callback mgr"""
self._tags = tags
@property
def agent_executor(self) -> AgentExecutor:
return self._agent_executor
@agent_executor.setter
@rebuild_callback_manager_on_set
def agent_executor(self, agent_executor: AgentExecutor) -> None:
self._agent_executor = agent_executor
# force re-prep inputs in case agent_executor's prep_inputs fn changed
self.inputs = self.inputs
@property
def callback_manager(self) -> Union[AsyncCallbackManager, CallbackManager]:
return self._callback_manager
def build_callback_manager(self) -> None:
"""
Create and configure the callback manager based on the current
callbacks and tags.
"""
CallbackMgr: Union[Type[AsyncCallbackManager], Type[CallbackManager]] = (
AsyncCallbackManager if self.async_ else CallbackManager
)
self._callback_manager = CallbackMgr.configure(
self.callbacks,
self.agent_executor.callbacks,
self.agent_executor.verbose,
self.tags,
self.agent_executor.tags,
)
@property
def name_to_tool_map(self) -> dict[str, BaseTool]:
return {tool.name: tool for tool in self.agent_executor.tools}
@property
def color_mapping(self) -> dict[str, str]:
return get_color_mapping(
[tool.name for tool in self.agent_executor.tools],
excluded_colors=["green", "red"],
)
def reset(self) -> None:
"""
Reset the iterator to its initial state, clearing intermediate steps,
iterations, and time elapsed.
"""
logger.debug("(Re)setting AgentExecutorIterator to fresh state")
self.intermediate_steps: list[tuple[AgentAction, str]] = []
self.iterations = 0
# maybe better to start these on the first __anext__ call?
self.time_elapsed = 0.0
self.start_time = time.time()
self._final_outputs = None
def update_iterations(self) -> None:
"""
Increment the number of iterations and update the time elapsed.
"""
self.iterations += 1
self.time_elapsed = time.time() - self.start_time
logger.debug(
f"Agent Iterations: {self.iterations} ({self.time_elapsed:.2f}s elapsed)"
)
def raise_stopiteration(self, output: Any) -> NoReturn:
"""
Raise a StopIteration exception with the given output.
"""
logger.debug("Chain end: stop iteration")
raise StopIteration(output)
async def raise_stopasynciteration(self, output: Any) -> NoReturn:
"""
Raise a StopAsyncIteration exception with the given output.
Close the timeout context manager.
"""
logger.debug("Chain end: stop async iteration")
if self.timeout_manager is not None:
await self.timeout_manager.__aexit__(None, None, None)
raise StopAsyncIteration(output)
@property
def final_outputs(self) -> Optional[dict[str, Any]]:
return self._final_outputs
@final_outputs.setter
def final_outputs(self, outputs: Optional[Dict[str, Any]]) -> None:
# have access to intermediate steps by design in iterator,
# so return only outputs may as well always be true.
self._final_outputs = None
if outputs:
prepared_outputs: dict[str, Any] = self.agent_executor.prep_outputs(
self.inputs, outputs, return_only_outputs=True
)
if self.include_run_info and self.run_manager is not None:
logger.debug("Assign run key")
prepared_outputs[RUN_KEY] = RunInfo(run_id=self.run_manager.run_id)
self._final_outputs = prepared_outputs
def __iter__(self: "AgentExecutorIterator") -> "AgentExecutorIterator":
logger.debug("Initialising AgentExecutorIterator")
self.reset()
assert isinstance(self.callback_manager, CallbackManager)
self.run_manager = self.callback_manager.on_chain_start(
dumpd(self.agent_executor),
self.inputs,
)
return self
def __aiter__(self) -> "AgentExecutorIterator":
"""
N.B. __aiter__ must be a normal method, so need to initialise async run manager
on first __anext__ call where we can await it
"""
logger.debug("Initialising AgentExecutorIterator (async)")
self.reset()
if self.agent_executor.max_execution_time:
self.timeout_manager = asyncio_timeout(
self.agent_executor.max_execution_time
)
else:
self.timeout_manager = None
return self
def _on_first_step(self) -> None:
"""
Perform any necessary setup for the first step of the synchronous iterator.
"""
pass
async def _on_first_async_step(self) -> None:
"""
Perform any necessary setup for the first step of the asynchronous iterator.
"""
# on first step, need to await callback manager and start async timeout ctxmgr
if self.iterations == 0:
assert isinstance(self.callback_manager, AsyncCallbackManager)
self.run_manager = await self.callback_manager.on_chain_start(
dumpd(self.agent_executor),
self.inputs,
)
if self.timeout_manager:
await self.timeout_manager.__aenter__()
def __next__(self) -> dict[str, Any]:
"""
AgentExecutor AgentExecutorIterator
__call__ (__iter__ ->) __next__
_call <=> _call_next
_take_next_step _take_next_step
"""
# first step
if self.iterations == 0:
self._on_first_step()
# N.B. timeout taken care of by "_should_continue" in sync case
try:
return self._call_next()
except StopIteration:
raise
except (KeyboardInterrupt, Exception) as e:
if self.run_manager:
self.run_manager.on_chain_error(e)
raise
async def __anext__(self) -> dict[str, Any]:
"""
AgentExecutor AgentExecutorIterator
acall (__aiter__ ->) __anext__
_acall <=> _acall_next
_atake_next_step _atake_next_step
"""
if self.iterations == 0:
await self._on_first_async_step()
try:
return await self._acall_next()
except StopAsyncIteration:
raise
except (TimeoutError, CancelledError):
await self.timeout_manager.__aexit__(None, None, None)
self.timeout_manager = None
return await self._astop()
except (KeyboardInterrupt, Exception) as e:
if self.run_manager:
assert isinstance(self.run_manager, AsyncCallbackManagerForChainRun)
await self.run_manager.on_chain_error(e)
raise
def _execute_next_step(
self, run_manager: Optional[CallbackManagerForChainRun]
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
"""
Execute the next step in the chain using the
AgentExecutor's _take_next_step method.
"""
return self.agent_executor._take_next_step(
self.name_to_tool_map,
self.color_mapping,
self.inputs,
self.intermediate_steps,
run_manager=run_manager,
)
async def _execute_next_async_step(
self, run_manager: Optional[AsyncCallbackManagerForChainRun]
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
"""
Execute the next step in the chain using the
AgentExecutor's _atake_next_step method.
"""
return await self.agent_executor._atake_next_step(
self.name_to_tool_map,
self.color_mapping,
self.inputs,
self.intermediate_steps,
run_manager=run_manager,
)
def _process_next_step_output(
self,
next_step_output: Union[AgentFinish, List[Tuple[AgentAction, str]]],
run_manager: Optional[CallbackManagerForChainRun],
) -> Dict[str, Union[str, List[Tuple[AgentAction, str]]]]:
"""
Process the output of the next step,
handling AgentFinish and tool return cases.
"""
logger.debug("Processing output of Agent loop step")
if isinstance(next_step_output, AgentFinish):
logger.debug(
"Hit AgentFinish: _return -> on_chain_end -> run final output logic"
)
output = self.agent_executor._return(
next_step_output, self.intermediate_steps, run_manager=run_manager
)
if self.run_manager:
self.run_manager.on_chain_end(output)
self.final_outputs = output
return output
self.intermediate_steps.extend(next_step_output)
logger.debug("Updated intermediate_steps with step output")
# Check for tool return
if len(next_step_output) == 1:
next_step_action = next_step_output[0]
tool_return = self.agent_executor._get_tool_return(next_step_action)
if tool_return is not None:
output = self.agent_executor._return(
tool_return, self.intermediate_steps, run_manager=run_manager
)
if self.run_manager:
self.run_manager.on_chain_end(output)
self.final_outputs = output
return output
output = {"intermediate_step": next_step_output}
return output
async def _aprocess_next_step_output(
self,
next_step_output: Union[AgentFinish, List[Tuple[AgentAction, str]]],
run_manager: Optional[AsyncCallbackManagerForChainRun],
) -> Dict[str, Union[str, List[Tuple[AgentAction, str]]]]:
"""
Process the output of the next async step,
handling AgentFinish and tool return cases.
"""
logger.debug("Processing output of async Agent loop step")
if isinstance(next_step_output, AgentFinish):
logger.debug(
"Hit AgentFinish: _areturn -> on_chain_end -> run final output logic"
)
output = await self.agent_executor._areturn(
next_step_output, self.intermediate_steps, run_manager=run_manager
)
if run_manager:
await run_manager.on_chain_end(output)
self.final_outputs = output
return output
self.intermediate_steps.extend(next_step_output)
logger.debug("Updated intermediate_steps with step output")
# Check for tool return
if len(next_step_output) == 1:
next_step_action = next_step_output[0]
tool_return = self.agent_executor._get_tool_return(next_step_action)
if tool_return is not None:
output = await self.agent_executor._areturn(
tool_return, self.intermediate_steps, run_manager=run_manager
)
if run_manager:
await run_manager.on_chain_end(output)
self.final_outputs = output
return output
output = {"intermediate_step": next_step_output}
return output
def _stop(self) -> dict[str, Any]:
"""
Stop the iterator and raise a StopIteration exception with the stopped response.
"""
logger.warning("Stopping agent prematurely due to triggering stop condition")
# this manually constructs agent finish with output key
output = self.agent_executor.agent.return_stopped_response(
self.agent_executor.early_stopping_method,
self.intermediate_steps,
**self.inputs,
)
assert (
isinstance(self.run_manager, CallbackManagerForChainRun)
or self.run_manager is None
)
returned_output = self.agent_executor._return(
output, self.intermediate_steps, run_manager=self.run_manager
)
self.final_outputs = returned_output
return returned_output
async def _astop(self) -> dict[str, Any]:
"""
Stop the async iterator and raise a StopAsyncIteration exception with
the stopped response.
"""
logger.warning("Stopping agent prematurely due to triggering stop condition")
output = self.agent_executor.agent.return_stopped_response(
self.agent_executor.early_stopping_method,
self.intermediate_steps,
**self.inputs,
)
assert (
isinstance(self.run_manager, AsyncCallbackManagerForChainRun)
or self.run_manager is None
)
returned_output = await self.agent_executor._areturn(
output, self.intermediate_steps, run_manager=self.run_manager
)
self.final_outputs = returned_output
return returned_output
def _call_next(self) -> dict[str, Any]:
"""
Perform a single iteration of the synchronous AgentExecutorIterator.
"""
# final output already reached: stopiteration (final output)
if self.final_outputs is not None:
self.raise_stopiteration(self.final_outputs)
# timeout/max iterations: stopiteration (stopped response)
if not self.agent_executor._should_continue(self.iterations, self.time_elapsed):
return self._stop()
assert (
isinstance(self.run_manager, CallbackManagerForChainRun)
or self.run_manager is None
)
next_step_output = self._execute_next_step(self.run_manager)
output = self._process_next_step_output(next_step_output, self.run_manager)
self.update_iterations()
return output
async def _acall_next(self) -> dict[str, Any]:
"""
Perform a single iteration of the asynchronous AgentExecutorIterator.
"""
# final output already reached: stopiteration (final output)
if self.final_outputs is not None:
await self.raise_stopasynciteration(self.final_outputs)
# timeout/max iterations: stopiteration (stopped response)
if not self.agent_executor._should_continue(self.iterations, self.time_elapsed):
return await self._astop()
assert (
isinstance(self.run_manager, AsyncCallbackManagerForChainRun)
or self.run_manager is None
)
next_step_output = await self._execute_next_async_step(self.run_manager)
output = await self._aprocess_next_step_output(
next_step_output, self.run_manager
)
self.update_iterations()
return output

View File

@ -32,6 +32,9 @@ class FakeListLLM(LLM):
"""Return number of tokens in text.""" """Return number of tokens in text."""
return len(text.split()) return len(text.split())
async def _acall(self, *args: Any, **kwargs: Any) -> str:
return self._call(*args, **kwargs)
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Mapping[str, Any]:
return {} return {}
@ -49,7 +52,8 @@ def _get_agent(**kwargs: Any) -> AgentExecutor:
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment", f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment",
"Oh well\nFinal Answer: curses foiled again", "Oh well\nFinal Answer: curses foiled again",
] ]
fake_llm = FakeListLLM(responses=responses) fake_llm = FakeListLLM(cache=False, responses=responses)
tools = [ tools = [
Tool( Tool(
name="Search", name="Search",
@ -62,6 +66,7 @@ def _get_agent(**kwargs: Any) -> AgentExecutor:
description="Useful for looking up things in a table", description="Useful for looking up things in a table",
), ),
] ]
agent = initialize_agent( agent = initialize_agent(
tools, tools,
fake_llm, fake_llm,
@ -194,6 +199,7 @@ def test_agent_tool_return_direct_in_intermediate_steps() -> None:
) )
resp = agent("when was langchain made") resp = agent("when was langchain made")
assert isinstance(resp, dict)
assert resp["output"] == "misalignment" assert resp["output"] == "misalignment"
assert len(resp["intermediate_steps"]) == 1 assert len(resp["intermediate_steps"]) == 1
action, _action_intput = resp["intermediate_steps"][0] action, _action_intput = resp["intermediate_steps"][0]

View File

@ -0,0 +1,360 @@
import pytest
from langchain.agents import (
AgentExecutor,
AgentExecutorIterator,
AgentType,
initialize_agent,
)
from langchain.agents.tools import Tool
from langchain.llms import FakeListLLM
from tests.unit_tests.agents.test_agent import _get_agent
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
def test_agent_iterator_bad_action() -> None:
"""Test react chain iterator when bad action given."""
agent = _get_agent()
agent_iter = agent.iter(inputs="when was langchain made")
outputs = []
for step in agent_iter:
outputs.append(step)
assert isinstance(outputs[-1], dict)
assert outputs[-1]["output"] == "curses foiled again"
def test_agent_iterator_stopped_early() -> None:
"""
Test react chain iterator when max iterations or
max execution time is exceeded.
"""
# iteration limit
agent = _get_agent(max_iterations=1)
agent_iter = agent.iter(inputs="when was langchain made")
outputs = []
for step in agent_iter:
outputs.append(step)
# NOTE: we don't use agent.run like in the test for the regular agent executor,
# so the dict structure for outputs stays intact
assert isinstance(outputs[-1], dict)
assert (
outputs[-1]["output"] == "Agent stopped due to iteration limit or time limit."
)
# execution time limit
agent = _get_agent(max_execution_time=1e-5)
agent_iter = agent.iter(inputs="when was langchain made")
outputs = []
for step in agent_iter:
outputs.append(step)
assert isinstance(outputs[-1], dict)
assert (
outputs[-1]["output"] == "Agent stopped due to iteration limit or time limit."
)
@pytest.mark.asyncio
async def test_agent_async_iterator_stopped_early() -> None:
"""
Test react chain async iterator when max iterations or
max execution time is exceeded.
"""
# iteration limit
agent = _get_agent(max_iterations=1)
agent_async_iter = agent.iter(inputs="when was langchain made", async_=True)
outputs = []
assert isinstance(agent_async_iter, AgentExecutorIterator)
async for step in agent_async_iter:
outputs.append(step)
assert isinstance(outputs[-1], dict)
assert (
outputs[-1]["output"] == "Agent stopped due to iteration limit or time limit."
)
# execution time limit
agent = _get_agent(max_execution_time=1e-5)
agent_async_iter = agent.iter(inputs="when was langchain made", async_=True)
assert isinstance(agent_async_iter, AgentExecutorIterator)
outputs = []
async for step in agent_async_iter:
outputs.append(step)
assert (
outputs[-1]["output"] == "Agent stopped due to iteration limit or time limit."
)
def test_agent_iterator_with_callbacks() -> None:
"""Test react chain iterator with callbacks by setting verbose globally."""
handler1 = FakeCallbackHandler()
handler2 = FakeCallbackHandler()
bad_action_name = "BadAction"
responses = [
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment",
"Oh well\nFinal Answer: curses foiled again",
]
fake_llm = FakeListLLM(cache=False, responses=responses, callbacks=[handler2])
tools = [
Tool(
name="Search",
func=lambda x: x,
description="Useful for searching",
),
Tool(
name="Lookup",
func=lambda x: x,
description="Useful for looking up things in a table",
),
]
agent = initialize_agent(
tools, fake_llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
agent_iter = agent.iter(inputs="when was langchain made", callbacks=[handler1])
outputs = []
for step in agent_iter:
outputs.append(step)
assert isinstance(outputs[-1], dict)
assert outputs[-1]["output"] == "curses foiled again"
# 1 top level chain run runs, 2 LLMChain runs, 2 LLM runs, 1 tool run
assert handler1.chain_starts == handler1.chain_ends == 3
assert handler1.llm_starts == handler1.llm_ends == 2
assert handler1.tool_starts == 1
assert handler1.tool_ends == 1
# 1 extra agent action
assert handler1.starts == 7
# 1 extra agent end
assert handler1.ends == 7
print("h:", handler1)
assert handler1.errors == 0
# during LLMChain
assert handler1.text == 2
assert handler2.llm_starts == 2
assert handler2.llm_ends == 2
assert (
handler2.chain_starts
== handler2.tool_starts
== handler2.tool_ends
== handler2.chain_ends
== 0
)
@pytest.mark.asyncio
async def test_agent_async_iterator_with_callbacks() -> None:
"""Test react chain async iterator with callbacks by setting verbose globally."""
handler1 = FakeCallbackHandler()
handler2 = FakeCallbackHandler()
bad_action_name = "BadAction"
responses = [
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment",
"Oh well\nFinal Answer: curses foiled again",
]
fake_llm = FakeListLLM(cache=False, responses=responses, callbacks=[handler2])
tools = [
Tool(
name="Search",
func=lambda x: x,
description="Useful for searching",
),
Tool(
name="Lookup",
func=lambda x: x,
description="Useful for looking up things in a table",
),
]
agent = initialize_agent(
tools, fake_llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
agent_async_iter = agent.iter(
inputs="when was langchain made",
callbacks=[handler1],
async_=True,
)
assert isinstance(agent_async_iter, AgentExecutorIterator)
outputs = []
async for step in agent_async_iter:
outputs.append(step)
assert outputs[-1]["output"] == "curses foiled again"
# 1 top level chain run runs, 2 LLMChain runs, 2 LLM runs, 1 tool run
assert handler1.chain_starts == handler1.chain_ends == 3
assert handler1.llm_starts == handler1.llm_ends == 2
assert handler1.tool_starts == 1
assert handler1.tool_ends == 1
# 1 extra agent action
assert handler1.starts == 7
# 1 extra agent end
assert handler1.ends == 7
assert handler1.errors == 0
# during LLMChain
assert handler1.text == 2
assert handler2.llm_starts == 2
assert handler2.llm_ends == 2
assert (
handler2.chain_starts
== handler2.tool_starts
== handler2.tool_ends
== handler2.chain_ends
== 0
)
def test_agent_iterator_properties_and_setters() -> None:
"""Test properties and setters of AgentExecutorIterator."""
agent = _get_agent()
agent.tags = None
agent_iter = agent.iter(inputs="when was langchain made")
assert isinstance(agent_iter, AgentExecutorIterator)
assert isinstance(agent_iter.inputs, dict)
assert isinstance(agent_iter.callbacks, type(None))
assert isinstance(agent_iter.tags, type(None))
assert isinstance(agent_iter.agent_executor, AgentExecutor)
agent_iter.inputs = "New input" # type: ignore
assert isinstance(agent_iter.inputs, dict)
agent_iter.callbacks = [FakeCallbackHandler()]
assert isinstance(agent_iter.callbacks, list)
agent_iter.tags = ["test"]
assert isinstance(agent_iter.tags, list)
new_agent = _get_agent()
agent_iter.agent_executor = new_agent
assert isinstance(agent_iter.agent_executor, AgentExecutor)
def test_agent_iterator_reset() -> None:
"""Test reset functionality of AgentExecutorIterator."""
agent = _get_agent()
agent_iter = agent.iter(inputs="when was langchain made")
assert isinstance(agent_iter, AgentExecutorIterator)
# Perform one iteration
next(agent_iter)
# Check if properties are updated
assert agent_iter.iterations == 1
assert agent_iter.time_elapsed > 0.0
assert agent_iter.intermediate_steps
# Reset the iterator
agent_iter.reset()
# Check if properties are reset
assert agent_iter.iterations == 0
assert agent_iter.time_elapsed == 0.0
assert not agent_iter.intermediate_steps
def test_agent_iterator_output_structure() -> None:
"""Test the output structure of AgentExecutorIterator."""
agent = _get_agent()
agent_iter = agent.iter(inputs="when was langchain made")
for step in agent_iter:
assert isinstance(step, dict)
if "intermediate_step" in step:
assert isinstance(step["intermediate_step"], list)
elif "output" in step:
assert isinstance(step["output"], str)
else:
assert False, "Unexpected output structure"
@pytest.mark.asyncio
async def test_agent_async_iterator_output_structure() -> None:
"""Test the async output structure of AgentExecutorIterator."""
agent = _get_agent()
agent_async_iter = agent.iter(inputs="when was langchain made", async_=True)
assert isinstance(agent_async_iter, AgentExecutorIterator)
async for step in agent_async_iter:
assert isinstance(step, dict)
if "intermediate_step" in step:
assert isinstance(step["intermediate_step"], list)
elif "output" in step:
assert isinstance(step["output"], str)
else:
assert False, "Unexpected output structure"
def test_agent_iterator_empty_input() -> None:
"""Test AgentExecutorIterator with empty input."""
agent = _get_agent()
agent_iter = agent.iter(inputs="")
outputs = []
for step in agent_iter:
outputs.append(step)
assert isinstance(outputs[-1], dict)
assert outputs[-1]["output"] # Check if there is an output
def test_agent_iterator_custom_stopping_condition() -> None:
"""Test AgentExecutorIterator with a custom stopping condition."""
agent = _get_agent()
class CustomAgentExecutorIterator(AgentExecutorIterator):
def _should_continue(self) -> bool:
return self.iterations < 2 # Custom stopping condition
agent_iter = CustomAgentExecutorIterator(agent, inputs="when was langchain made")
outputs = []
for step in agent_iter:
outputs.append(step)
assert len(outputs) == 2 # Check if the custom stopping condition is respected
def test_agent_iterator_failing_tool() -> None:
"""Test AgentExecutorIterator with a tool that raises an exception."""
# Get agent for testing.
bad_action_name = "FailingTool"
responses = [
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment",
"Oh well\nFinal Answer: curses foiled again",
]
fake_llm = FakeListLLM(responses=responses)
tools = [
Tool(
name="FailingTool",
func=lambda x: 1 / 0, # This tool will raise a ZeroDivisionError
description="A tool that fails",
),
]
agent = initialize_agent(
tools, fake_llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
agent_iter = agent.iter(inputs="when was langchain made")
assert isinstance(agent_iter, AgentExecutorIterator)
# initialise iterator
iter(agent_iter)
with pytest.raises(ZeroDivisionError):
next(agent_iter)

View File

@ -3,6 +3,7 @@ from langchain.agents import __all__ as agents_all
_EXPECTED = [ _EXPECTED = [
"Agent", "Agent",
"AgentExecutor", "AgentExecutor",
"AgentExecutorIterator",
"AgentOutputParser", "AgentOutputParser",
"AgentType", "AgentType",
"BaseMultiActionAgent", "BaseMultiActionAgent",