core[minor]: add validation error handler to BaseTool (#14007)

- **Description:** add a ValidationError handler as a field of
[`BaseTool`](https://github.com/langchain-ai/langchain/blob/master/libs/core/langchain_core/tools.py#L101)
and add unit tests for the code change.
- **Issue:** #12721 #13662
- **Dependencies:** None
- **Tag maintainer:** 
- **Twitter handle:** @hmdev3
- **NOTE:**
  - I'm wondering if the update of document is required.

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
hmasdev 2024-02-02 13:09:19 +09:00 committed by GitHub
parent bdacfafa05
commit cc17334473
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 170 additions and 4 deletions

View File

@ -20,6 +20,7 @@ from langchain_core.pydantic_v1 import (
BaseModel,
Extra,
Field,
ValidationError,
create_model,
root_validator,
validate_arguments,
@ -169,6 +170,11 @@ class ChildTool(BaseTool):
] = False
"""Handle the content of the ToolException thrown."""
handle_validation_error: Optional[
Union[bool, str, Callable[[ValidationError], str]]
] = False
"""Handle the content of the ValidationError thrown."""
class Config(Serializable.Config):
"""Configuration for this pydantic object."""
@ -346,6 +352,21 @@ class ChildTool(BaseTool):
if new_arg_supported
else self._run(*tool_args, **tool_kwargs)
)
except ValidationError as e:
if not self.handle_validation_error:
raise e
elif isinstance(self.handle_validation_error, bool):
observation = "Tool input validation error"
elif isinstance(self.handle_validation_error, str):
observation = self.handle_validation_error
elif callable(self.handle_validation_error):
observation = self.handle_validation_error(e)
else:
raise ValueError(
f"Got unexpected type of `handle_validation_error`. Expected bool, "
f"str or callable. Received: {self.handle_validation_error}"
)
return observation
except ToolException as e:
if not self.handle_tool_error:
run_manager.on_tool_error(e)
@ -422,6 +443,21 @@ class ChildTool(BaseTool):
if new_arg_supported
else await self._arun(*tool_args, **tool_kwargs)
)
except ValidationError as e:
if not self.handle_validation_error:
raise e
elif isinstance(self.handle_validation_error, bool):
observation = "Tool input validation error"
elif isinstance(self.handle_validation_error, str):
observation = self.handle_validation_error
elif callable(self.handle_validation_error):
observation = self.handle_validation_error(e)
else:
raise ValueError(
f"Got unexpected type of `handle_validation_error`. Expected bool, "
f"str or callable. Received: {self.handle_validation_error}"
)
return observation
except ToolException as e:
if not self.handle_tool_error:
await run_manager.on_tool_error(e)

View File

@ -3,7 +3,7 @@ import json
from datetime import datetime
from enum import Enum
from functools import partial
from typing import Any, List, Optional, Type, Union
from typing import Any, Callable, Dict, List, Optional, Type, Union
import pytest
@ -11,7 +11,7 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.pydantic_v1 import BaseModel, ValidationError
from langchain_core.tools import (
BaseTool,
SchemaAnnotationError,
@ -620,7 +620,10 @@ def test_exception_handling_str() -> None:
def test_exception_handling_callable() -> None:
expected = "foo bar"
handling = lambda _: expected # noqa: E731
def handling(e: ToolException) -> str:
return expected # noqa: E731
_tool = _FakeExceptionTool(handle_tool_error=handling)
actual = _tool.run({})
assert expected == actual
@ -648,7 +651,10 @@ async def test_async_exception_handling_str() -> None:
async def test_async_exception_handling_callable() -> None:
expected = "foo bar"
handling = lambda _: expected # noqa: E731
def handling(e: ToolException) -> str:
return expected # noqa: E731
_tool = _FakeExceptionTool(handle_tool_error=handling)
actual = await _tool.arun({})
assert expected == actual
@ -691,3 +697,127 @@ def test_structured_tool_from_function() -> None:
prefix = "foo(bar: int, baz: str) -> str - "
assert foo.__doc__ is not None
assert structured_tool.description == prefix + foo.__doc__.strip()
def test_validation_error_handling_bool() -> None:
"""Test that validation errors are handled correctly."""
expected = "Tool input validation error"
_tool = _MockStructuredTool(handle_validation_error=True)
actual = _tool.run({})
assert expected == actual
def test_validation_error_handling_str() -> None:
"""Test that validation errors are handled correctly."""
expected = "foo bar"
_tool = _MockStructuredTool(handle_validation_error=expected)
actual = _tool.run({})
assert expected == actual
def test_validation_error_handling_callable() -> None:
"""Test that validation errors are handled correctly."""
expected = "foo bar"
def handling(e: ValidationError) -> str:
return expected # noqa: E731
_tool = _MockStructuredTool(handle_validation_error=handling)
actual = _tool.run({})
assert expected == actual
@pytest.mark.parametrize(
"handler",
[
True,
"foo bar",
lambda _: "foo bar",
],
)
def test_validation_error_handling_non_validation_error(
handler: Union[bool, str, Callable[[ValidationError], str]]
) -> None:
"""Test that validation errors are handled correctly."""
class _RaiseNonValidationErrorTool(BaseTool):
name = "raise_non_validation_error_tool"
description = "A tool that raises a non-validation error"
def _parse_input(
self,
tool_input: Union[str, Dict],
) -> Union[str, Dict[str, Any]]:
raise NotImplementedError()
def _run(self) -> str:
return "dummy"
async def _arun(self) -> str:
return "dummy"
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler)
with pytest.raises(NotImplementedError):
_tool.run({})
async def test_async_validation_error_handling_bool() -> None:
"""Test that validation errors are handled correctly."""
expected = "Tool input validation error"
_tool = _MockStructuredTool(handle_validation_error=True)
actual = await _tool.arun({})
assert expected == actual
async def test_async_validation_error_handling_str() -> None:
"""Test that validation errors are handled correctly."""
expected = "foo bar"
_tool = _MockStructuredTool(handle_validation_error=expected)
actual = await _tool.arun({})
assert expected == actual
async def test_async_validation_error_handling_callable() -> None:
"""Test that validation errors are handled correctly."""
expected = "foo bar"
def handling(e: ValidationError) -> str:
return expected # noqa: E731
_tool = _MockStructuredTool(handle_validation_error=handling)
actual = await _tool.arun({})
assert expected == actual
@pytest.mark.parametrize(
"handler",
[
True,
"foo bar",
lambda _: "foo bar",
],
)
async def test_async_validation_error_handling_non_validation_error(
handler: Union[bool, str, Callable[[ValidationError], str]]
) -> None:
"""Test that validation errors are handled correctly."""
class _RaiseNonValidationErrorTool(BaseTool):
name = "raise_non_validation_error_tool"
description = "A tool that raises a non-validation error"
def _parse_input(
self,
tool_input: Union[str, Dict],
) -> Union[str, Dict[str, Any]]:
raise NotImplementedError()
def _run(self) -> str:
return "dummy"
async def _arun(self) -> str:
return "dummy"
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler)
with pytest.raises(NotImplementedError):
await _tool.arun({})