diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index 3c948642..447f6a10 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -3,6 +3,7 @@ from __future__ import annotations import json import logging +import time from abc import abstractmethod from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Tuple, Union @@ -26,6 +27,7 @@ from langchain.schema import ( BaseOutputParser, ) from langchain.tools.base import BaseTool +from langchain.utilities.asyncio import asyncio_timeout logger = logging.getLogger() @@ -88,7 +90,9 @@ class BaseSingleActionAgent(BaseModel): """Return response when agent has been stopped due to max iterations.""" if early_stopping_method == "force": # `force` just returns a constant string - return AgentFinish({"output": "Agent stopped due to max iterations."}, "") + return AgentFinish( + {"output": "Agent stopped due to iteration limit or time limit."}, "" + ) else: raise ValueError( f"Got unsupported early_stopping_method `{early_stopping_method}`" @@ -506,7 +510,9 @@ class Agent(BaseSingleActionAgent): """Return response when agent has been stopped due to max iterations.""" if early_stopping_method == "force": # `force` just returns a constant string - return AgentFinish({"output": "Agent stopped due to max iterations."}, "") + return AgentFinish( + {"output": "Agent stopped due to iteration limit or time limit."}, "" + ) elif early_stopping_method == "generate": # Generate does one final forward pass thoughts = "" @@ -555,6 +561,7 @@ class AgentExecutor(Chain): tools: Sequence[BaseTool] return_intermediate_steps: bool = False max_iterations: Optional[int] = 15 + max_execution_time: Optional[float] = None early_stopping_method: str = "force" @classmethod @@ -633,11 +640,16 @@ class AgentExecutor(Chain): """Lookup tool by name.""" return {tool.name: tool for tool in self.tools}[name] - def _should_continue(self, iterations: int) -> bool: - if self.max_iterations is None: - return True - else: - return iterations < self.max_iterations + def _should_continue(self, iterations: int, time_elapsed: float) -> bool: + if self.max_iterations is not None and iterations >= self.max_iterations: + return False + if ( + self.max_execution_time is not None + and time_elapsed >= self.max_execution_time + ): + return False + + return True def _return(self, output: AgentFinish, intermediate_steps: list) -> Dict[str, Any]: self.callback_manager.on_agent_finish( @@ -783,10 +795,12 @@ class AgentExecutor(Chain): [tool.name for tool in self.tools], excluded_colors=["green"] ) intermediate_steps: List[Tuple[AgentAction, str]] = [] - # Let's start tracking the iterations the agent has gone through + # Let's start tracking the number of iterations and time elapsed iterations = 0 + time_elapsed = 0.0 + start_time = time.time() # We now enter the agent loop (until it returns something). - while self._should_continue(iterations): + while self._should_continue(iterations, time_elapsed): next_step_output = self._take_next_step( name_to_tool_map, color_mapping, inputs, intermediate_steps ) @@ -801,6 +815,7 @@ class AgentExecutor(Chain): if tool_return is not None: return self._return(tool_return, intermediate_steps) iterations += 1 + time_elapsed = time.time() - start_time output = self.agent.return_stopped_response( self.early_stopping_method, intermediate_steps, **inputs ) @@ -815,29 +830,40 @@ class AgentExecutor(Chain): [tool.name for tool in self.tools], excluded_colors=["green"] ) intermediate_steps: List[Tuple[AgentAction, str]] = [] - # Let's start tracking the iterations the agent has gone through + # Let's start tracking the number of iterations and time elapsed iterations = 0 + time_elapsed = 0.0 + start_time = time.time() # We now enter the agent loop (until it returns something). - while self._should_continue(iterations): - next_step_output = await self._atake_next_step( - name_to_tool_map, color_mapping, inputs, intermediate_steps - ) - if isinstance(next_step_output, AgentFinish): - return await self._areturn(next_step_output, intermediate_steps) - - intermediate_steps.extend(next_step_output) - if len(next_step_output) == 1: - next_step_action = next_step_output[0] - # See if tool should return directly - tool_return = self._get_tool_return(next_step_action) - if tool_return is not None: - return await self._areturn(tool_return, intermediate_steps) - - iterations += 1 - output = self.agent.return_stopped_response( - self.early_stopping_method, intermediate_steps, **inputs - ) - return await self._areturn(output, intermediate_steps) + async with asyncio_timeout(self.max_execution_time): + try: + while self._should_continue(iterations, time_elapsed): + next_step_output = await self._atake_next_step( + name_to_tool_map, color_mapping, inputs, intermediate_steps + ) + if isinstance(next_step_output, AgentFinish): + return await self._areturn(next_step_output, intermediate_steps) + + intermediate_steps.extend(next_step_output) + if len(next_step_output) == 1: + next_step_action = next_step_output[0] + # See if tool should return directly + tool_return = self._get_tool_return(next_step_action) + if tool_return is not None: + return await self._areturn(tool_return, intermediate_steps) + + iterations += 1 + time_elapsed = time.time() - start_time + output = self.agent.return_stopped_response( + self.early_stopping_method, intermediate_steps, **inputs + ) + return await self._areturn(output, intermediate_steps) + except TimeoutError: + # stop early when interrupted by the async timeout + output = self.agent.return_stopped_response( + self.early_stopping_method, intermediate_steps, **inputs + ) + return await self._areturn(output, intermediate_steps) def _get_tool_return( self, next_step_output: Tuple[AgentAction, str] diff --git a/langchain/utilities/asyncio.py b/langchain/utilities/asyncio.py new file mode 100644 index 00000000..d7db052e --- /dev/null +++ b/langchain/utilities/asyncio.py @@ -0,0 +1,11 @@ +"""Shims for asyncio features that may be missing from older python versions""" + +import sys + +if sys.version_info[:2] < (3, 11): + from async_timeout import timeout as asyncio_timeout +else: + from asyncio import timeout as asyncio_timeout + + +__all__ = ["asyncio_timeout"] diff --git a/pyproject.toml b/pyproject.toml index f57910ca..665963b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ pgvector = {version = "^0.1.6", optional = true} psycopg2-binary = {version = "^2.9.5", optional = true} boto3 = {version = "^1.26.96", optional = true} pyowm = {version = "^3.3.0", optional = true} +async-timeout = {version = "^4.0.0", python = "<3.11"} [tool.poetry.group.docs.dependencies] autodoc_pydantic = "^1.8.0" diff --git a/tests/unit_tests/agents/test_agent.py b/tests/unit_tests/agents/test_agent.py index f33573b4..ce052170 100644 --- a/tests/unit_tests/agents/test_agent.py +++ b/tests/unit_tests/agents/test_agent.py @@ -70,10 +70,16 @@ def test_agent_bad_action() -> None: def test_agent_stopped_early() -> None: - """Test react chain when bad action given.""" + """Test react chain when max iterations or max execution time is exceeded.""" + # iteration limit agent = _get_agent(max_iterations=0) output = agent.run("when was langchain made") - assert output == "Agent stopped due to max iterations." + assert output == "Agent stopped due to iteration limit or time limit." + + # execution time limit + agent = _get_agent(max_execution_time=0.0) + output = agent.run("when was langchain made") + assert output == "Agent stopped due to iteration limit or time limit." def test_agent_with_callbacks_global() -> None: