@ -3,7 +3,7 @@ import json
from datetime import datetime
from enum import Enum
from functools import partial
from typing import Any , List, Optional , Type , Union
from typing import Any , Callable, Dict , List, Optional , Type , Union
import pytest
@ -11,7 +11,7 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun ,
CallbackManagerForToolRun ,
)
from langchain_core . pydantic_v1 import BaseModel
from langchain_core . pydantic_v1 import BaseModel , ValidationError
from langchain_core . tools import (
BaseTool ,
SchemaAnnotationError ,
@ -620,7 +620,10 @@ def test_exception_handling_str() -> None:
def test_exception_handling_callable ( ) - > None :
expected = " foo bar "
handling = lambda _ : expected # noqa: E731
def handling ( e : ToolException ) - > str :
return expected # noqa: E731
_tool = _FakeExceptionTool ( handle_tool_error = handling )
actual = _tool . run ( { } )
assert expected == actual
@ -648,7 +651,10 @@ async def test_async_exception_handling_str() -> None:
async def test_async_exception_handling_callable ( ) - > None :
expected = " foo bar "
handling = lambda _ : expected # noqa: E731
def handling ( e : ToolException ) - > str :
return expected # noqa: E731
_tool = _FakeExceptionTool ( handle_tool_error = handling )
actual = await _tool . arun ( { } )
assert expected == actual
@ -691,3 +697,127 @@ def test_structured_tool_from_function() -> None:
prefix = " foo(bar: int, baz: str) -> str - "
assert foo . __doc__ is not None
assert structured_tool . description == prefix + foo . __doc__ . strip ( )
def test_validation_error_handling_bool ( ) - > None :
""" Test that validation errors are handled correctly. """
expected = " Tool input validation error "
_tool = _MockStructuredTool ( handle_validation_error = True )
actual = _tool . run ( { } )
assert expected == actual
def test_validation_error_handling_str ( ) - > None :
""" Test that validation errors are handled correctly. """
expected = " foo bar "
_tool = _MockStructuredTool ( handle_validation_error = expected )
actual = _tool . run ( { } )
assert expected == actual
def test_validation_error_handling_callable ( ) - > None :
""" Test that validation errors are handled correctly. """
expected = " foo bar "
def handling ( e : ValidationError ) - > str :
return expected # noqa: E731
_tool = _MockStructuredTool ( handle_validation_error = handling )
actual = _tool . run ( { } )
assert expected == actual
@pytest.mark.parametrize (
" handler " ,
[
True ,
" foo bar " ,
lambda _ : " foo bar " ,
] ,
)
def test_validation_error_handling_non_validation_error (
handler : Union [ bool , str , Callable [ [ ValidationError ] , str ] ]
) - > None :
""" Test that validation errors are handled correctly. """
class _RaiseNonValidationErrorTool ( BaseTool ) :
name = " raise_non_validation_error_tool "
description = " A tool that raises a non-validation error "
def _parse_input (
self ,
tool_input : Union [ str , Dict ] ,
) - > Union [ str , Dict [ str , Any ] ] :
raise NotImplementedError ( )
def _run ( self ) - > str :
return " dummy "
async def _arun ( self ) - > str :
return " dummy "
_tool = _RaiseNonValidationErrorTool ( handle_validation_error = handler )
with pytest . raises ( NotImplementedError ) :
_tool . run ( { } )
async def test_async_validation_error_handling_bool ( ) - > None :
""" Test that validation errors are handled correctly. """
expected = " Tool input validation error "
_tool = _MockStructuredTool ( handle_validation_error = True )
actual = await _tool . arun ( { } )
assert expected == actual
async def test_async_validation_error_handling_str ( ) - > None :
""" Test that validation errors are handled correctly. """
expected = " foo bar "
_tool = _MockStructuredTool ( handle_validation_error = expected )
actual = await _tool . arun ( { } )
assert expected == actual
async def test_async_validation_error_handling_callable ( ) - > None :
""" Test that validation errors are handled correctly. """
expected = " foo bar "
def handling ( e : ValidationError ) - > str :
return expected # noqa: E731
_tool = _MockStructuredTool ( handle_validation_error = handling )
actual = await _tool . arun ( { } )
assert expected == actual
@pytest.mark.parametrize (
" handler " ,
[
True ,
" foo bar " ,
lambda _ : " foo bar " ,
] ,
)
async def test_async_validation_error_handling_non_validation_error (
handler : Union [ bool , str , Callable [ [ ValidationError ] , str ] ]
) - > None :
""" Test that validation errors are handled correctly. """
class _RaiseNonValidationErrorTool ( BaseTool ) :
name = " raise_non_validation_error_tool "
description = " A tool that raises a non-validation error "
def _parse_input (
self ,
tool_input : Union [ str , Dict ] ,
) - > Union [ str , Dict [ str , Any ] ] :
raise NotImplementedError ( )
def _run ( self ) - > str :
return " dummy "
async def _arun ( self ) - > str :
return " dummy "
_tool = _RaiseNonValidationErrorTool ( handle_validation_error = handler )
with pytest . raises ( NotImplementedError ) :
await _tool . arun ( { } )