mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
core[minor]: add validation error handler to BaseTool
(#14007)
- **Description:** add a ValidationError handler as a field of [`BaseTool`](https://github.com/langchain-ai/langchain/blob/master/libs/core/langchain_core/tools.py#L101) and add unit tests for the code change. - **Issue:** #12721 #13662 - **Dependencies:** None - **Tag maintainer:** - **Twitter handle:** @hmdev3 - **NOTE:** - I'm wondering if the update of document is required. --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
bdacfafa05
commit
cc17334473
@ -20,6 +20,7 @@ from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
Field,
|
||||
ValidationError,
|
||||
create_model,
|
||||
root_validator,
|
||||
validate_arguments,
|
||||
@ -169,6 +170,11 @@ class ChildTool(BaseTool):
|
||||
] = False
|
||||
"""Handle the content of the ToolException thrown."""
|
||||
|
||||
handle_validation_error: Optional[
|
||||
Union[bool, str, Callable[[ValidationError], str]]
|
||||
] = False
|
||||
"""Handle the content of the ValidationError thrown."""
|
||||
|
||||
class Config(Serializable.Config):
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
@ -346,6 +352,21 @@ class ChildTool(BaseTool):
|
||||
if new_arg_supported
|
||||
else self._run(*tool_args, **tool_kwargs)
|
||||
)
|
||||
except ValidationError as e:
|
||||
if not self.handle_validation_error:
|
||||
raise e
|
||||
elif isinstance(self.handle_validation_error, bool):
|
||||
observation = "Tool input validation error"
|
||||
elif isinstance(self.handle_validation_error, str):
|
||||
observation = self.handle_validation_error
|
||||
elif callable(self.handle_validation_error):
|
||||
observation = self.handle_validation_error(e)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got unexpected type of `handle_validation_error`. Expected bool, "
|
||||
f"str or callable. Received: {self.handle_validation_error}"
|
||||
)
|
||||
return observation
|
||||
except ToolException as e:
|
||||
if not self.handle_tool_error:
|
||||
run_manager.on_tool_error(e)
|
||||
@ -422,6 +443,21 @@ class ChildTool(BaseTool):
|
||||
if new_arg_supported
|
||||
else await self._arun(*tool_args, **tool_kwargs)
|
||||
)
|
||||
except ValidationError as e:
|
||||
if not self.handle_validation_error:
|
||||
raise e
|
||||
elif isinstance(self.handle_validation_error, bool):
|
||||
observation = "Tool input validation error"
|
||||
elif isinstance(self.handle_validation_error, str):
|
||||
observation = self.handle_validation_error
|
||||
elif callable(self.handle_validation_error):
|
||||
observation = self.handle_validation_error(e)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got unexpected type of `handle_validation_error`. Expected bool, "
|
||||
f"str or callable. Received: {self.handle_validation_error}"
|
||||
)
|
||||
return observation
|
||||
except ToolException as e:
|
||||
if not self.handle_tool_error:
|
||||
await run_manager.on_tool_error(e)
|
||||
|
@ -3,7 +3,7 @@ import json
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, List, Optional, Type, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||||
|
||||
import pytest
|
||||
|
||||
@ -11,7 +11,7 @@ from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||
from langchain_core.tools import (
|
||||
BaseTool,
|
||||
SchemaAnnotationError,
|
||||
@ -620,7 +620,10 @@ def test_exception_handling_str() -> None:
|
||||
|
||||
def test_exception_handling_callable() -> None:
|
||||
expected = "foo bar"
|
||||
handling = lambda _: expected # noqa: E731
|
||||
|
||||
def handling(e: ToolException) -> str:
|
||||
return expected # noqa: E731
|
||||
|
||||
_tool = _FakeExceptionTool(handle_tool_error=handling)
|
||||
actual = _tool.run({})
|
||||
assert expected == actual
|
||||
@ -648,7 +651,10 @@ async def test_async_exception_handling_str() -> None:
|
||||
|
||||
async def test_async_exception_handling_callable() -> None:
|
||||
expected = "foo bar"
|
||||
handling = lambda _: expected # noqa: E731
|
||||
|
||||
def handling(e: ToolException) -> str:
|
||||
return expected # noqa: E731
|
||||
|
||||
_tool = _FakeExceptionTool(handle_tool_error=handling)
|
||||
actual = await _tool.arun({})
|
||||
assert expected == actual
|
||||
@ -691,3 +697,127 @@ def test_structured_tool_from_function() -> None:
|
||||
prefix = "foo(bar: int, baz: str) -> str - "
|
||||
assert foo.__doc__ is not None
|
||||
assert structured_tool.description == prefix + foo.__doc__.strip()
|
||||
|
||||
|
||||
def test_validation_error_handling_bool() -> None:
|
||||
"""Test that validation errors are handled correctly."""
|
||||
expected = "Tool input validation error"
|
||||
_tool = _MockStructuredTool(handle_validation_error=True)
|
||||
actual = _tool.run({})
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_validation_error_handling_str() -> None:
|
||||
"""Test that validation errors are handled correctly."""
|
||||
expected = "foo bar"
|
||||
_tool = _MockStructuredTool(handle_validation_error=expected)
|
||||
actual = _tool.run({})
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_validation_error_handling_callable() -> None:
|
||||
"""Test that validation errors are handled correctly."""
|
||||
expected = "foo bar"
|
||||
|
||||
def handling(e: ValidationError) -> str:
|
||||
return expected # noqa: E731
|
||||
|
||||
_tool = _MockStructuredTool(handle_validation_error=handling)
|
||||
actual = _tool.run({})
|
||||
assert expected == actual
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"handler",
|
||||
[
|
||||
True,
|
||||
"foo bar",
|
||||
lambda _: "foo bar",
|
||||
],
|
||||
)
|
||||
def test_validation_error_handling_non_validation_error(
|
||||
handler: Union[bool, str, Callable[[ValidationError], str]]
|
||||
) -> None:
|
||||
"""Test that validation errors are handled correctly."""
|
||||
|
||||
class _RaiseNonValidationErrorTool(BaseTool):
|
||||
name = "raise_non_validation_error_tool"
|
||||
description = "A tool that raises a non-validation error"
|
||||
|
||||
def _parse_input(
|
||||
self,
|
||||
tool_input: Union[str, Dict],
|
||||
) -> Union[str, Dict[str, Any]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _run(self) -> str:
|
||||
return "dummy"
|
||||
|
||||
async def _arun(self) -> str:
|
||||
return "dummy"
|
||||
|
||||
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler)
|
||||
with pytest.raises(NotImplementedError):
|
||||
_tool.run({})
|
||||
|
||||
|
||||
async def test_async_validation_error_handling_bool() -> None:
|
||||
"""Test that validation errors are handled correctly."""
|
||||
expected = "Tool input validation error"
|
||||
_tool = _MockStructuredTool(handle_validation_error=True)
|
||||
actual = await _tool.arun({})
|
||||
assert expected == actual
|
||||
|
||||
|
||||
async def test_async_validation_error_handling_str() -> None:
|
||||
"""Test that validation errors are handled correctly."""
|
||||
expected = "foo bar"
|
||||
_tool = _MockStructuredTool(handle_validation_error=expected)
|
||||
actual = await _tool.arun({})
|
||||
assert expected == actual
|
||||
|
||||
|
||||
async def test_async_validation_error_handling_callable() -> None:
|
||||
"""Test that validation errors are handled correctly."""
|
||||
expected = "foo bar"
|
||||
|
||||
def handling(e: ValidationError) -> str:
|
||||
return expected # noqa: E731
|
||||
|
||||
_tool = _MockStructuredTool(handle_validation_error=handling)
|
||||
actual = await _tool.arun({})
|
||||
assert expected == actual
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"handler",
|
||||
[
|
||||
True,
|
||||
"foo bar",
|
||||
lambda _: "foo bar",
|
||||
],
|
||||
)
|
||||
async def test_async_validation_error_handling_non_validation_error(
|
||||
handler: Union[bool, str, Callable[[ValidationError], str]]
|
||||
) -> None:
|
||||
"""Test that validation errors are handled correctly."""
|
||||
|
||||
class _RaiseNonValidationErrorTool(BaseTool):
|
||||
name = "raise_non_validation_error_tool"
|
||||
description = "A tool that raises a non-validation error"
|
||||
|
||||
def _parse_input(
|
||||
self,
|
||||
tool_input: Union[str, Dict],
|
||||
) -> Union[str, Dict[str, Any]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _run(self) -> str:
|
||||
return "dummy"
|
||||
|
||||
async def _arun(self) -> str:
|
||||
return "dummy"
|
||||
|
||||
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler)
|
||||
with pytest.raises(NotImplementedError):
|
||||
await _tool.arun({})
|
||||
|
Loading…
Reference in New Issue
Block a user