core[patch]: support ValidationError from pydantic v1 in tools (#27194)

This commit is contained in:
Vadym Barda 2024-10-08 10:19:04 -04:00 committed by GitHub
parent 16f5fdb38b
commit 8d27325dbc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 8 deletions

View File

@ -35,6 +35,7 @@ from pydantic import (
validate_arguments,
)
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import ValidationError as ValidationErrorV1
from pydantic.v1 import validate_arguments as validate_arguments_v1
from langchain_core._api import deprecated
@ -404,7 +405,7 @@ class ChildTool(BaseTool):
"""Handle the content of the ToolException thrown."""
handle_validation_error: Optional[
Union[bool, str, Callable[[ValidationError], str]]
Union[bool, str, Callable[[Union[ValidationError, ValidationErrorV1]], str]]
] = False
"""Handle the content of the ValidationError thrown."""
@ -667,7 +668,7 @@ class ChildTool(BaseTool):
else:
content = response
status = "success"
except ValidationError as e:
except (ValidationError, ValidationErrorV1) as e:
if not self.handle_validation_error:
error_to_raise = e
else:
@ -819,9 +820,11 @@ def _is_tool_call(x: Any) -> bool:
def _handle_validation_error(
e: ValidationError,
e: Union[ValidationError, ValidationErrorV1],
*,
flag: Union[Literal[True], str, Callable[[ValidationError], str]],
flag: Union[
Literal[True], str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
],
) -> str:
if isinstance(flag, bool):
content = "Tool input validation error"

View File

@ -22,6 +22,7 @@ from typing import (
import pytest
from pydantic import BaseModel, Field, ValidationError
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import ValidationError as ValidationErrorV1
from typing_extensions import TypedDict
from langchain_core import tools
@ -825,7 +826,7 @@ def test_validation_error_handling_callable() -> None:
"""Test that validation errors are handled correctly."""
expected = "foo bar"
def handling(e: ValidationError) -> str:
def handling(e: Union[ValidationError, ValidationErrorV1]) -> str:
return expected
_tool = _MockStructuredTool(handle_validation_error=handling)
@ -842,7 +843,9 @@ def test_validation_error_handling_callable() -> None:
],
)
def test_validation_error_handling_non_validation_error(
handler: Union[bool, str, Callable[[ValidationError], str]],
handler: Union[
bool, str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
],
) -> None:
"""Test that validation errors are handled correctly."""
@ -887,7 +890,7 @@ async def test_async_validation_error_handling_callable() -> None:
"""Test that validation errors are handled correctly."""
expected = "foo bar"
def handling(e: ValidationError) -> str:
def handling(e: Union[ValidationError, ValidationErrorV1]) -> str:
return expected
_tool = _MockStructuredTool(handle_validation_error=handling)
@ -904,7 +907,9 @@ async def test_async_validation_error_handling_callable() -> None:
],
)
async def test_async_validation_error_handling_non_validation_error(
handler: Union[bool, str, Callable[[ValidationError], str]],
handler: Union[
bool, str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
],
) -> None:
"""Test that validation errors are handled correctly."""