Feature: AgentExecutor execution time limit (#2399)

`AgentExecutor` already has support for limiting the number of
iterations. But the amount of time taken for each iteration can vary
quite a bit, so it is difficult to place limits on the execution time.
This PR adds a new field `max_execution_time` to the `AgentExecutor`
model. When called asynchronously, the agent loop is wrapped in an
`asyncio.timeout()` context which triggers the early stopping response
if the time limit is reached. When called synchronously, the agent loop
checks for both the max_iteration limit and the time limit after each
iteration.

When used asynchronously `max_execution_time` gives really tight control
over the max time for an execution chain. When used synchronously, the
chain can unfortunately exceed max_execution_time, but it still gives
more control than trying to estimate the number of max_iterations needed
to cap the execution time.

---------

Co-authored-by: Zachary Jones <zjones@zetaglobal.com>
doc
Zach Jones 1 year ago committed by GitHub
parent 5b34931948
commit 13d1df2140
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,6 +3,7 @@ from __future__ import annotations
import json
import logging
import time
from abc import abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
@ -26,6 +27,7 @@ from langchain.schema import (
BaseOutputParser,
)
from langchain.tools.base import BaseTool
from langchain.utilities.asyncio import asyncio_timeout
logger = logging.getLogger()
@ -88,7 +90,9 @@ class BaseSingleActionAgent(BaseModel):
"""Return response when agent has been stopped due to max iterations."""
if early_stopping_method == "force":
# `force` just returns a constant string
return AgentFinish({"output": "Agent stopped due to max iterations."}, "")
return AgentFinish(
{"output": "Agent stopped due to iteration limit or time limit."}, ""
)
else:
raise ValueError(
f"Got unsupported early_stopping_method `{early_stopping_method}`"
@ -506,7 +510,9 @@ class Agent(BaseSingleActionAgent):
"""Return response when agent has been stopped due to max iterations."""
if early_stopping_method == "force":
# `force` just returns a constant string
return AgentFinish({"output": "Agent stopped due to max iterations."}, "")
return AgentFinish(
{"output": "Agent stopped due to iteration limit or time limit."}, ""
)
elif early_stopping_method == "generate":
# Generate does one final forward pass
thoughts = ""
@ -555,6 +561,7 @@ class AgentExecutor(Chain):
tools: Sequence[BaseTool]
return_intermediate_steps: bool = False
max_iterations: Optional[int] = 15
max_execution_time: Optional[float] = None
early_stopping_method: str = "force"
@classmethod
@ -633,11 +640,16 @@ class AgentExecutor(Chain):
"""Lookup tool by name."""
return {tool.name: tool for tool in self.tools}[name]
def _should_continue(self, iterations: int) -> bool:
if self.max_iterations is None:
return True
else:
return iterations < self.max_iterations
def _should_continue(self, iterations: int, time_elapsed: float) -> bool:
if self.max_iterations is not None and iterations >= self.max_iterations:
return False
if (
self.max_execution_time is not None
and time_elapsed >= self.max_execution_time
):
return False
return True
def _return(self, output: AgentFinish, intermediate_steps: list) -> Dict[str, Any]:
self.callback_manager.on_agent_finish(
@ -783,10 +795,12 @@ class AgentExecutor(Chain):
[tool.name for tool in self.tools], excluded_colors=["green"]
)
intermediate_steps: List[Tuple[AgentAction, str]] = []
# Let's start tracking the iterations the agent has gone through
# Let's start tracking the number of iterations and time elapsed
iterations = 0
time_elapsed = 0.0
start_time = time.time()
# We now enter the agent loop (until it returns something).
while self._should_continue(iterations):
while self._should_continue(iterations, time_elapsed):
next_step_output = self._take_next_step(
name_to_tool_map, color_mapping, inputs, intermediate_steps
)
@ -801,6 +815,7 @@ class AgentExecutor(Chain):
if tool_return is not None:
return self._return(tool_return, intermediate_steps)
iterations += 1
time_elapsed = time.time() - start_time
output = self.agent.return_stopped_response(
self.early_stopping_method, intermediate_steps, **inputs
)
@ -815,29 +830,40 @@ class AgentExecutor(Chain):
[tool.name for tool in self.tools], excluded_colors=["green"]
)
intermediate_steps: List[Tuple[AgentAction, str]] = []
# Let's start tracking the iterations the agent has gone through
# Let's start tracking the number of iterations and time elapsed
iterations = 0
time_elapsed = 0.0
start_time = time.time()
# We now enter the agent loop (until it returns something).
while self._should_continue(iterations):
next_step_output = await self._atake_next_step(
name_to_tool_map, color_mapping, inputs, intermediate_steps
)
if isinstance(next_step_output, AgentFinish):
return await self._areturn(next_step_output, intermediate_steps)
intermediate_steps.extend(next_step_output)
if len(next_step_output) == 1:
next_step_action = next_step_output[0]
# See if tool should return directly
tool_return = self._get_tool_return(next_step_action)
if tool_return is not None:
return await self._areturn(tool_return, intermediate_steps)
iterations += 1
output = self.agent.return_stopped_response(
self.early_stopping_method, intermediate_steps, **inputs
)
return await self._areturn(output, intermediate_steps)
async with asyncio_timeout(self.max_execution_time):
try:
while self._should_continue(iterations, time_elapsed):
next_step_output = await self._atake_next_step(
name_to_tool_map, color_mapping, inputs, intermediate_steps
)
if isinstance(next_step_output, AgentFinish):
return await self._areturn(next_step_output, intermediate_steps)
intermediate_steps.extend(next_step_output)
if len(next_step_output) == 1:
next_step_action = next_step_output[0]
# See if tool should return directly
tool_return = self._get_tool_return(next_step_action)
if tool_return is not None:
return await self._areturn(tool_return, intermediate_steps)
iterations += 1
time_elapsed = time.time() - start_time
output = self.agent.return_stopped_response(
self.early_stopping_method, intermediate_steps, **inputs
)
return await self._areturn(output, intermediate_steps)
except TimeoutError:
# stop early when interrupted by the async timeout
output = self.agent.return_stopped_response(
self.early_stopping_method, intermediate_steps, **inputs
)
return await self._areturn(output, intermediate_steps)
def _get_tool_return(
self, next_step_output: Tuple[AgentAction, str]

@ -0,0 +1,11 @@
"""Shims for asyncio features that may be missing from older python versions"""
import sys
if sys.version_info[:2] < (3, 11):
from async_timeout import timeout as asyncio_timeout
else:
from asyncio import timeout as asyncio_timeout
__all__ = ["asyncio_timeout"]

@ -57,6 +57,7 @@ pgvector = {version = "^0.1.6", optional = true}
psycopg2-binary = {version = "^2.9.5", optional = true}
boto3 = {version = "^1.26.96", optional = true}
pyowm = {version = "^3.3.0", optional = true}
async-timeout = {version = "^4.0.0", python = "<3.11"}
[tool.poetry.group.docs.dependencies]
autodoc_pydantic = "^1.8.0"

@ -70,10 +70,16 @@ def test_agent_bad_action() -> None:
def test_agent_stopped_early() -> None:
"""Test react chain when bad action given."""
"""Test react chain when max iterations or max execution time is exceeded."""
# iteration limit
agent = _get_agent(max_iterations=0)
output = agent.run("when was langchain made")
assert output == "Agent stopped due to max iterations."
assert output == "Agent stopped due to iteration limit or time limit."
# execution time limit
agent = _get_agent(max_execution_time=0.0)
output = agent.run("when was langchain made")
assert output == "Agent stopped due to iteration limit or time limit."
def test_agent_with_callbacks_global() -> None:

Loading…
Cancel
Save