From 961a0e200fb5cb12f1fee6788a298266bc5266d1 Mon Sep 17 00:00:00 2001 From: SlapDrone <32279503+SlapDrone@users.noreply.github.com> Date: Mon, 24 Jul 2023 03:00:22 +0200 Subject: [PATCH] Implement AgentExecutorIterator (#6929) - Description: Implements a `.iter()` method for the `AgentExecutor` class. This allows hooking into and intercepting intermediate agent steps. - Issue: #6925 - Dependencies: None - Tag maintainer: @vowelparrot @agola11 - Twitter handle: @SlapDron3 @lacicocodes --------- Co-authored-by: Lacico Co-authored-by: Bagatur --- .../modules/agents/how_to/agent_iter.ipynb | 245 +++++++++ libs/langchain/langchain/agents/__init__.py | 2 + libs/langchain/langchain/agents/agent.py | 19 + .../langchain/agents/agent_iterator.py | 501 ++++++++++++++++++ .../tests/unit_tests/agents/test_agent.py | 8 +- .../unit_tests/agents/test_agent_iterator.py | 360 +++++++++++++ .../unit_tests/agents/test_public_api.py | 1 + 7 files changed, 1135 insertions(+), 1 deletion(-) create mode 100644 docs/extras/modules/agents/how_to/agent_iter.ipynb create mode 100644 libs/langchain/langchain/agents/agent_iterator.py create mode 100644 libs/langchain/tests/unit_tests/agents/test_agent_iterator.py diff --git a/docs/extras/modules/agents/how_to/agent_iter.ipynb b/docs/extras/modules/agents/how_to/agent_iter.ipynb new file mode 100644 index 0000000000..a7baf8a115 --- /dev/null +++ b/docs/extras/modules/agents/how_to/agent_iter.ipynb @@ -0,0 +1,245 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "feb31cc6", + "metadata": {}, + "source": [ + "# Running Agent as an Iterator\n", + "\n", + "To demonstrate the `AgentExecutorIterator` functionality, we will set up a problem where an Agent must:\n", + "\n", + "- Retrieve three prime numbers from a Tool\n", + "- Multiply these together. \n", + "\n", + "In this simple problem we can demonstrate adding some logic to verify intermediate steps by checking whether their outputs are prime." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "8167db11", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import dotenv\n", + "import pydantic\n", + "from langchain.agents import AgentExecutor, initialize_agent, AgentType\n", + "from langchain.schema import AgentFinish\n", + "from langchain.agents.tools import Tool\n", + "from langchain import LLMMathChain\n", + "from langchain.chat_models import ChatOpenAI" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "7e41b9e6", + "metadata": {}, + "outputs": [], + "source": [ + "# Uncomment if you have a .env in root of repo contains OPENAI_API_KEY\n", + "# dotenv.load_dotenv(\"../../../../../.env\")\n", + "\n", + "# need to use GPT-4 here as GPT-3.5 does not understand, however hard you insist, that\n", + "# it should use the calculator to perform the final calculation\n", + "llm = ChatOpenAI(temperature=0, model=\"gpt-4\")\n", + "llm_math_chain = LLMMathChain.from_llm(llm=llm, verbose=True)" + ] + }, + { + "cell_type": "markdown", + "id": "81e88aa5", + "metadata": {}, + "source": [ + "Define tools which provide:\n", + "- The `n`th prime number (using a small subset for this example) \n", + "- The LLMMathChain to act as a calculator" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "86f04b55", + "metadata": {}, + "outputs": [], + "source": [ + "primes = {998: 7901, 999: 7907, 1000: 7919}\n", + "\n", + "\n", + "class CalculatorInput(pydantic.BaseModel):\n", + " question: str = pydantic.Field()\n", + "\n", + "\n", + "class PrimeInput(pydantic.BaseModel):\n", + " n: int = pydantic.Field()\n", + "\n", + "\n", + "def is_prime(n: int) -> bool:\n", + " if n <= 1 or (n % 2 == 0 and n > 2):\n", + " return False\n", + " for i in range(3, int(n**0.5) + 1, 2):\n", + " if n % i == 0:\n", + " return False\n", + " return True\n", + "\n", + "\n", + "def get_prime(n: int, primes: dict = primes) -> str:\n", + " return str(primes.get(int(n)))\n", + "\n", + "\n", + "async def aget_prime(n: int, primes: dict = primes) -> str:\n", + " return str(primes.get(int(n)))\n", + "\n", + "\n", + "tools = [\n", + " Tool(\n", + " name=\"GetPrime\",\n", + " func=get_prime,\n", + " description=\"A tool that returns the `n`th prime number\",\n", + " args_schema=PrimeInput,\n", + " coroutine=aget_prime,\n", + " ),\n", + " Tool.from_function(\n", + " func=llm_math_chain.run,\n", + " name=\"Calculator\",\n", + " description=\"Useful for when you need to compute mathematical expressions\",\n", + " args_schema=CalculatorInput,\n", + " coroutine=llm_math_chain.arun,\n", + " ),\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "0e660ee6", + "metadata": {}, + "source": [ + "Construct the agent. We will use the default agent type here." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "21c775b0", + "metadata": {}, + "outputs": [], + "source": [ + "agent = initialize_agent(\n", + " tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "a233fe4e", + "metadata": {}, + "source": [ + "Run the iteration and perform a custom check on certain steps:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "582d61f4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mI need to find the 998th, 999th and 1000th prime numbers first.\n", + "Action: GetPrime\n", + "Action Input: 998\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3m7901\u001b[0m\n", + "Thought:Checking whether 7901 is prime...\n", + "Should the agent continue (Y/n)?:\n", + "Y\n", + "\u001b[32;1m\u001b[1;3mI have the 998th prime number. Now I need to find the 999th prime number.\n", + "Action: GetPrime\n", + "Action Input: 999\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3m7907\u001b[0m\n", + "Thought:Checking whether 7907 is prime...\n", + "Should the agent continue (Y/n)?:\n", + "Y\n", + "\u001b[32;1m\u001b[1;3mI have the 999th prime number. Now I need to find the 1000th prime number.\n", + "Action: GetPrime\n", + "Action Input: 1000\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3m7919\u001b[0m\n", + "Thought:Checking whether 7919 is prime...\n", + "Should the agent continue (Y/n)?:\n", + "Y\n", + "\u001b[32;1m\u001b[1;3mI have all three prime numbers. Now I need to calculate the product of these numbers.\n", + "Action: Calculator\n", + "Action Input: 7901 * 7907 * 7919\u001b[0m\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "7901 * 7907 * 7919\u001b[32;1m\u001b[1;3m```text\n", + "7901 * 7907 * 7919\n", + "```\n", + "...numexpr.evaluate(\"7901 * 7907 * 7919\")...\n", + "\u001b[0m\n", + "Answer: \u001b[33;1m\u001b[1;3m494725326233\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "Observation: \u001b[33;1m\u001b[1;3mAnswer: 494725326233\u001b[0m\n", + "Thought:Should the agent continue (Y/n)?:\n", + "Y\n", + "\u001b[32;1m\u001b[1;3mI now know the final answer\n", + "Final Answer: 494725326233\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + } + ], + "source": [ + "question = \"What is the product of the 998th, 999th and 1000th prime numbers?\"\n", + "\n", + "for step in agent.iter(question):\n", + " if output := step.get(\"intermediate_step\"):\n", + " action, value = output[0]\n", + " if action.tool == \"GetPrime\":\n", + " print(f\"Checking whether {value} is prime...\")\n", + " assert is_prime(int(value))\n", + " # Ask user if they want to continue\n", + " _continue = input(\"Should the agent continue (Y/n)?:\\n\")\n", + " if _continue != \"Y\":\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6934ff8e", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "venv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/langchain/langchain/agents/__init__.py b/libs/langchain/langchain/agents/__init__.py index a7213959a8..f045082fac 100644 --- a/libs/langchain/langchain/agents/__init__.py +++ b/libs/langchain/langchain/agents/__init__.py @@ -7,6 +7,7 @@ from langchain.agents.agent import ( BaseSingleActionAgent, LLMSingleActionAgent, ) +from langchain.agents.agent_iterator import AgentExecutorIterator from langchain.agents.agent_toolkits import ( create_csv_agent, create_json_agent, @@ -42,6 +43,7 @@ from langchain.agents.tools import Tool, tool __all__ = [ "Agent", "AgentExecutor", + "AgentExecutorIterator", "AgentOutputParser", "AgentType", "BaseMultiActionAgent", diff --git a/libs/langchain/langchain/agents/agent.py b/libs/langchain/langchain/agents/agent.py index ee4f7ae815..fb6404bf6c 100644 --- a/libs/langchain/langchain/agents/agent.py +++ b/libs/langchain/langchain/agents/agent.py @@ -12,6 +12,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import yaml from pydantic import BaseModel, root_validator +from langchain.agents.agent_iterator import AgentExecutorIterator from langchain.agents.agent_types import AgentType from langchain.agents.tools import InvalidTool from langchain.callbacks.base import BaseCallbackManager @@ -732,6 +733,24 @@ s """Save the underlying agent.""" return self.agent.save(file_path) + def iter( + self, + inputs: Any, + callbacks: Callbacks = None, + *, + include_run_info: bool = False, + async_: bool = False, + ) -> AgentExecutorIterator: + """Enables iteration over steps taken to reach final output.""" + return AgentExecutorIterator( + self, + inputs, + callbacks, + tags=self.tags, + include_run_info=include_run_info, + async_=async_, + ) + @property def input_keys(self) -> List[str]: """Return the input keys. diff --git a/libs/langchain/langchain/agents/agent_iterator.py b/libs/langchain/langchain/agents/agent_iterator.py new file mode 100644 index 0000000000..40193387f9 --- /dev/null +++ b/libs/langchain/langchain/agents/agent_iterator.py @@ -0,0 +1,501 @@ +from __future__ import annotations + +import logging +import time +from abc import ABC, abstractmethod +from asyncio import CancelledError +from functools import wraps +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + NoReturn, + Optional, + Tuple, + Type, + Union, +) + +from langchain.callbacks.manager import ( + AsyncCallbackManager, + AsyncCallbackManagerForChainRun, + CallbackManager, + CallbackManagerForChainRun, + Callbacks, +) +from langchain.input import get_color_mapping +from langchain.load.dump import dumpd +from langchain.schema import RUN_KEY, AgentAction, AgentFinish, RunInfo +from langchain.tools import BaseTool +from langchain.utilities.asyncio import asyncio_timeout + +if TYPE_CHECKING: + from langchain.agents.agent import AgentExecutor + +logger = logging.getLogger(__name__) + + +class BaseAgentExecutorIterator(ABC): + @abstractmethod + def build_callback_manager(self) -> None: + pass + + +def rebuild_callback_manager_on_set( + setter_method: Callable[..., None] +) -> Callable[..., None]: + """Decorator to force setters to rebuild callback mgr""" + + @wraps(setter_method) + def wrapper(self: BaseAgentExecutorIterator, *args: Any, **kwargs: Any) -> None: + setter_method(self, *args, **kwargs) + self.build_callback_manager() + + return wrapper + + +class AgentExecutorIterator(BaseAgentExecutorIterator): + def __init__( + self, + agent_executor: AgentExecutor, + inputs: Any, + callbacks: Callbacks = None, + *, + tags: Optional[list[str]] = None, + include_run_info: bool = False, + async_: bool = False, + ): + """ + Initialize the AgentExecutorIterator with the given AgentExecutor, + inputs, and optional callbacks. + """ + self._agent_executor = agent_executor + self.inputs = inputs + self.async_ = async_ + # build callback manager on tags setter + self._callbacks = callbacks + self.tags = tags + self.include_run_info = include_run_info + self.run_manager = None + self.reset() + + _callback_manager: Union[AsyncCallbackManager, CallbackManager] + _inputs: dict[str, str] + _final_outputs: Optional[dict[str, str]] + run_manager: Optional[ + Union[AsyncCallbackManagerForChainRun, CallbackManagerForChainRun] + ] + timeout_manager: Any # TODO: Fix a type here; the shim makes it tricky. + + @property + def inputs(self) -> dict[str, str]: + return self._inputs + + @inputs.setter + def inputs(self, inputs: Any) -> None: + self._inputs = self.agent_executor.prep_inputs(inputs) + + @property + def callbacks(self) -> Callbacks: + return self._callbacks + + @callbacks.setter + @rebuild_callback_manager_on_set + def callbacks(self, callbacks: Callbacks) -> None: + """When callbacks are changed after __init__, rebuild callback mgr""" + self._callbacks = callbacks + + @property + def tags(self) -> Optional[List[str]]: + return self._tags + + @tags.setter + @rebuild_callback_manager_on_set + def tags(self, tags: Optional[List[str]]) -> None: + """When tags are changed after __init__, rebuild callback mgr""" + self._tags = tags + + @property + def agent_executor(self) -> AgentExecutor: + return self._agent_executor + + @agent_executor.setter + @rebuild_callback_manager_on_set + def agent_executor(self, agent_executor: AgentExecutor) -> None: + self._agent_executor = agent_executor + # force re-prep inputs in case agent_executor's prep_inputs fn changed + self.inputs = self.inputs + + @property + def callback_manager(self) -> Union[AsyncCallbackManager, CallbackManager]: + return self._callback_manager + + def build_callback_manager(self) -> None: + """ + Create and configure the callback manager based on the current + callbacks and tags. + """ + CallbackMgr: Union[Type[AsyncCallbackManager], Type[CallbackManager]] = ( + AsyncCallbackManager if self.async_ else CallbackManager + ) + self._callback_manager = CallbackMgr.configure( + self.callbacks, + self.agent_executor.callbacks, + self.agent_executor.verbose, + self.tags, + self.agent_executor.tags, + ) + + @property + def name_to_tool_map(self) -> dict[str, BaseTool]: + return {tool.name: tool for tool in self.agent_executor.tools} + + @property + def color_mapping(self) -> dict[str, str]: + return get_color_mapping( + [tool.name for tool in self.agent_executor.tools], + excluded_colors=["green", "red"], + ) + + def reset(self) -> None: + """ + Reset the iterator to its initial state, clearing intermediate steps, + iterations, and time elapsed. + """ + logger.debug("(Re)setting AgentExecutorIterator to fresh state") + self.intermediate_steps: list[tuple[AgentAction, str]] = [] + self.iterations = 0 + # maybe better to start these on the first __anext__ call? + self.time_elapsed = 0.0 + self.start_time = time.time() + self._final_outputs = None + + def update_iterations(self) -> None: + """ + Increment the number of iterations and update the time elapsed. + """ + self.iterations += 1 + self.time_elapsed = time.time() - self.start_time + logger.debug( + f"Agent Iterations: {self.iterations} ({self.time_elapsed:.2f}s elapsed)" + ) + + def raise_stopiteration(self, output: Any) -> NoReturn: + """ + Raise a StopIteration exception with the given output. + """ + logger.debug("Chain end: stop iteration") + raise StopIteration(output) + + async def raise_stopasynciteration(self, output: Any) -> NoReturn: + """ + Raise a StopAsyncIteration exception with the given output. + Close the timeout context manager. + """ + logger.debug("Chain end: stop async iteration") + if self.timeout_manager is not None: + await self.timeout_manager.__aexit__(None, None, None) + raise StopAsyncIteration(output) + + @property + def final_outputs(self) -> Optional[dict[str, Any]]: + return self._final_outputs + + @final_outputs.setter + def final_outputs(self, outputs: Optional[Dict[str, Any]]) -> None: + # have access to intermediate steps by design in iterator, + # so return only outputs may as well always be true. + + self._final_outputs = None + if outputs: + prepared_outputs: dict[str, Any] = self.agent_executor.prep_outputs( + self.inputs, outputs, return_only_outputs=True + ) + if self.include_run_info and self.run_manager is not None: + logger.debug("Assign run key") + prepared_outputs[RUN_KEY] = RunInfo(run_id=self.run_manager.run_id) + self._final_outputs = prepared_outputs + + def __iter__(self: "AgentExecutorIterator") -> "AgentExecutorIterator": + logger.debug("Initialising AgentExecutorIterator") + self.reset() + assert isinstance(self.callback_manager, CallbackManager) + self.run_manager = self.callback_manager.on_chain_start( + dumpd(self.agent_executor), + self.inputs, + ) + return self + + def __aiter__(self) -> "AgentExecutorIterator": + """ + N.B. __aiter__ must be a normal method, so need to initialise async run manager + on first __anext__ call where we can await it + """ + logger.debug("Initialising AgentExecutorIterator (async)") + self.reset() + if self.agent_executor.max_execution_time: + self.timeout_manager = asyncio_timeout( + self.agent_executor.max_execution_time + ) + else: + self.timeout_manager = None + return self + + def _on_first_step(self) -> None: + """ + Perform any necessary setup for the first step of the synchronous iterator. + """ + pass + + async def _on_first_async_step(self) -> None: + """ + Perform any necessary setup for the first step of the asynchronous iterator. + """ + # on first step, need to await callback manager and start async timeout ctxmgr + if self.iterations == 0: + assert isinstance(self.callback_manager, AsyncCallbackManager) + self.run_manager = await self.callback_manager.on_chain_start( + dumpd(self.agent_executor), + self.inputs, + ) + if self.timeout_manager: + await self.timeout_manager.__aenter__() + + def __next__(self) -> dict[str, Any]: + """ + AgentExecutor AgentExecutorIterator + __call__ (__iter__ ->) __next__ + _call <=> _call_next + _take_next_step _take_next_step + """ + # first step + if self.iterations == 0: + self._on_first_step() + # N.B. timeout taken care of by "_should_continue" in sync case + try: + return self._call_next() + except StopIteration: + raise + except (KeyboardInterrupt, Exception) as e: + if self.run_manager: + self.run_manager.on_chain_error(e) + raise + + async def __anext__(self) -> dict[str, Any]: + """ + AgentExecutor AgentExecutorIterator + acall (__aiter__ ->) __anext__ + _acall <=> _acall_next + _atake_next_step _atake_next_step + """ + if self.iterations == 0: + await self._on_first_async_step() + try: + return await self._acall_next() + except StopAsyncIteration: + raise + except (TimeoutError, CancelledError): + await self.timeout_manager.__aexit__(None, None, None) + self.timeout_manager = None + return await self._astop() + except (KeyboardInterrupt, Exception) as e: + if self.run_manager: + assert isinstance(self.run_manager, AsyncCallbackManagerForChainRun) + await self.run_manager.on_chain_error(e) + raise + + def _execute_next_step( + self, run_manager: Optional[CallbackManagerForChainRun] + ) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]: + """ + Execute the next step in the chain using the + AgentExecutor's _take_next_step method. + """ + return self.agent_executor._take_next_step( + self.name_to_tool_map, + self.color_mapping, + self.inputs, + self.intermediate_steps, + run_manager=run_manager, + ) + + async def _execute_next_async_step( + self, run_manager: Optional[AsyncCallbackManagerForChainRun] + ) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]: + """ + Execute the next step in the chain using the + AgentExecutor's _atake_next_step method. + """ + return await self.agent_executor._atake_next_step( + self.name_to_tool_map, + self.color_mapping, + self.inputs, + self.intermediate_steps, + run_manager=run_manager, + ) + + def _process_next_step_output( + self, + next_step_output: Union[AgentFinish, List[Tuple[AgentAction, str]]], + run_manager: Optional[CallbackManagerForChainRun], + ) -> Dict[str, Union[str, List[Tuple[AgentAction, str]]]]: + """ + Process the output of the next step, + handling AgentFinish and tool return cases. + """ + logger.debug("Processing output of Agent loop step") + if isinstance(next_step_output, AgentFinish): + logger.debug( + "Hit AgentFinish: _return -> on_chain_end -> run final output logic" + ) + output = self.agent_executor._return( + next_step_output, self.intermediate_steps, run_manager=run_manager + ) + if self.run_manager: + self.run_manager.on_chain_end(output) + self.final_outputs = output + return output + + self.intermediate_steps.extend(next_step_output) + logger.debug("Updated intermediate_steps with step output") + + # Check for tool return + if len(next_step_output) == 1: + next_step_action = next_step_output[0] + tool_return = self.agent_executor._get_tool_return(next_step_action) + if tool_return is not None: + output = self.agent_executor._return( + tool_return, self.intermediate_steps, run_manager=run_manager + ) + if self.run_manager: + self.run_manager.on_chain_end(output) + self.final_outputs = output + return output + + output = {"intermediate_step": next_step_output} + return output + + async def _aprocess_next_step_output( + self, + next_step_output: Union[AgentFinish, List[Tuple[AgentAction, str]]], + run_manager: Optional[AsyncCallbackManagerForChainRun], + ) -> Dict[str, Union[str, List[Tuple[AgentAction, str]]]]: + """ + Process the output of the next async step, + handling AgentFinish and tool return cases. + """ + logger.debug("Processing output of async Agent loop step") + if isinstance(next_step_output, AgentFinish): + logger.debug( + "Hit AgentFinish: _areturn -> on_chain_end -> run final output logic" + ) + output = await self.agent_executor._areturn( + next_step_output, self.intermediate_steps, run_manager=run_manager + ) + if run_manager: + await run_manager.on_chain_end(output) + self.final_outputs = output + return output + + self.intermediate_steps.extend(next_step_output) + logger.debug("Updated intermediate_steps with step output") + + # Check for tool return + if len(next_step_output) == 1: + next_step_action = next_step_output[0] + tool_return = self.agent_executor._get_tool_return(next_step_action) + if tool_return is not None: + output = await self.agent_executor._areturn( + tool_return, self.intermediate_steps, run_manager=run_manager + ) + if run_manager: + await run_manager.on_chain_end(output) + self.final_outputs = output + return output + + output = {"intermediate_step": next_step_output} + return output + + def _stop(self) -> dict[str, Any]: + """ + Stop the iterator and raise a StopIteration exception with the stopped response. + """ + logger.warning("Stopping agent prematurely due to triggering stop condition") + # this manually constructs agent finish with output key + output = self.agent_executor.agent.return_stopped_response( + self.agent_executor.early_stopping_method, + self.intermediate_steps, + **self.inputs, + ) + assert ( + isinstance(self.run_manager, CallbackManagerForChainRun) + or self.run_manager is None + ) + returned_output = self.agent_executor._return( + output, self.intermediate_steps, run_manager=self.run_manager + ) + self.final_outputs = returned_output + return returned_output + + async def _astop(self) -> dict[str, Any]: + """ + Stop the async iterator and raise a StopAsyncIteration exception with + the stopped response. + """ + logger.warning("Stopping agent prematurely due to triggering stop condition") + output = self.agent_executor.agent.return_stopped_response( + self.agent_executor.early_stopping_method, + self.intermediate_steps, + **self.inputs, + ) + assert ( + isinstance(self.run_manager, AsyncCallbackManagerForChainRun) + or self.run_manager is None + ) + returned_output = await self.agent_executor._areturn( + output, self.intermediate_steps, run_manager=self.run_manager + ) + self.final_outputs = returned_output + return returned_output + + def _call_next(self) -> dict[str, Any]: + """ + Perform a single iteration of the synchronous AgentExecutorIterator. + """ + # final output already reached: stopiteration (final output) + if self.final_outputs is not None: + self.raise_stopiteration(self.final_outputs) + # timeout/max iterations: stopiteration (stopped response) + if not self.agent_executor._should_continue(self.iterations, self.time_elapsed): + return self._stop() + assert ( + isinstance(self.run_manager, CallbackManagerForChainRun) + or self.run_manager is None + ) + next_step_output = self._execute_next_step(self.run_manager) + output = self._process_next_step_output(next_step_output, self.run_manager) + self.update_iterations() + return output + + async def _acall_next(self) -> dict[str, Any]: + """ + Perform a single iteration of the asynchronous AgentExecutorIterator. + """ + # final output already reached: stopiteration (final output) + if self.final_outputs is not None: + await self.raise_stopasynciteration(self.final_outputs) + # timeout/max iterations: stopiteration (stopped response) + if not self.agent_executor._should_continue(self.iterations, self.time_elapsed): + return await self._astop() + assert ( + isinstance(self.run_manager, AsyncCallbackManagerForChainRun) + or self.run_manager is None + ) + next_step_output = await self._execute_next_async_step(self.run_manager) + output = await self._aprocess_next_step_output( + next_step_output, self.run_manager + ) + self.update_iterations() + return output diff --git a/libs/langchain/tests/unit_tests/agents/test_agent.py b/libs/langchain/tests/unit_tests/agents/test_agent.py index cdb7bb0c3b..2405da88f9 100644 --- a/libs/langchain/tests/unit_tests/agents/test_agent.py +++ b/libs/langchain/tests/unit_tests/agents/test_agent.py @@ -32,6 +32,9 @@ class FakeListLLM(LLM): """Return number of tokens in text.""" return len(text.split()) + async def _acall(self, *args: Any, **kwargs: Any) -> str: + return self._call(*args, **kwargs) + @property def _identifying_params(self) -> Mapping[str, Any]: return {} @@ -49,7 +52,8 @@ def _get_agent(**kwargs: Any) -> AgentExecutor: f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment", "Oh well\nFinal Answer: curses foiled again", ] - fake_llm = FakeListLLM(responses=responses) + fake_llm = FakeListLLM(cache=False, responses=responses) + tools = [ Tool( name="Search", @@ -62,6 +66,7 @@ def _get_agent(**kwargs: Any) -> AgentExecutor: description="Useful for looking up things in a table", ), ] + agent = initialize_agent( tools, fake_llm, @@ -194,6 +199,7 @@ def test_agent_tool_return_direct_in_intermediate_steps() -> None: ) resp = agent("when was langchain made") + assert isinstance(resp, dict) assert resp["output"] == "misalignment" assert len(resp["intermediate_steps"]) == 1 action, _action_intput = resp["intermediate_steps"][0] diff --git a/libs/langchain/tests/unit_tests/agents/test_agent_iterator.py b/libs/langchain/tests/unit_tests/agents/test_agent_iterator.py new file mode 100644 index 0000000000..6861d67618 --- /dev/null +++ b/libs/langchain/tests/unit_tests/agents/test_agent_iterator.py @@ -0,0 +1,360 @@ +import pytest + +from langchain.agents import ( + AgentExecutor, + AgentExecutorIterator, + AgentType, + initialize_agent, +) +from langchain.agents.tools import Tool +from langchain.llms import FakeListLLM +from tests.unit_tests.agents.test_agent import _get_agent +from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler + + +def test_agent_iterator_bad_action() -> None: + """Test react chain iterator when bad action given.""" + agent = _get_agent() + agent_iter = agent.iter(inputs="when was langchain made") + + outputs = [] + for step in agent_iter: + outputs.append(step) + + assert isinstance(outputs[-1], dict) + assert outputs[-1]["output"] == "curses foiled again" + + +def test_agent_iterator_stopped_early() -> None: + """ + Test react chain iterator when max iterations or + max execution time is exceeded. + """ + # iteration limit + agent = _get_agent(max_iterations=1) + agent_iter = agent.iter(inputs="when was langchain made") + + outputs = [] + for step in agent_iter: + outputs.append(step) + # NOTE: we don't use agent.run like in the test for the regular agent executor, + # so the dict structure for outputs stays intact + assert isinstance(outputs[-1], dict) + assert ( + outputs[-1]["output"] == "Agent stopped due to iteration limit or time limit." + ) + + # execution time limit + agent = _get_agent(max_execution_time=1e-5) + agent_iter = agent.iter(inputs="when was langchain made") + + outputs = [] + for step in agent_iter: + outputs.append(step) + assert isinstance(outputs[-1], dict) + assert ( + outputs[-1]["output"] == "Agent stopped due to iteration limit or time limit." + ) + + +@pytest.mark.asyncio +async def test_agent_async_iterator_stopped_early() -> None: + """ + Test react chain async iterator when max iterations or + max execution time is exceeded. + """ + # iteration limit + agent = _get_agent(max_iterations=1) + agent_async_iter = agent.iter(inputs="when was langchain made", async_=True) + + outputs = [] + assert isinstance(agent_async_iter, AgentExecutorIterator) + async for step in agent_async_iter: + outputs.append(step) + + assert isinstance(outputs[-1], dict) + assert ( + outputs[-1]["output"] == "Agent stopped due to iteration limit or time limit." + ) + + # execution time limit + agent = _get_agent(max_execution_time=1e-5) + agent_async_iter = agent.iter(inputs="when was langchain made", async_=True) + assert isinstance(agent_async_iter, AgentExecutorIterator) + + outputs = [] + async for step in agent_async_iter: + outputs.append(step) + + assert ( + outputs[-1]["output"] == "Agent stopped due to iteration limit or time limit." + ) + + +def test_agent_iterator_with_callbacks() -> None: + """Test react chain iterator with callbacks by setting verbose globally.""" + handler1 = FakeCallbackHandler() + handler2 = FakeCallbackHandler() + bad_action_name = "BadAction" + responses = [ + f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment", + "Oh well\nFinal Answer: curses foiled again", + ] + fake_llm = FakeListLLM(cache=False, responses=responses, callbacks=[handler2]) + + tools = [ + Tool( + name="Search", + func=lambda x: x, + description="Useful for searching", + ), + Tool( + name="Lookup", + func=lambda x: x, + description="Useful for looking up things in a table", + ), + ] + + agent = initialize_agent( + tools, fake_llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True + ) + agent_iter = agent.iter(inputs="when was langchain made", callbacks=[handler1]) + + outputs = [] + for step in agent_iter: + outputs.append(step) + assert isinstance(outputs[-1], dict) + assert outputs[-1]["output"] == "curses foiled again" + + # 1 top level chain run runs, 2 LLMChain runs, 2 LLM runs, 1 tool run + assert handler1.chain_starts == handler1.chain_ends == 3 + assert handler1.llm_starts == handler1.llm_ends == 2 + assert handler1.tool_starts == 1 + assert handler1.tool_ends == 1 + # 1 extra agent action + assert handler1.starts == 7 + # 1 extra agent end + assert handler1.ends == 7 + print("h:", handler1) + assert handler1.errors == 0 + # during LLMChain + assert handler1.text == 2 + + assert handler2.llm_starts == 2 + assert handler2.llm_ends == 2 + assert ( + handler2.chain_starts + == handler2.tool_starts + == handler2.tool_ends + == handler2.chain_ends + == 0 + ) + + +@pytest.mark.asyncio +async def test_agent_async_iterator_with_callbacks() -> None: + """Test react chain async iterator with callbacks by setting verbose globally.""" + handler1 = FakeCallbackHandler() + handler2 = FakeCallbackHandler() + + bad_action_name = "BadAction" + responses = [ + f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment", + "Oh well\nFinal Answer: curses foiled again", + ] + fake_llm = FakeListLLM(cache=False, responses=responses, callbacks=[handler2]) + + tools = [ + Tool( + name="Search", + func=lambda x: x, + description="Useful for searching", + ), + Tool( + name="Lookup", + func=lambda x: x, + description="Useful for looking up things in a table", + ), + ] + + agent = initialize_agent( + tools, fake_llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True + ) + agent_async_iter = agent.iter( + inputs="when was langchain made", + callbacks=[handler1], + async_=True, + ) + assert isinstance(agent_async_iter, AgentExecutorIterator) + + outputs = [] + async for step in agent_async_iter: + outputs.append(step) + + assert outputs[-1]["output"] == "curses foiled again" + + # 1 top level chain run runs, 2 LLMChain runs, 2 LLM runs, 1 tool run + assert handler1.chain_starts == handler1.chain_ends == 3 + assert handler1.llm_starts == handler1.llm_ends == 2 + assert handler1.tool_starts == 1 + assert handler1.tool_ends == 1 + # 1 extra agent action + assert handler1.starts == 7 + # 1 extra agent end + assert handler1.ends == 7 + assert handler1.errors == 0 + # during LLMChain + assert handler1.text == 2 + + assert handler2.llm_starts == 2 + assert handler2.llm_ends == 2 + assert ( + handler2.chain_starts + == handler2.tool_starts + == handler2.tool_ends + == handler2.chain_ends + == 0 + ) + + +def test_agent_iterator_properties_and_setters() -> None: + """Test properties and setters of AgentExecutorIterator.""" + agent = _get_agent() + agent.tags = None + agent_iter = agent.iter(inputs="when was langchain made") + + assert isinstance(agent_iter, AgentExecutorIterator) + assert isinstance(agent_iter.inputs, dict) + assert isinstance(agent_iter.callbacks, type(None)) + assert isinstance(agent_iter.tags, type(None)) + assert isinstance(agent_iter.agent_executor, AgentExecutor) + + agent_iter.inputs = "New input" # type: ignore + assert isinstance(agent_iter.inputs, dict) + + agent_iter.callbacks = [FakeCallbackHandler()] + assert isinstance(agent_iter.callbacks, list) + + agent_iter.tags = ["test"] + assert isinstance(agent_iter.tags, list) + + new_agent = _get_agent() + agent_iter.agent_executor = new_agent + assert isinstance(agent_iter.agent_executor, AgentExecutor) + + +def test_agent_iterator_reset() -> None: + """Test reset functionality of AgentExecutorIterator.""" + agent = _get_agent() + agent_iter = agent.iter(inputs="when was langchain made") + assert isinstance(agent_iter, AgentExecutorIterator) + + # Perform one iteration + next(agent_iter) + + # Check if properties are updated + assert agent_iter.iterations == 1 + assert agent_iter.time_elapsed > 0.0 + assert agent_iter.intermediate_steps + + # Reset the iterator + agent_iter.reset() + + # Check if properties are reset + assert agent_iter.iterations == 0 + assert agent_iter.time_elapsed == 0.0 + assert not agent_iter.intermediate_steps + + +def test_agent_iterator_output_structure() -> None: + """Test the output structure of AgentExecutorIterator.""" + agent = _get_agent() + agent_iter = agent.iter(inputs="when was langchain made") + + for step in agent_iter: + assert isinstance(step, dict) + if "intermediate_step" in step: + assert isinstance(step["intermediate_step"], list) + elif "output" in step: + assert isinstance(step["output"], str) + else: + assert False, "Unexpected output structure" + + +@pytest.mark.asyncio +async def test_agent_async_iterator_output_structure() -> None: + """Test the async output structure of AgentExecutorIterator.""" + agent = _get_agent() + agent_async_iter = agent.iter(inputs="when was langchain made", async_=True) + + assert isinstance(agent_async_iter, AgentExecutorIterator) + async for step in agent_async_iter: + assert isinstance(step, dict) + if "intermediate_step" in step: + assert isinstance(step["intermediate_step"], list) + elif "output" in step: + assert isinstance(step["output"], str) + else: + assert False, "Unexpected output structure" + + +def test_agent_iterator_empty_input() -> None: + """Test AgentExecutorIterator with empty input.""" + agent = _get_agent() + agent_iter = agent.iter(inputs="") + + outputs = [] + for step in agent_iter: + outputs.append(step) + + assert isinstance(outputs[-1], dict) + assert outputs[-1]["output"] # Check if there is an output + + +def test_agent_iterator_custom_stopping_condition() -> None: + """Test AgentExecutorIterator with a custom stopping condition.""" + agent = _get_agent() + + class CustomAgentExecutorIterator(AgentExecutorIterator): + def _should_continue(self) -> bool: + return self.iterations < 2 # Custom stopping condition + + agent_iter = CustomAgentExecutorIterator(agent, inputs="when was langchain made") + + outputs = [] + for step in agent_iter: + outputs.append(step) + + assert len(outputs) == 2 # Check if the custom stopping condition is respected + + +def test_agent_iterator_failing_tool() -> None: + """Test AgentExecutorIterator with a tool that raises an exception.""" + + # Get agent for testing. + bad_action_name = "FailingTool" + responses = [ + f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment", + "Oh well\nFinal Answer: curses foiled again", + ] + fake_llm = FakeListLLM(responses=responses) + + tools = [ + Tool( + name="FailingTool", + func=lambda x: 1 / 0, # This tool will raise a ZeroDivisionError + description="A tool that fails", + ), + ] + + agent = initialize_agent( + tools, fake_llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True + ) + + agent_iter = agent.iter(inputs="when was langchain made") + assert isinstance(agent_iter, AgentExecutorIterator) + # initialise iterator + iter(agent_iter) + + with pytest.raises(ZeroDivisionError): + next(agent_iter) diff --git a/libs/langchain/tests/unit_tests/agents/test_public_api.py b/libs/langchain/tests/unit_tests/agents/test_public_api.py index 1ae893edfa..9040e48475 100644 --- a/libs/langchain/tests/unit_tests/agents/test_public_api.py +++ b/libs/langchain/tests/unit_tests/agents/test_public_api.py @@ -3,6 +3,7 @@ from langchain.agents import __all__ as agents_all _EXPECTED = [ "Agent", "AgentExecutor", + "AgentExecutorIterator", "AgentOutputParser", "AgentType", "BaseMultiActionAgent",