From cf5803e44c0b0ed8664b5948a0ed45ed670ed5fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B0=8F=E9=93=AD?= <1223398803@qq.com> Date: Tue, 30 May 2023 04:05:58 +0800 Subject: [PATCH] Add ToolException that a tool can throw. (#5050) # Add ToolException that a tool can throw This is an optional exception that tool throws when execution error occurs. When this exception is thrown, the agent will not stop working,but will handle the exception according to the handle_tool_error variable of the tool,and the processing result will be returned to the agent as observation,and printed in pink on the console.It can be used like this: ```python from langchain.schema import ToolException from langchain import LLMMathChain, SerpAPIWrapper, OpenAI from langchain.agents import AgentType, initialize_agent from langchain.chat_models import ChatOpenAI from langchain.tools import BaseTool, StructuredTool, Tool, tool from langchain.chat_models import ChatOpenAI llm = ChatOpenAI(temperature=0) llm_math_chain = LLMMathChain(llm=llm, verbose=True) class Error_tool: def run(self, s: str): raise ToolException('The current search tool is not available.') def handle_tool_error(error) -> str: return "The following errors occurred during tool execution:"+str(error) search_tool1 = Error_tool() search_tool2 = SerpAPIWrapper() tools = [ Tool.from_function( func=search_tool1.run, name="Search_tool1", description="useful for when you need to answer questions about current events.You should give priority to using it.", handle_tool_error=handle_tool_error, ), Tool.from_function( func=search_tool2.run, name="Search_tool2", description="useful for when you need to answer questions about current events", return_direct=True, ) ] agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, handle_tool_errors=handle_tool_error) agent.run("Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?") ``` ![image](https://github.com/hwchase17/langchain/assets/32786500/51930410-b26e-4f85-a1e1-e6a6fb450ada) ## Who can review? - @vowelparrot --------- Co-authored-by: Dev 2049 --- docs/modules/agents/tools/custom_tools.ipynb | 123 ++++++++++++++++++- langchain/agents/agent.py | 2 +- langchain/tools/base.py | 77 +++++++++++- tests/unit_tests/tools/test_base.py | 79 +++++++++++- 4 files changed, 272 insertions(+), 9 deletions(-) diff --git a/docs/modules/agents/tools/custom_tools.ipynb b/docs/modules/agents/tools/custom_tools.ipynb index a0fea125..5bfc026e 100644 --- a/docs/modules/agents/tools/custom_tools.ipynb +++ b/docs/modules/agents/tools/custom_tools.ipynb @@ -839,6 +839,127 @@ "source": [ "agent.run(\"whats 2**.12\")" ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "f1da459d", + "metadata": {}, + "source": [ + "## Handling Tool Errors \n", + "When a tool encounters an error and the exception is not caught, the agent will stop executing. If you want the agent to continue execution, you can raise a `ToolException` and set `handle_tool_error` accordingly. \n", + "\n", + "When `ToolException` is thrown, the agent will not stop working, but will handle the exception according to the `handle_tool_error` variable of the tool, and the processing result will be returned to the agent as observation, and printed in red.\n", + "\n", + "You can set `handle_tool_error` to `True`, set it a unified string value, or set it as a function. If it's set as a function, the function should take a `ToolException` as a parameter and return a `str` value.\n", + "\n", + "Please note that only raising a `ToolException` won't be effective. You need to first set the `handle_tool_error` of the tool because its default value is `False`." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "ad16fbcf", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.schema import ToolException\n", + "\n", + "from langchain import SerpAPIWrapper\n", + "from langchain.agents import AgentType, initialize_agent\n", + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.tools import Tool\n", + "\n", + "from langchain.chat_models import ChatOpenAI\n", + "\n", + "def _handle_error(error:ToolException) -> str:\n", + " return \"The following errors occurred during tool execution:\" + error.args[0]+ \"Please try another tool.\"\n", + "def search_tool1(s: str):raise ToolException(\"The search tool1 is not available.\")\n", + "def search_tool2(s: str):raise ToolException(\"The search tool2 is not available.\")\n", + "search_tool3 = SerpAPIWrapper()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "c05aa75b", + "metadata": {}, + "outputs": [], + "source": [ + "description=\"useful for when you need to answer questions about current events.You should give priority to using it.\"\n", + "tools = [\n", + " Tool.from_function(\n", + " func=search_tool1,\n", + " name=\"Search_tool1\",\n", + " description=description,\n", + " handle_tool_error=True,\n", + " ),\n", + " Tool.from_function(\n", + " func=search_tool2,\n", + " name=\"Search_tool2\",\n", + " description=description,\n", + " handle_tool_error=_handle_error,\n", + " ),\n", + " Tool.from_function(\n", + " func=search_tool3.run,\n", + " name=\"Search_tool3\",\n", + " description=\"useful for when you need to answer questions about current events\",\n", + " ),\n", + "]\n", + "\n", + "agent = initialize_agent(\n", + " tools,\n", + " ChatOpenAI(temperature=0),\n", + " agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n", + " verbose=True,\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "cff8b4b5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mI should use Search_tool1 to find recent news articles about Leo DiCaprio's personal life.\n", + "Action: Search_tool1\n", + "Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\n", + "Observation: \u001b[31;1m\u001b[1;3mThe search tool1 is not available.\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3mI should try using Search_tool2 instead.\n", + "Action: Search_tool2\n", + "Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\n", + "Observation: \u001b[31;1m\u001b[1;3mThe following errors occurred during tool execution:The search tool2 is not available.Please try another tool.\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3mI should try using Search_tool3 as a last resort.\n", + "Action: Search_tool3\n", + "Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\n", + "Observation: \u001b[38;5;200m\u001b[1;3mLeonardo DiCaprio and Gigi Hadid were recently spotted at a pre-Oscars party, sparking interest once again in their rumored romance. The Revenant actor and the model first made headlines when they were spotted together at a New York Fashion Week afterparty in September 2022.\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3mBased on the information from Search_tool3, it seems that Gigi Hadid is currently rumored to be Leo DiCaprio's girlfriend.\n", + "Final Answer: Gigi Hadid is currently rumored to be Leo DiCaprio's girlfriend.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "\"Gigi Hadid is currently rumored to be Leo DiCaprio's girlfriend.\"" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.run(\"Who is Leo DiCaprio's girlfriend?\")" + ] } ], "metadata": { @@ -857,7 +978,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.2" + "version": "3.11.3" }, "vscode": { "interpreter": { diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index c0241277..a1da5e7a 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -941,7 +941,7 @@ class AgentExecutor(Chain): name_to_tool_map = {tool.name: tool for tool in self.tools} # We construct a mapping from each tool to a color, used for logging. color_mapping = get_color_mapping( - [tool.name for tool in self.tools], excluded_colors=["green"] + [tool.name for tool in self.tools], excluded_colors=["green", "red"] ) intermediate_steps: List[Tuple[AgentAction, str]] = [] # Let's start tracking the number of iterations and time elapsed diff --git a/langchain/tools/base.py b/langchain/tools/base.py index 52f959ac..ea69731b 100644 --- a/langchain/tools/base.py +++ b/langchain/tools/base.py @@ -112,6 +112,18 @@ def create_schema_from_function( ) +class ToolException(Exception): + """An optional exception that tool throws when execution error occurs. + + When this exception is thrown, the agent will not stop working, + but will handle the exception according to the handle_tool_error + variable of the tool, and the processing result will be returned + to the agent as observation, and printed in red on the console. + """ + + pass + + class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): """Interface LangChain tools must implement.""" @@ -137,6 +149,11 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) """Deprecated. Please use callbacks instead.""" + handle_tool_error: Optional[ + Union[bool, str, Callable[[ToolException], str]] + ] = False + """Handle the content of the ToolException thrown.""" + class Config: """Configuration for this pydantic object.""" @@ -250,11 +267,36 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): if new_arg_supported else self._run(*tool_args, **tool_kwargs) ) + except ToolException as e: + if not self.handle_tool_error: + run_manager.on_tool_error(e) + raise e + elif isinstance(self.handle_tool_error, bool): + if e.args: + observation = e.args[0] + else: + observation = "Tool execution error" + elif isinstance(self.handle_tool_error, str): + observation = self.handle_tool_error + elif callable(self.handle_tool_error): + observation = self.handle_tool_error(e) + else: + raise ValueError( + f"Got unexpected type of `handle_tool_error`. Expected bool, str " + f"or callable. Received: {self.handle_tool_error}" + ) + run_manager.on_tool_end( + str(observation), color="red", name=self.name, **kwargs + ) + return observation except (Exception, KeyboardInterrupt) as e: run_manager.on_tool_error(e) raise e - run_manager.on_tool_end(str(observation), color=color, name=self.name, **kwargs) - return observation + else: + run_manager.on_tool_end( + str(observation), color=color, name=self.name, **kwargs + ) + return observation async def arun( self, @@ -289,13 +331,36 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): if new_arg_supported else await self._arun(*tool_args, **tool_kwargs) ) + except ToolException as e: + if not self.handle_tool_error: + await run_manager.on_tool_error(e) + raise e + elif isinstance(self.handle_tool_error, bool): + if e.args: + observation = e.args[0] + else: + observation = "Tool execution error" + elif isinstance(self.handle_tool_error, str): + observation = self.handle_tool_error + elif callable(self.handle_tool_error): + observation = self.handle_tool_error(e) + else: + raise ValueError( + f"Got unexpected type of `handle_tool_error`. Expected bool, str " + f"or callable. Received: {self.handle_tool_error}" + ) + await run_manager.on_tool_end( + str(observation), color="red", name=self.name, **kwargs + ) + return observation except (Exception, KeyboardInterrupt) as e: await run_manager.on_tool_error(e) raise e - await run_manager.on_tool_end( - str(observation), color=color, name=self.name, **kwargs - ) - return observation + else: + await run_manager.on_tool_end( + str(observation), color=color, name=self.name, **kwargs + ) + return observation def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str: """Make tool callable.""" diff --git a/tests/unit_tests/tools/test_base.py b/tests/unit_tests/tools/test_base.py index a5cf9c9f..6638a095 100644 --- a/tests/unit_tests/tools/test_base.py +++ b/tests/unit_tests/tools/test_base.py @@ -13,7 +13,12 @@ from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) -from langchain.tools.base import BaseTool, SchemaAnnotationError, StructuredTool +from langchain.tools.base import ( + BaseTool, + SchemaAnnotationError, + StructuredTool, + ToolException, +) def test_unnamed_decorator() -> None: @@ -479,3 +484,75 @@ async def test_create_async_tool() -> None: assert test_tool.description == "test_description" assert test_tool.coroutine is not None assert await test_tool.arun("foo") == "foo" + + +class _FakeExceptionTool(BaseTool): + name = "exception" + description = "an exception-throwing tool" + exception: Exception = ToolException() + + def _run(self) -> str: + raise self.exception + + async def _arun(self) -> str: + raise self.exception + + +def test_exception_handling_bool() -> None: + _tool = _FakeExceptionTool(handle_tool_error=True) + expected = "Tool execution error" + actual = _tool.run({}) + assert expected == actual + + +def test_exception_handling_str() -> None: + expected = "foo bar" + _tool = _FakeExceptionTool(handle_tool_error=expected) + actual = _tool.run({}) + assert expected == actual + + +def test_exception_handling_callable() -> None: + expected = "foo bar" + handling = lambda _: expected # noqa: E731 + _tool = _FakeExceptionTool(handle_tool_error=handling) + actual = _tool.run({}) + assert expected == actual + + +def test_exception_handling_non_tool_exception() -> None: + _tool = _FakeExceptionTool(exception=ValueError()) + with pytest.raises(ValueError): + _tool.run({}) + + +@pytest.mark.asyncio +async def test_async_exception_handling_bool() -> None: + _tool = _FakeExceptionTool(handle_tool_error=True) + expected = "Tool execution error" + actual = await _tool.arun({}) + assert expected == actual + + +@pytest.mark.asyncio +async def test_async_exception_handling_str() -> None: + expected = "foo bar" + _tool = _FakeExceptionTool(handle_tool_error=expected) + actual = await _tool.arun({}) + assert expected == actual + + +@pytest.mark.asyncio +async def test_async_exception_handling_callable() -> None: + expected = "foo bar" + handling = lambda _: expected # noqa: E731 + _tool = _FakeExceptionTool(handle_tool_error=handling) + actual = await _tool.arun({}) + assert expected == actual + + +@pytest.mark.asyncio +async def test_async_exception_handling_non_tool_exception() -> None: + _tool = _FakeExceptionTool(exception=ValueError()) + with pytest.raises(ValueError): + await _tool.arun({})