mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Implement AgentExecutorIterator (#6929)
- Description: Implements a `.iter()` method for the `AgentExecutor` class. This allows hooking into and intercepting intermediate agent steps. - Issue: #6925 - Dependencies: None - Tag maintainer: @vowelparrot @agola11 - Twitter handle: @SlapDron3 @lacicocodes --------- Co-authored-by: Lacico <Lacicocodes@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
77bf75c236
commit
961a0e200f
245
docs/extras/modules/agents/how_to/agent_iter.ipynb
Normal file
245
docs/extras/modules/agents/how_to/agent_iter.ipynb
Normal file
@ -0,0 +1,245 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "feb31cc6",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Running Agent as an Iterator\n",
|
||||||
|
"\n",
|
||||||
|
"To demonstrate the `AgentExecutorIterator` functionality, we will set up a problem where an Agent must:\n",
|
||||||
|
"\n",
|
||||||
|
"- Retrieve three prime numbers from a Tool\n",
|
||||||
|
"- Multiply these together. \n",
|
||||||
|
"\n",
|
||||||
|
"In this simple problem we can demonstrate adding some logic to verify intermediate steps by checking whether their outputs are prime."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "8167db11",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"import dotenv\n",
|
||||||
|
"import pydantic\n",
|
||||||
|
"from langchain.agents import AgentExecutor, initialize_agent, AgentType\n",
|
||||||
|
"from langchain.schema import AgentFinish\n",
|
||||||
|
"from langchain.agents.tools import Tool\n",
|
||||||
|
"from langchain import LLMMathChain\n",
|
||||||
|
"from langchain.chat_models import ChatOpenAI"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "7e41b9e6",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Uncomment if you have a .env in root of repo contains OPENAI_API_KEY\n",
|
||||||
|
"# dotenv.load_dotenv(\"../../../../../.env\")\n",
|
||||||
|
"\n",
|
||||||
|
"# need to use GPT-4 here as GPT-3.5 does not understand, however hard you insist, that\n",
|
||||||
|
"# it should use the calculator to perform the final calculation\n",
|
||||||
|
"llm = ChatOpenAI(temperature=0, model=\"gpt-4\")\n",
|
||||||
|
"llm_math_chain = LLMMathChain.from_llm(llm=llm, verbose=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "81e88aa5",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Define tools which provide:\n",
|
||||||
|
"- The `n`th prime number (using a small subset for this example) \n",
|
||||||
|
"- The LLMMathChain to act as a calculator"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "86f04b55",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"primes = {998: 7901, 999: 7907, 1000: 7919}\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"class CalculatorInput(pydantic.BaseModel):\n",
|
||||||
|
" question: str = pydantic.Field()\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"class PrimeInput(pydantic.BaseModel):\n",
|
||||||
|
" n: int = pydantic.Field()\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def is_prime(n: int) -> bool:\n",
|
||||||
|
" if n <= 1 or (n % 2 == 0 and n > 2):\n",
|
||||||
|
" return False\n",
|
||||||
|
" for i in range(3, int(n**0.5) + 1, 2):\n",
|
||||||
|
" if n % i == 0:\n",
|
||||||
|
" return False\n",
|
||||||
|
" return True\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def get_prime(n: int, primes: dict = primes) -> str:\n",
|
||||||
|
" return str(primes.get(int(n)))\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"async def aget_prime(n: int, primes: dict = primes) -> str:\n",
|
||||||
|
" return str(primes.get(int(n)))\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"tools = [\n",
|
||||||
|
" Tool(\n",
|
||||||
|
" name=\"GetPrime\",\n",
|
||||||
|
" func=get_prime,\n",
|
||||||
|
" description=\"A tool that returns the `n`th prime number\",\n",
|
||||||
|
" args_schema=PrimeInput,\n",
|
||||||
|
" coroutine=aget_prime,\n",
|
||||||
|
" ),\n",
|
||||||
|
" Tool.from_function(\n",
|
||||||
|
" func=llm_math_chain.run,\n",
|
||||||
|
" name=\"Calculator\",\n",
|
||||||
|
" description=\"Useful for when you need to compute mathematical expressions\",\n",
|
||||||
|
" args_schema=CalculatorInput,\n",
|
||||||
|
" coroutine=llm_math_chain.arun,\n",
|
||||||
|
" ),\n",
|
||||||
|
"]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "0e660ee6",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Construct the agent. We will use the default agent type here."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "21c775b0",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"agent = initialize_agent(\n",
|
||||||
|
" tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "a233fe4e",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Run the iteration and perform a custom check on certain steps:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "582d61f4",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3mI need to find the 998th, 999th and 1000th prime numbers first.\n",
|
||||||
|
"Action: GetPrime\n",
|
||||||
|
"Action Input: 998\u001b[0m\n",
|
||||||
|
"Observation: \u001b[36;1m\u001b[1;3m7901\u001b[0m\n",
|
||||||
|
"Thought:Checking whether 7901 is prime...\n",
|
||||||
|
"Should the agent continue (Y/n)?:\n",
|
||||||
|
"Y\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3mI have the 998th prime number. Now I need to find the 999th prime number.\n",
|
||||||
|
"Action: GetPrime\n",
|
||||||
|
"Action Input: 999\u001b[0m\n",
|
||||||
|
"Observation: \u001b[36;1m\u001b[1;3m7907\u001b[0m\n",
|
||||||
|
"Thought:Checking whether 7907 is prime...\n",
|
||||||
|
"Should the agent continue (Y/n)?:\n",
|
||||||
|
"Y\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3mI have the 999th prime number. Now I need to find the 1000th prime number.\n",
|
||||||
|
"Action: GetPrime\n",
|
||||||
|
"Action Input: 1000\u001b[0m\n",
|
||||||
|
"Observation: \u001b[36;1m\u001b[1;3m7919\u001b[0m\n",
|
||||||
|
"Thought:Checking whether 7919 is prime...\n",
|
||||||
|
"Should the agent continue (Y/n)?:\n",
|
||||||
|
"Y\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3mI have all three prime numbers. Now I need to calculate the product of these numbers.\n",
|
||||||
|
"Action: Calculator\n",
|
||||||
|
"Action Input: 7901 * 7907 * 7919\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||||
|
"7901 * 7907 * 7919\u001b[32;1m\u001b[1;3m```text\n",
|
||||||
|
"7901 * 7907 * 7919\n",
|
||||||
|
"```\n",
|
||||||
|
"...numexpr.evaluate(\"7901 * 7907 * 7919\")...\n",
|
||||||
|
"\u001b[0m\n",
|
||||||
|
"Answer: \u001b[33;1m\u001b[1;3m494725326233\u001b[0m\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"Observation: \u001b[33;1m\u001b[1;3mAnswer: 494725326233\u001b[0m\n",
|
||||||
|
"Thought:Should the agent continue (Y/n)?:\n",
|
||||||
|
"Y\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3mI now know the final answer\n",
|
||||||
|
"Final Answer: 494725326233\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"question = \"What is the product of the 998th, 999th and 1000th prime numbers?\"\n",
|
||||||
|
"\n",
|
||||||
|
"for step in agent.iter(question):\n",
|
||||||
|
" if output := step.get(\"intermediate_step\"):\n",
|
||||||
|
" action, value = output[0]\n",
|
||||||
|
" if action.tool == \"GetPrime\":\n",
|
||||||
|
" print(f\"Checking whether {value} is prime...\")\n",
|
||||||
|
" assert is_prime(int(value))\n",
|
||||||
|
" # Ask user if they want to continue\n",
|
||||||
|
" _continue = input(\"Should the agent continue (Y/n)?:\\n\")\n",
|
||||||
|
" if _continue != \"Y\":\n",
|
||||||
|
" break"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "6934ff8e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "venv",
|
||||||
|
"language": "python",
|
||||||
|
"name": "venv"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.3"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -7,6 +7,7 @@ from langchain.agents.agent import (
|
|||||||
BaseSingleActionAgent,
|
BaseSingleActionAgent,
|
||||||
LLMSingleActionAgent,
|
LLMSingleActionAgent,
|
||||||
)
|
)
|
||||||
|
from langchain.agents.agent_iterator import AgentExecutorIterator
|
||||||
from langchain.agents.agent_toolkits import (
|
from langchain.agents.agent_toolkits import (
|
||||||
create_csv_agent,
|
create_csv_agent,
|
||||||
create_json_agent,
|
create_json_agent,
|
||||||
@ -42,6 +43,7 @@ from langchain.agents.tools import Tool, tool
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"Agent",
|
"Agent",
|
||||||
"AgentExecutor",
|
"AgentExecutor",
|
||||||
|
"AgentExecutorIterator",
|
||||||
"AgentOutputParser",
|
"AgentOutputParser",
|
||||||
"AgentType",
|
"AgentType",
|
||||||
"BaseMultiActionAgent",
|
"BaseMultiActionAgent",
|
||||||
|
@ -12,6 +12,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
|||||||
import yaml
|
import yaml
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
|
from langchain.agents.agent_iterator import AgentExecutorIterator
|
||||||
from langchain.agents.agent_types import AgentType
|
from langchain.agents.agent_types import AgentType
|
||||||
from langchain.agents.tools import InvalidTool
|
from langchain.agents.tools import InvalidTool
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
@ -732,6 +733,24 @@ s
|
|||||||
"""Save the underlying agent."""
|
"""Save the underlying agent."""
|
||||||
return self.agent.save(file_path)
|
return self.agent.save(file_path)
|
||||||
|
|
||||||
|
def iter(
|
||||||
|
self,
|
||||||
|
inputs: Any,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
*,
|
||||||
|
include_run_info: bool = False,
|
||||||
|
async_: bool = False,
|
||||||
|
) -> AgentExecutorIterator:
|
||||||
|
"""Enables iteration over steps taken to reach final output."""
|
||||||
|
return AgentExecutorIterator(
|
||||||
|
self,
|
||||||
|
inputs,
|
||||||
|
callbacks,
|
||||||
|
tags=self.tags,
|
||||||
|
include_run_info=include_run_info,
|
||||||
|
async_=async_,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> List[str]:
|
||||||
"""Return the input keys.
|
"""Return the input keys.
|
||||||
|
501
libs/langchain/langchain/agents/agent_iterator.py
Normal file
501
libs/langchain/langchain/agents/agent_iterator.py
Normal file
@ -0,0 +1,501 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from asyncio import CancelledError
|
||||||
|
from functools import wraps
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
NoReturn,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManager,
|
||||||
|
AsyncCallbackManagerForChainRun,
|
||||||
|
CallbackManager,
|
||||||
|
CallbackManagerForChainRun,
|
||||||
|
Callbacks,
|
||||||
|
)
|
||||||
|
from langchain.input import get_color_mapping
|
||||||
|
from langchain.load.dump import dumpd
|
||||||
|
from langchain.schema import RUN_KEY, AgentAction, AgentFinish, RunInfo
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
from langchain.utilities.asyncio import asyncio_timeout
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain.agents.agent import AgentExecutor
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAgentExecutorIterator(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def build_callback_manager(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def rebuild_callback_manager_on_set(
|
||||||
|
setter_method: Callable[..., None]
|
||||||
|
) -> Callable[..., None]:
|
||||||
|
"""Decorator to force setters to rebuild callback mgr"""
|
||||||
|
|
||||||
|
@wraps(setter_method)
|
||||||
|
def wrapper(self: BaseAgentExecutorIterator, *args: Any, **kwargs: Any) -> None:
|
||||||
|
setter_method(self, *args, **kwargs)
|
||||||
|
self.build_callback_manager()
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class AgentExecutorIterator(BaseAgentExecutorIterator):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
agent_executor: AgentExecutor,
|
||||||
|
inputs: Any,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
*,
|
||||||
|
tags: Optional[list[str]] = None,
|
||||||
|
include_run_info: bool = False,
|
||||||
|
async_: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the AgentExecutorIterator with the given AgentExecutor,
|
||||||
|
inputs, and optional callbacks.
|
||||||
|
"""
|
||||||
|
self._agent_executor = agent_executor
|
||||||
|
self.inputs = inputs
|
||||||
|
self.async_ = async_
|
||||||
|
# build callback manager on tags setter
|
||||||
|
self._callbacks = callbacks
|
||||||
|
self.tags = tags
|
||||||
|
self.include_run_info = include_run_info
|
||||||
|
self.run_manager = None
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
_callback_manager: Union[AsyncCallbackManager, CallbackManager]
|
||||||
|
_inputs: dict[str, str]
|
||||||
|
_final_outputs: Optional[dict[str, str]]
|
||||||
|
run_manager: Optional[
|
||||||
|
Union[AsyncCallbackManagerForChainRun, CallbackManagerForChainRun]
|
||||||
|
]
|
||||||
|
timeout_manager: Any # TODO: Fix a type here; the shim makes it tricky.
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> dict[str, str]:
|
||||||
|
return self._inputs
|
||||||
|
|
||||||
|
@inputs.setter
|
||||||
|
def inputs(self, inputs: Any) -> None:
|
||||||
|
self._inputs = self.agent_executor.prep_inputs(inputs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def callbacks(self) -> Callbacks:
|
||||||
|
return self._callbacks
|
||||||
|
|
||||||
|
@callbacks.setter
|
||||||
|
@rebuild_callback_manager_on_set
|
||||||
|
def callbacks(self, callbacks: Callbacks) -> None:
|
||||||
|
"""When callbacks are changed after __init__, rebuild callback mgr"""
|
||||||
|
self._callbacks = callbacks
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tags(self) -> Optional[List[str]]:
|
||||||
|
return self._tags
|
||||||
|
|
||||||
|
@tags.setter
|
||||||
|
@rebuild_callback_manager_on_set
|
||||||
|
def tags(self, tags: Optional[List[str]]) -> None:
|
||||||
|
"""When tags are changed after __init__, rebuild callback mgr"""
|
||||||
|
self._tags = tags
|
||||||
|
|
||||||
|
@property
|
||||||
|
def agent_executor(self) -> AgentExecutor:
|
||||||
|
return self._agent_executor
|
||||||
|
|
||||||
|
@agent_executor.setter
|
||||||
|
@rebuild_callback_manager_on_set
|
||||||
|
def agent_executor(self, agent_executor: AgentExecutor) -> None:
|
||||||
|
self._agent_executor = agent_executor
|
||||||
|
# force re-prep inputs in case agent_executor's prep_inputs fn changed
|
||||||
|
self.inputs = self.inputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def callback_manager(self) -> Union[AsyncCallbackManager, CallbackManager]:
|
||||||
|
return self._callback_manager
|
||||||
|
|
||||||
|
def build_callback_manager(self) -> None:
|
||||||
|
"""
|
||||||
|
Create and configure the callback manager based on the current
|
||||||
|
callbacks and tags.
|
||||||
|
"""
|
||||||
|
CallbackMgr: Union[Type[AsyncCallbackManager], Type[CallbackManager]] = (
|
||||||
|
AsyncCallbackManager if self.async_ else CallbackManager
|
||||||
|
)
|
||||||
|
self._callback_manager = CallbackMgr.configure(
|
||||||
|
self.callbacks,
|
||||||
|
self.agent_executor.callbacks,
|
||||||
|
self.agent_executor.verbose,
|
||||||
|
self.tags,
|
||||||
|
self.agent_executor.tags,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name_to_tool_map(self) -> dict[str, BaseTool]:
|
||||||
|
return {tool.name: tool for tool in self.agent_executor.tools}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def color_mapping(self) -> dict[str, str]:
|
||||||
|
return get_color_mapping(
|
||||||
|
[tool.name for tool in self.agent_executor.tools],
|
||||||
|
excluded_colors=["green", "red"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""
|
||||||
|
Reset the iterator to its initial state, clearing intermediate steps,
|
||||||
|
iterations, and time elapsed.
|
||||||
|
"""
|
||||||
|
logger.debug("(Re)setting AgentExecutorIterator to fresh state")
|
||||||
|
self.intermediate_steps: list[tuple[AgentAction, str]] = []
|
||||||
|
self.iterations = 0
|
||||||
|
# maybe better to start these on the first __anext__ call?
|
||||||
|
self.time_elapsed = 0.0
|
||||||
|
self.start_time = time.time()
|
||||||
|
self._final_outputs = None
|
||||||
|
|
||||||
|
def update_iterations(self) -> None:
|
||||||
|
"""
|
||||||
|
Increment the number of iterations and update the time elapsed.
|
||||||
|
"""
|
||||||
|
self.iterations += 1
|
||||||
|
self.time_elapsed = time.time() - self.start_time
|
||||||
|
logger.debug(
|
||||||
|
f"Agent Iterations: {self.iterations} ({self.time_elapsed:.2f}s elapsed)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def raise_stopiteration(self, output: Any) -> NoReturn:
|
||||||
|
"""
|
||||||
|
Raise a StopIteration exception with the given output.
|
||||||
|
"""
|
||||||
|
logger.debug("Chain end: stop iteration")
|
||||||
|
raise StopIteration(output)
|
||||||
|
|
||||||
|
async def raise_stopasynciteration(self, output: Any) -> NoReturn:
|
||||||
|
"""
|
||||||
|
Raise a StopAsyncIteration exception with the given output.
|
||||||
|
Close the timeout context manager.
|
||||||
|
"""
|
||||||
|
logger.debug("Chain end: stop async iteration")
|
||||||
|
if self.timeout_manager is not None:
|
||||||
|
await self.timeout_manager.__aexit__(None, None, None)
|
||||||
|
raise StopAsyncIteration(output)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def final_outputs(self) -> Optional[dict[str, Any]]:
|
||||||
|
return self._final_outputs
|
||||||
|
|
||||||
|
@final_outputs.setter
|
||||||
|
def final_outputs(self, outputs: Optional[Dict[str, Any]]) -> None:
|
||||||
|
# have access to intermediate steps by design in iterator,
|
||||||
|
# so return only outputs may as well always be true.
|
||||||
|
|
||||||
|
self._final_outputs = None
|
||||||
|
if outputs:
|
||||||
|
prepared_outputs: dict[str, Any] = self.agent_executor.prep_outputs(
|
||||||
|
self.inputs, outputs, return_only_outputs=True
|
||||||
|
)
|
||||||
|
if self.include_run_info and self.run_manager is not None:
|
||||||
|
logger.debug("Assign run key")
|
||||||
|
prepared_outputs[RUN_KEY] = RunInfo(run_id=self.run_manager.run_id)
|
||||||
|
self._final_outputs = prepared_outputs
|
||||||
|
|
||||||
|
def __iter__(self: "AgentExecutorIterator") -> "AgentExecutorIterator":
|
||||||
|
logger.debug("Initialising AgentExecutorIterator")
|
||||||
|
self.reset()
|
||||||
|
assert isinstance(self.callback_manager, CallbackManager)
|
||||||
|
self.run_manager = self.callback_manager.on_chain_start(
|
||||||
|
dumpd(self.agent_executor),
|
||||||
|
self.inputs,
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __aiter__(self) -> "AgentExecutorIterator":
|
||||||
|
"""
|
||||||
|
N.B. __aiter__ must be a normal method, so need to initialise async run manager
|
||||||
|
on first __anext__ call where we can await it
|
||||||
|
"""
|
||||||
|
logger.debug("Initialising AgentExecutorIterator (async)")
|
||||||
|
self.reset()
|
||||||
|
if self.agent_executor.max_execution_time:
|
||||||
|
self.timeout_manager = asyncio_timeout(
|
||||||
|
self.agent_executor.max_execution_time
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.timeout_manager = None
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _on_first_step(self) -> None:
|
||||||
|
"""
|
||||||
|
Perform any necessary setup for the first step of the synchronous iterator.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _on_first_async_step(self) -> None:
|
||||||
|
"""
|
||||||
|
Perform any necessary setup for the first step of the asynchronous iterator.
|
||||||
|
"""
|
||||||
|
# on first step, need to await callback manager and start async timeout ctxmgr
|
||||||
|
if self.iterations == 0:
|
||||||
|
assert isinstance(self.callback_manager, AsyncCallbackManager)
|
||||||
|
self.run_manager = await self.callback_manager.on_chain_start(
|
||||||
|
dumpd(self.agent_executor),
|
||||||
|
self.inputs,
|
||||||
|
)
|
||||||
|
if self.timeout_manager:
|
||||||
|
await self.timeout_manager.__aenter__()
|
||||||
|
|
||||||
|
def __next__(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
AgentExecutor AgentExecutorIterator
|
||||||
|
__call__ (__iter__ ->) __next__
|
||||||
|
_call <=> _call_next
|
||||||
|
_take_next_step _take_next_step
|
||||||
|
"""
|
||||||
|
# first step
|
||||||
|
if self.iterations == 0:
|
||||||
|
self._on_first_step()
|
||||||
|
# N.B. timeout taken care of by "_should_continue" in sync case
|
||||||
|
try:
|
||||||
|
return self._call_next()
|
||||||
|
except StopIteration:
|
||||||
|
raise
|
||||||
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
|
if self.run_manager:
|
||||||
|
self.run_manager.on_chain_error(e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def __anext__(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
AgentExecutor AgentExecutorIterator
|
||||||
|
acall (__aiter__ ->) __anext__
|
||||||
|
_acall <=> _acall_next
|
||||||
|
_atake_next_step _atake_next_step
|
||||||
|
"""
|
||||||
|
if self.iterations == 0:
|
||||||
|
await self._on_first_async_step()
|
||||||
|
try:
|
||||||
|
return await self._acall_next()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
raise
|
||||||
|
except (TimeoutError, CancelledError):
|
||||||
|
await self.timeout_manager.__aexit__(None, None, None)
|
||||||
|
self.timeout_manager = None
|
||||||
|
return await self._astop()
|
||||||
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
|
if self.run_manager:
|
||||||
|
assert isinstance(self.run_manager, AsyncCallbackManagerForChainRun)
|
||||||
|
await self.run_manager.on_chain_error(e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _execute_next_step(
|
||||||
|
self, run_manager: Optional[CallbackManagerForChainRun]
|
||||||
|
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
|
||||||
|
"""
|
||||||
|
Execute the next step in the chain using the
|
||||||
|
AgentExecutor's _take_next_step method.
|
||||||
|
"""
|
||||||
|
return self.agent_executor._take_next_step(
|
||||||
|
self.name_to_tool_map,
|
||||||
|
self.color_mapping,
|
||||||
|
self.inputs,
|
||||||
|
self.intermediate_steps,
|
||||||
|
run_manager=run_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _execute_next_async_step(
|
||||||
|
self, run_manager: Optional[AsyncCallbackManagerForChainRun]
|
||||||
|
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
|
||||||
|
"""
|
||||||
|
Execute the next step in the chain using the
|
||||||
|
AgentExecutor's _atake_next_step method.
|
||||||
|
"""
|
||||||
|
return await self.agent_executor._atake_next_step(
|
||||||
|
self.name_to_tool_map,
|
||||||
|
self.color_mapping,
|
||||||
|
self.inputs,
|
||||||
|
self.intermediate_steps,
|
||||||
|
run_manager=run_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _process_next_step_output(
|
||||||
|
self,
|
||||||
|
next_step_output: Union[AgentFinish, List[Tuple[AgentAction, str]]],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun],
|
||||||
|
) -> Dict[str, Union[str, List[Tuple[AgentAction, str]]]]:
|
||||||
|
"""
|
||||||
|
Process the output of the next step,
|
||||||
|
handling AgentFinish and tool return cases.
|
||||||
|
"""
|
||||||
|
logger.debug("Processing output of Agent loop step")
|
||||||
|
if isinstance(next_step_output, AgentFinish):
|
||||||
|
logger.debug(
|
||||||
|
"Hit AgentFinish: _return -> on_chain_end -> run final output logic"
|
||||||
|
)
|
||||||
|
output = self.agent_executor._return(
|
||||||
|
next_step_output, self.intermediate_steps, run_manager=run_manager
|
||||||
|
)
|
||||||
|
if self.run_manager:
|
||||||
|
self.run_manager.on_chain_end(output)
|
||||||
|
self.final_outputs = output
|
||||||
|
return output
|
||||||
|
|
||||||
|
self.intermediate_steps.extend(next_step_output)
|
||||||
|
logger.debug("Updated intermediate_steps with step output")
|
||||||
|
|
||||||
|
# Check for tool return
|
||||||
|
if len(next_step_output) == 1:
|
||||||
|
next_step_action = next_step_output[0]
|
||||||
|
tool_return = self.agent_executor._get_tool_return(next_step_action)
|
||||||
|
if tool_return is not None:
|
||||||
|
output = self.agent_executor._return(
|
||||||
|
tool_return, self.intermediate_steps, run_manager=run_manager
|
||||||
|
)
|
||||||
|
if self.run_manager:
|
||||||
|
self.run_manager.on_chain_end(output)
|
||||||
|
self.final_outputs = output
|
||||||
|
return output
|
||||||
|
|
||||||
|
output = {"intermediate_step": next_step_output}
|
||||||
|
return output
|
||||||
|
|
||||||
|
async def _aprocess_next_step_output(
|
||||||
|
self,
|
||||||
|
next_step_output: Union[AgentFinish, List[Tuple[AgentAction, str]]],
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForChainRun],
|
||||||
|
) -> Dict[str, Union[str, List[Tuple[AgentAction, str]]]]:
|
||||||
|
"""
|
||||||
|
Process the output of the next async step,
|
||||||
|
handling AgentFinish and tool return cases.
|
||||||
|
"""
|
||||||
|
logger.debug("Processing output of async Agent loop step")
|
||||||
|
if isinstance(next_step_output, AgentFinish):
|
||||||
|
logger.debug(
|
||||||
|
"Hit AgentFinish: _areturn -> on_chain_end -> run final output logic"
|
||||||
|
)
|
||||||
|
output = await self.agent_executor._areturn(
|
||||||
|
next_step_output, self.intermediate_steps, run_manager=run_manager
|
||||||
|
)
|
||||||
|
if run_manager:
|
||||||
|
await run_manager.on_chain_end(output)
|
||||||
|
self.final_outputs = output
|
||||||
|
return output
|
||||||
|
|
||||||
|
self.intermediate_steps.extend(next_step_output)
|
||||||
|
logger.debug("Updated intermediate_steps with step output")
|
||||||
|
|
||||||
|
# Check for tool return
|
||||||
|
if len(next_step_output) == 1:
|
||||||
|
next_step_action = next_step_output[0]
|
||||||
|
tool_return = self.agent_executor._get_tool_return(next_step_action)
|
||||||
|
if tool_return is not None:
|
||||||
|
output = await self.agent_executor._areturn(
|
||||||
|
tool_return, self.intermediate_steps, run_manager=run_manager
|
||||||
|
)
|
||||||
|
if run_manager:
|
||||||
|
await run_manager.on_chain_end(output)
|
||||||
|
self.final_outputs = output
|
||||||
|
return output
|
||||||
|
|
||||||
|
output = {"intermediate_step": next_step_output}
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _stop(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Stop the iterator and raise a StopIteration exception with the stopped response.
|
||||||
|
"""
|
||||||
|
logger.warning("Stopping agent prematurely due to triggering stop condition")
|
||||||
|
# this manually constructs agent finish with output key
|
||||||
|
output = self.agent_executor.agent.return_stopped_response(
|
||||||
|
self.agent_executor.early_stopping_method,
|
||||||
|
self.intermediate_steps,
|
||||||
|
**self.inputs,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
isinstance(self.run_manager, CallbackManagerForChainRun)
|
||||||
|
or self.run_manager is None
|
||||||
|
)
|
||||||
|
returned_output = self.agent_executor._return(
|
||||||
|
output, self.intermediate_steps, run_manager=self.run_manager
|
||||||
|
)
|
||||||
|
self.final_outputs = returned_output
|
||||||
|
return returned_output
|
||||||
|
|
||||||
|
async def _astop(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Stop the async iterator and raise a StopAsyncIteration exception with
|
||||||
|
the stopped response.
|
||||||
|
"""
|
||||||
|
logger.warning("Stopping agent prematurely due to triggering stop condition")
|
||||||
|
output = self.agent_executor.agent.return_stopped_response(
|
||||||
|
self.agent_executor.early_stopping_method,
|
||||||
|
self.intermediate_steps,
|
||||||
|
**self.inputs,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
isinstance(self.run_manager, AsyncCallbackManagerForChainRun)
|
||||||
|
or self.run_manager is None
|
||||||
|
)
|
||||||
|
returned_output = await self.agent_executor._areturn(
|
||||||
|
output, self.intermediate_steps, run_manager=self.run_manager
|
||||||
|
)
|
||||||
|
self.final_outputs = returned_output
|
||||||
|
return returned_output
|
||||||
|
|
||||||
|
def _call_next(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Perform a single iteration of the synchronous AgentExecutorIterator.
|
||||||
|
"""
|
||||||
|
# final output already reached: stopiteration (final output)
|
||||||
|
if self.final_outputs is not None:
|
||||||
|
self.raise_stopiteration(self.final_outputs)
|
||||||
|
# timeout/max iterations: stopiteration (stopped response)
|
||||||
|
if not self.agent_executor._should_continue(self.iterations, self.time_elapsed):
|
||||||
|
return self._stop()
|
||||||
|
assert (
|
||||||
|
isinstance(self.run_manager, CallbackManagerForChainRun)
|
||||||
|
or self.run_manager is None
|
||||||
|
)
|
||||||
|
next_step_output = self._execute_next_step(self.run_manager)
|
||||||
|
output = self._process_next_step_output(next_step_output, self.run_manager)
|
||||||
|
self.update_iterations()
|
||||||
|
return output
|
||||||
|
|
||||||
|
async def _acall_next(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Perform a single iteration of the asynchronous AgentExecutorIterator.
|
||||||
|
"""
|
||||||
|
# final output already reached: stopiteration (final output)
|
||||||
|
if self.final_outputs is not None:
|
||||||
|
await self.raise_stopasynciteration(self.final_outputs)
|
||||||
|
# timeout/max iterations: stopiteration (stopped response)
|
||||||
|
if not self.agent_executor._should_continue(self.iterations, self.time_elapsed):
|
||||||
|
return await self._astop()
|
||||||
|
assert (
|
||||||
|
isinstance(self.run_manager, AsyncCallbackManagerForChainRun)
|
||||||
|
or self.run_manager is None
|
||||||
|
)
|
||||||
|
next_step_output = await self._execute_next_async_step(self.run_manager)
|
||||||
|
output = await self._aprocess_next_step_output(
|
||||||
|
next_step_output, self.run_manager
|
||||||
|
)
|
||||||
|
self.update_iterations()
|
||||||
|
return output
|
@ -32,6 +32,9 @@ class FakeListLLM(LLM):
|
|||||||
"""Return number of tokens in text."""
|
"""Return number of tokens in text."""
|
||||||
return len(text.split())
|
return len(text.split())
|
||||||
|
|
||||||
|
async def _acall(self, *args: Any, **kwargs: Any) -> str:
|
||||||
|
return self._call(*args, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Mapping[str, Any]:
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
return {}
|
return {}
|
||||||
@ -49,7 +52,8 @@ def _get_agent(**kwargs: Any) -> AgentExecutor:
|
|||||||
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment",
|
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment",
|
||||||
"Oh well\nFinal Answer: curses foiled again",
|
"Oh well\nFinal Answer: curses foiled again",
|
||||||
]
|
]
|
||||||
fake_llm = FakeListLLM(responses=responses)
|
fake_llm = FakeListLLM(cache=False, responses=responses)
|
||||||
|
|
||||||
tools = [
|
tools = [
|
||||||
Tool(
|
Tool(
|
||||||
name="Search",
|
name="Search",
|
||||||
@ -62,6 +66,7 @@ def _get_agent(**kwargs: Any) -> AgentExecutor:
|
|||||||
description="Useful for looking up things in a table",
|
description="Useful for looking up things in a table",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
agent = initialize_agent(
|
agent = initialize_agent(
|
||||||
tools,
|
tools,
|
||||||
fake_llm,
|
fake_llm,
|
||||||
@ -194,6 +199,7 @@ def test_agent_tool_return_direct_in_intermediate_steps() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
resp = agent("when was langchain made")
|
resp = agent("when was langchain made")
|
||||||
|
assert isinstance(resp, dict)
|
||||||
assert resp["output"] == "misalignment"
|
assert resp["output"] == "misalignment"
|
||||||
assert len(resp["intermediate_steps"]) == 1
|
assert len(resp["intermediate_steps"]) == 1
|
||||||
action, _action_intput = resp["intermediate_steps"][0]
|
action, _action_intput = resp["intermediate_steps"][0]
|
||||||
|
360
libs/langchain/tests/unit_tests/agents/test_agent_iterator.py
Normal file
360
libs/langchain/tests/unit_tests/agents/test_agent_iterator.py
Normal file
@ -0,0 +1,360 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.agents import (
|
||||||
|
AgentExecutor,
|
||||||
|
AgentExecutorIterator,
|
||||||
|
AgentType,
|
||||||
|
initialize_agent,
|
||||||
|
)
|
||||||
|
from langchain.agents.tools import Tool
|
||||||
|
from langchain.llms import FakeListLLM
|
||||||
|
from tests.unit_tests.agents.test_agent import _get_agent
|
||||||
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_iterator_bad_action() -> None:
|
||||||
|
"""Test react chain iterator when bad action given."""
|
||||||
|
agent = _get_agent()
|
||||||
|
agent_iter = agent.iter(inputs="when was langchain made")
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for step in agent_iter:
|
||||||
|
outputs.append(step)
|
||||||
|
|
||||||
|
assert isinstance(outputs[-1], dict)
|
||||||
|
assert outputs[-1]["output"] == "curses foiled again"
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_iterator_stopped_early() -> None:
|
||||||
|
"""
|
||||||
|
Test react chain iterator when max iterations or
|
||||||
|
max execution time is exceeded.
|
||||||
|
"""
|
||||||
|
# iteration limit
|
||||||
|
agent = _get_agent(max_iterations=1)
|
||||||
|
agent_iter = agent.iter(inputs="when was langchain made")
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for step in agent_iter:
|
||||||
|
outputs.append(step)
|
||||||
|
# NOTE: we don't use agent.run like in the test for the regular agent executor,
|
||||||
|
# so the dict structure for outputs stays intact
|
||||||
|
assert isinstance(outputs[-1], dict)
|
||||||
|
assert (
|
||||||
|
outputs[-1]["output"] == "Agent stopped due to iteration limit or time limit."
|
||||||
|
)
|
||||||
|
|
||||||
|
# execution time limit
|
||||||
|
agent = _get_agent(max_execution_time=1e-5)
|
||||||
|
agent_iter = agent.iter(inputs="when was langchain made")
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for step in agent_iter:
|
||||||
|
outputs.append(step)
|
||||||
|
assert isinstance(outputs[-1], dict)
|
||||||
|
assert (
|
||||||
|
outputs[-1]["output"] == "Agent stopped due to iteration limit or time limit."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_async_iterator_stopped_early() -> None:
|
||||||
|
"""
|
||||||
|
Test react chain async iterator when max iterations or
|
||||||
|
max execution time is exceeded.
|
||||||
|
"""
|
||||||
|
# iteration limit
|
||||||
|
agent = _get_agent(max_iterations=1)
|
||||||
|
agent_async_iter = agent.iter(inputs="when was langchain made", async_=True)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
assert isinstance(agent_async_iter, AgentExecutorIterator)
|
||||||
|
async for step in agent_async_iter:
|
||||||
|
outputs.append(step)
|
||||||
|
|
||||||
|
assert isinstance(outputs[-1], dict)
|
||||||
|
assert (
|
||||||
|
outputs[-1]["output"] == "Agent stopped due to iteration limit or time limit."
|
||||||
|
)
|
||||||
|
|
||||||
|
# execution time limit
|
||||||
|
agent = _get_agent(max_execution_time=1e-5)
|
||||||
|
agent_async_iter = agent.iter(inputs="when was langchain made", async_=True)
|
||||||
|
assert isinstance(agent_async_iter, AgentExecutorIterator)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
async for step in agent_async_iter:
|
||||||
|
outputs.append(step)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
outputs[-1]["output"] == "Agent stopped due to iteration limit or time limit."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_iterator_with_callbacks() -> None:
|
||||||
|
"""Test react chain iterator with callbacks by setting verbose globally."""
|
||||||
|
handler1 = FakeCallbackHandler()
|
||||||
|
handler2 = FakeCallbackHandler()
|
||||||
|
bad_action_name = "BadAction"
|
||||||
|
responses = [
|
||||||
|
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment",
|
||||||
|
"Oh well\nFinal Answer: curses foiled again",
|
||||||
|
]
|
||||||
|
fake_llm = FakeListLLM(cache=False, responses=responses, callbacks=[handler2])
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
Tool(
|
||||||
|
name="Search",
|
||||||
|
func=lambda x: x,
|
||||||
|
description="Useful for searching",
|
||||||
|
),
|
||||||
|
Tool(
|
||||||
|
name="Lookup",
|
||||||
|
func=lambda x: x,
|
||||||
|
description="Useful for looking up things in a table",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
agent = initialize_agent(
|
||||||
|
tools, fake_llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||||
|
)
|
||||||
|
agent_iter = agent.iter(inputs="when was langchain made", callbacks=[handler1])
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for step in agent_iter:
|
||||||
|
outputs.append(step)
|
||||||
|
assert isinstance(outputs[-1], dict)
|
||||||
|
assert outputs[-1]["output"] == "curses foiled again"
|
||||||
|
|
||||||
|
# 1 top level chain run runs, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
||||||
|
assert handler1.chain_starts == handler1.chain_ends == 3
|
||||||
|
assert handler1.llm_starts == handler1.llm_ends == 2
|
||||||
|
assert handler1.tool_starts == 1
|
||||||
|
assert handler1.tool_ends == 1
|
||||||
|
# 1 extra agent action
|
||||||
|
assert handler1.starts == 7
|
||||||
|
# 1 extra agent end
|
||||||
|
assert handler1.ends == 7
|
||||||
|
print("h:", handler1)
|
||||||
|
assert handler1.errors == 0
|
||||||
|
# during LLMChain
|
||||||
|
assert handler1.text == 2
|
||||||
|
|
||||||
|
assert handler2.llm_starts == 2
|
||||||
|
assert handler2.llm_ends == 2
|
||||||
|
assert (
|
||||||
|
handler2.chain_starts
|
||||||
|
== handler2.tool_starts
|
||||||
|
== handler2.tool_ends
|
||||||
|
== handler2.chain_ends
|
||||||
|
== 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_async_iterator_with_callbacks() -> None:
|
||||||
|
"""Test react chain async iterator with callbacks by setting verbose globally."""
|
||||||
|
handler1 = FakeCallbackHandler()
|
||||||
|
handler2 = FakeCallbackHandler()
|
||||||
|
|
||||||
|
bad_action_name = "BadAction"
|
||||||
|
responses = [
|
||||||
|
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment",
|
||||||
|
"Oh well\nFinal Answer: curses foiled again",
|
||||||
|
]
|
||||||
|
fake_llm = FakeListLLM(cache=False, responses=responses, callbacks=[handler2])
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
Tool(
|
||||||
|
name="Search",
|
||||||
|
func=lambda x: x,
|
||||||
|
description="Useful for searching",
|
||||||
|
),
|
||||||
|
Tool(
|
||||||
|
name="Lookup",
|
||||||
|
func=lambda x: x,
|
||||||
|
description="Useful for looking up things in a table",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
agent = initialize_agent(
|
||||||
|
tools, fake_llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||||
|
)
|
||||||
|
agent_async_iter = agent.iter(
|
||||||
|
inputs="when was langchain made",
|
||||||
|
callbacks=[handler1],
|
||||||
|
async_=True,
|
||||||
|
)
|
||||||
|
assert isinstance(agent_async_iter, AgentExecutorIterator)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
async for step in agent_async_iter:
|
||||||
|
outputs.append(step)
|
||||||
|
|
||||||
|
assert outputs[-1]["output"] == "curses foiled again"
|
||||||
|
|
||||||
|
# 1 top level chain run runs, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
||||||
|
assert handler1.chain_starts == handler1.chain_ends == 3
|
||||||
|
assert handler1.llm_starts == handler1.llm_ends == 2
|
||||||
|
assert handler1.tool_starts == 1
|
||||||
|
assert handler1.tool_ends == 1
|
||||||
|
# 1 extra agent action
|
||||||
|
assert handler1.starts == 7
|
||||||
|
# 1 extra agent end
|
||||||
|
assert handler1.ends == 7
|
||||||
|
assert handler1.errors == 0
|
||||||
|
# during LLMChain
|
||||||
|
assert handler1.text == 2
|
||||||
|
|
||||||
|
assert handler2.llm_starts == 2
|
||||||
|
assert handler2.llm_ends == 2
|
||||||
|
assert (
|
||||||
|
handler2.chain_starts
|
||||||
|
== handler2.tool_starts
|
||||||
|
== handler2.tool_ends
|
||||||
|
== handler2.chain_ends
|
||||||
|
== 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_iterator_properties_and_setters() -> None:
|
||||||
|
"""Test properties and setters of AgentExecutorIterator."""
|
||||||
|
agent = _get_agent()
|
||||||
|
agent.tags = None
|
||||||
|
agent_iter = agent.iter(inputs="when was langchain made")
|
||||||
|
|
||||||
|
assert isinstance(agent_iter, AgentExecutorIterator)
|
||||||
|
assert isinstance(agent_iter.inputs, dict)
|
||||||
|
assert isinstance(agent_iter.callbacks, type(None))
|
||||||
|
assert isinstance(agent_iter.tags, type(None))
|
||||||
|
assert isinstance(agent_iter.agent_executor, AgentExecutor)
|
||||||
|
|
||||||
|
agent_iter.inputs = "New input" # type: ignore
|
||||||
|
assert isinstance(agent_iter.inputs, dict)
|
||||||
|
|
||||||
|
agent_iter.callbacks = [FakeCallbackHandler()]
|
||||||
|
assert isinstance(agent_iter.callbacks, list)
|
||||||
|
|
||||||
|
agent_iter.tags = ["test"]
|
||||||
|
assert isinstance(agent_iter.tags, list)
|
||||||
|
|
||||||
|
new_agent = _get_agent()
|
||||||
|
agent_iter.agent_executor = new_agent
|
||||||
|
assert isinstance(agent_iter.agent_executor, AgentExecutor)
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_iterator_reset() -> None:
|
||||||
|
"""Test reset functionality of AgentExecutorIterator."""
|
||||||
|
agent = _get_agent()
|
||||||
|
agent_iter = agent.iter(inputs="when was langchain made")
|
||||||
|
assert isinstance(agent_iter, AgentExecutorIterator)
|
||||||
|
|
||||||
|
# Perform one iteration
|
||||||
|
next(agent_iter)
|
||||||
|
|
||||||
|
# Check if properties are updated
|
||||||
|
assert agent_iter.iterations == 1
|
||||||
|
assert agent_iter.time_elapsed > 0.0
|
||||||
|
assert agent_iter.intermediate_steps
|
||||||
|
|
||||||
|
# Reset the iterator
|
||||||
|
agent_iter.reset()
|
||||||
|
|
||||||
|
# Check if properties are reset
|
||||||
|
assert agent_iter.iterations == 0
|
||||||
|
assert agent_iter.time_elapsed == 0.0
|
||||||
|
assert not agent_iter.intermediate_steps
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_iterator_output_structure() -> None:
|
||||||
|
"""Test the output structure of AgentExecutorIterator."""
|
||||||
|
agent = _get_agent()
|
||||||
|
agent_iter = agent.iter(inputs="when was langchain made")
|
||||||
|
|
||||||
|
for step in agent_iter:
|
||||||
|
assert isinstance(step, dict)
|
||||||
|
if "intermediate_step" in step:
|
||||||
|
assert isinstance(step["intermediate_step"], list)
|
||||||
|
elif "output" in step:
|
||||||
|
assert isinstance(step["output"], str)
|
||||||
|
else:
|
||||||
|
assert False, "Unexpected output structure"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_async_iterator_output_structure() -> None:
|
||||||
|
"""Test the async output structure of AgentExecutorIterator."""
|
||||||
|
agent = _get_agent()
|
||||||
|
agent_async_iter = agent.iter(inputs="when was langchain made", async_=True)
|
||||||
|
|
||||||
|
assert isinstance(agent_async_iter, AgentExecutorIterator)
|
||||||
|
async for step in agent_async_iter:
|
||||||
|
assert isinstance(step, dict)
|
||||||
|
if "intermediate_step" in step:
|
||||||
|
assert isinstance(step["intermediate_step"], list)
|
||||||
|
elif "output" in step:
|
||||||
|
assert isinstance(step["output"], str)
|
||||||
|
else:
|
||||||
|
assert False, "Unexpected output structure"
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_iterator_empty_input() -> None:
|
||||||
|
"""Test AgentExecutorIterator with empty input."""
|
||||||
|
agent = _get_agent()
|
||||||
|
agent_iter = agent.iter(inputs="")
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for step in agent_iter:
|
||||||
|
outputs.append(step)
|
||||||
|
|
||||||
|
assert isinstance(outputs[-1], dict)
|
||||||
|
assert outputs[-1]["output"] # Check if there is an output
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_iterator_custom_stopping_condition() -> None:
|
||||||
|
"""Test AgentExecutorIterator with a custom stopping condition."""
|
||||||
|
agent = _get_agent()
|
||||||
|
|
||||||
|
class CustomAgentExecutorIterator(AgentExecutorIterator):
|
||||||
|
def _should_continue(self) -> bool:
|
||||||
|
return self.iterations < 2 # Custom stopping condition
|
||||||
|
|
||||||
|
agent_iter = CustomAgentExecutorIterator(agent, inputs="when was langchain made")
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for step in agent_iter:
|
||||||
|
outputs.append(step)
|
||||||
|
|
||||||
|
assert len(outputs) == 2 # Check if the custom stopping condition is respected
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_iterator_failing_tool() -> None:
|
||||||
|
"""Test AgentExecutorIterator with a tool that raises an exception."""
|
||||||
|
|
||||||
|
# Get agent for testing.
|
||||||
|
bad_action_name = "FailingTool"
|
||||||
|
responses = [
|
||||||
|
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment",
|
||||||
|
"Oh well\nFinal Answer: curses foiled again",
|
||||||
|
]
|
||||||
|
fake_llm = FakeListLLM(responses=responses)
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
Tool(
|
||||||
|
name="FailingTool",
|
||||||
|
func=lambda x: 1 / 0, # This tool will raise a ZeroDivisionError
|
||||||
|
description="A tool that fails",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
agent = initialize_agent(
|
||||||
|
tools, fake_llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_iter = agent.iter(inputs="when was langchain made")
|
||||||
|
assert isinstance(agent_iter, AgentExecutorIterator)
|
||||||
|
# initialise iterator
|
||||||
|
iter(agent_iter)
|
||||||
|
|
||||||
|
with pytest.raises(ZeroDivisionError):
|
||||||
|
next(agent_iter)
|
@ -3,6 +3,7 @@ from langchain.agents import __all__ as agents_all
|
|||||||
_EXPECTED = [
|
_EXPECTED = [
|
||||||
"Agent",
|
"Agent",
|
||||||
"AgentExecutor",
|
"AgentExecutor",
|
||||||
|
"AgentExecutorIterator",
|
||||||
"AgentOutputParser",
|
"AgentOutputParser",
|
||||||
"AgentType",
|
"AgentType",
|
||||||
"BaseMultiActionAgent",
|
"BaseMultiActionAgent",
|
||||||
|
Loading…
Reference in New Issue
Block a user