diff --git a/docs/modules/agents/tools/custom_tools.ipynb b/docs/modules/agents/tools/custom_tools.ipynb index a0fea125..5bfc026e 100644 --- a/docs/modules/agents/tools/custom_tools.ipynb +++ b/docs/modules/agents/tools/custom_tools.ipynb @@ -839,6 +839,127 @@ "source": [ "agent.run(\"whats 2**.12\")" ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "f1da459d", + "metadata": {}, + "source": [ + "## Handling Tool Errors \n", + "When a tool encounters an error and the exception is not caught, the agent will stop executing. If you want the agent to continue execution, you can raise a `ToolException` and set `handle_tool_error` accordingly. \n", + "\n", + "When `ToolException` is thrown, the agent will not stop working, but will handle the exception according to the `handle_tool_error` variable of the tool, and the processing result will be returned to the agent as observation, and printed in red.\n", + "\n", + "You can set `handle_tool_error` to `True`, set it a unified string value, or set it as a function. If it's set as a function, the function should take a `ToolException` as a parameter and return a `str` value.\n", + "\n", + "Please note that only raising a `ToolException` won't be effective. You need to first set the `handle_tool_error` of the tool because its default value is `False`." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "ad16fbcf", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.schema import ToolException\n", + "\n", + "from langchain import SerpAPIWrapper\n", + "from langchain.agents import AgentType, initialize_agent\n", + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.tools import Tool\n", + "\n", + "from langchain.chat_models import ChatOpenAI\n", + "\n", + "def _handle_error(error:ToolException) -> str:\n", + " return \"The following errors occurred during tool execution:\" + error.args[0]+ \"Please try another tool.\"\n", + "def search_tool1(s: str):raise ToolException(\"The search tool1 is not available.\")\n", + "def search_tool2(s: str):raise ToolException(\"The search tool2 is not available.\")\n", + "search_tool3 = SerpAPIWrapper()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "c05aa75b", + "metadata": {}, + "outputs": [], + "source": [ + "description=\"useful for when you need to answer questions about current events.You should give priority to using it.\"\n", + "tools = [\n", + " Tool.from_function(\n", + " func=search_tool1,\n", + " name=\"Search_tool1\",\n", + " description=description,\n", + " handle_tool_error=True,\n", + " ),\n", + " Tool.from_function(\n", + " func=search_tool2,\n", + " name=\"Search_tool2\",\n", + " description=description,\n", + " handle_tool_error=_handle_error,\n", + " ),\n", + " Tool.from_function(\n", + " func=search_tool3.run,\n", + " name=\"Search_tool3\",\n", + " description=\"useful for when you need to answer questions about current events\",\n", + " ),\n", + "]\n", + "\n", + "agent = initialize_agent(\n", + " tools,\n", + " ChatOpenAI(temperature=0),\n", + " agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n", + " verbose=True,\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "cff8b4b5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mI should use Search_tool1 to find recent news articles about Leo DiCaprio's personal life.\n", + "Action: Search_tool1\n", + "Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\n", + "Observation: \u001b[31;1m\u001b[1;3mThe search tool1 is not available.\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3mI should try using Search_tool2 instead.\n", + "Action: Search_tool2\n", + "Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\n", + "Observation: \u001b[31;1m\u001b[1;3mThe following errors occurred during tool execution:The search tool2 is not available.Please try another tool.\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3mI should try using Search_tool3 as a last resort.\n", + "Action: Search_tool3\n", + "Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\n", + "Observation: \u001b[38;5;200m\u001b[1;3mLeonardo DiCaprio and Gigi Hadid were recently spotted at a pre-Oscars party, sparking interest once again in their rumored romance. The Revenant actor and the model first made headlines when they were spotted together at a New York Fashion Week afterparty in September 2022.\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3mBased on the information from Search_tool3, it seems that Gigi Hadid is currently rumored to be Leo DiCaprio's girlfriend.\n", + "Final Answer: Gigi Hadid is currently rumored to be Leo DiCaprio's girlfriend.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "\"Gigi Hadid is currently rumored to be Leo DiCaprio's girlfriend.\"" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.run(\"Who is Leo DiCaprio's girlfriend?\")" + ] } ], "metadata": { @@ -857,7 +978,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.2" + "version": "3.11.3" }, "vscode": { "interpreter": { diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index c0241277..a1da5e7a 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -941,7 +941,7 @@ class AgentExecutor(Chain): name_to_tool_map = {tool.name: tool for tool in self.tools} # We construct a mapping from each tool to a color, used for logging. color_mapping = get_color_mapping( - [tool.name for tool in self.tools], excluded_colors=["green"] + [tool.name for tool in self.tools], excluded_colors=["green", "red"] ) intermediate_steps: List[Tuple[AgentAction, str]] = [] # Let's start tracking the number of iterations and time elapsed diff --git a/langchain/tools/base.py b/langchain/tools/base.py index 52f959ac..ea69731b 100644 --- a/langchain/tools/base.py +++ b/langchain/tools/base.py @@ -112,6 +112,18 @@ def create_schema_from_function( ) +class ToolException(Exception): + """An optional exception that tool throws when execution error occurs. + + When this exception is thrown, the agent will not stop working, + but will handle the exception according to the handle_tool_error + variable of the tool, and the processing result will be returned + to the agent as observation, and printed in red on the console. + """ + + pass + + class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): """Interface LangChain tools must implement.""" @@ -137,6 +149,11 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) """Deprecated. Please use callbacks instead.""" + handle_tool_error: Optional[ + Union[bool, str, Callable[[ToolException], str]] + ] = False + """Handle the content of the ToolException thrown.""" + class Config: """Configuration for this pydantic object.""" @@ -250,11 +267,36 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): if new_arg_supported else self._run(*tool_args, **tool_kwargs) ) + except ToolException as e: + if not self.handle_tool_error: + run_manager.on_tool_error(e) + raise e + elif isinstance(self.handle_tool_error, bool): + if e.args: + observation = e.args[0] + else: + observation = "Tool execution error" + elif isinstance(self.handle_tool_error, str): + observation = self.handle_tool_error + elif callable(self.handle_tool_error): + observation = self.handle_tool_error(e) + else: + raise ValueError( + f"Got unexpected type of `handle_tool_error`. Expected bool, str " + f"or callable. Received: {self.handle_tool_error}" + ) + run_manager.on_tool_end( + str(observation), color="red", name=self.name, **kwargs + ) + return observation except (Exception, KeyboardInterrupt) as e: run_manager.on_tool_error(e) raise e - run_manager.on_tool_end(str(observation), color=color, name=self.name, **kwargs) - return observation + else: + run_manager.on_tool_end( + str(observation), color=color, name=self.name, **kwargs + ) + return observation async def arun( self, @@ -289,13 +331,36 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): if new_arg_supported else await self._arun(*tool_args, **tool_kwargs) ) + except ToolException as e: + if not self.handle_tool_error: + await run_manager.on_tool_error(e) + raise e + elif isinstance(self.handle_tool_error, bool): + if e.args: + observation = e.args[0] + else: + observation = "Tool execution error" + elif isinstance(self.handle_tool_error, str): + observation = self.handle_tool_error + elif callable(self.handle_tool_error): + observation = self.handle_tool_error(e) + else: + raise ValueError( + f"Got unexpected type of `handle_tool_error`. Expected bool, str " + f"or callable. Received: {self.handle_tool_error}" + ) + await run_manager.on_tool_end( + str(observation), color="red", name=self.name, **kwargs + ) + return observation except (Exception, KeyboardInterrupt) as e: await run_manager.on_tool_error(e) raise e - await run_manager.on_tool_end( - str(observation), color=color, name=self.name, **kwargs - ) - return observation + else: + await run_manager.on_tool_end( + str(observation), color=color, name=self.name, **kwargs + ) + return observation def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str: """Make tool callable.""" diff --git a/tests/unit_tests/tools/test_base.py b/tests/unit_tests/tools/test_base.py index a5cf9c9f..6638a095 100644 --- a/tests/unit_tests/tools/test_base.py +++ b/tests/unit_tests/tools/test_base.py @@ -13,7 +13,12 @@ from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) -from langchain.tools.base import BaseTool, SchemaAnnotationError, StructuredTool +from langchain.tools.base import ( + BaseTool, + SchemaAnnotationError, + StructuredTool, + ToolException, +) def test_unnamed_decorator() -> None: @@ -479,3 +484,75 @@ async def test_create_async_tool() -> None: assert test_tool.description == "test_description" assert test_tool.coroutine is not None assert await test_tool.arun("foo") == "foo" + + +class _FakeExceptionTool(BaseTool): + name = "exception" + description = "an exception-throwing tool" + exception: Exception = ToolException() + + def _run(self) -> str: + raise self.exception + + async def _arun(self) -> str: + raise self.exception + + +def test_exception_handling_bool() -> None: + _tool = _FakeExceptionTool(handle_tool_error=True) + expected = "Tool execution error" + actual = _tool.run({}) + assert expected == actual + + +def test_exception_handling_str() -> None: + expected = "foo bar" + _tool = _FakeExceptionTool(handle_tool_error=expected) + actual = _tool.run({}) + assert expected == actual + + +def test_exception_handling_callable() -> None: + expected = "foo bar" + handling = lambda _: expected # noqa: E731 + _tool = _FakeExceptionTool(handle_tool_error=handling) + actual = _tool.run({}) + assert expected == actual + + +def test_exception_handling_non_tool_exception() -> None: + _tool = _FakeExceptionTool(exception=ValueError()) + with pytest.raises(ValueError): + _tool.run({}) + + +@pytest.mark.asyncio +async def test_async_exception_handling_bool() -> None: + _tool = _FakeExceptionTool(handle_tool_error=True) + expected = "Tool execution error" + actual = await _tool.arun({}) + assert expected == actual + + +@pytest.mark.asyncio +async def test_async_exception_handling_str() -> None: + expected = "foo bar" + _tool = _FakeExceptionTool(handle_tool_error=expected) + actual = await _tool.arun({}) + assert expected == actual + + +@pytest.mark.asyncio +async def test_async_exception_handling_callable() -> None: + expected = "foo bar" + handling = lambda _: expected # noqa: E731 + _tool = _FakeExceptionTool(handle_tool_error=handling) + actual = await _tool.arun({}) + assert expected == actual + + +@pytest.mark.asyncio +async def test_async_exception_handling_non_tool_exception() -> None: + _tool = _FakeExceptionTool(exception=ValueError()) + with pytest.raises(ValueError): + await _tool.arun({})