Fix tests

This commit is contained in:
William Fu-Hinthorn 2024-04-26 18:29:14 -07:00
parent 894cf7824b
commit 1df6da2583
2 changed files with 8 additions and 7 deletions

View File

@ -210,9 +210,9 @@ class ChildTool(BaseTool):
You can use these to eg identify a specific instance of a tool with its use case. You can use these to eg identify a specific instance of a tool with its use case.
""" """
handle_tool_error: Optional[Union[bool, str, Callable[[ToolException], str]]] = ( handle_tool_error: Optional[
False Union[bool, str, Callable[[ToolException], str]]
) ] = False
"""Handle the content of the ToolException thrown.""" """Handle the content of the ToolException thrown."""
handle_validation_error: Optional[ handle_validation_error: Optional[
@ -838,7 +838,7 @@ class StructuredTool(BaseTool):
# Description example: # Description example:
# search_api(query: str) - Searches the API for the query. # search_api(query: str) - Searches the API for the query.
sig = signature(source_function) sig = signature(source_function)
description = f"{name}{sig} - {description_.strip()}" description_ = f"{name}{sig} - {description_.strip()}"
_args_schema = args_schema _args_schema = args_schema
if _args_schema is None and infer_schema: if _args_schema is None and infer_schema:
# schema name is appended within function # schema name is appended within function

View File

@ -3,6 +3,7 @@
import asyncio import asyncio
import json import json
import sys import sys
import textwrap
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from functools import partial from functools import partial
@ -333,7 +334,7 @@ def test_structured_tool_from_function_docstring() -> None:
prefix = "foo(bar: int, baz: str) -> str - " prefix = "foo(bar: int, baz: str) -> str - "
assert foo.__doc__ is not None assert foo.__doc__ is not None
assert structured_tool.description == prefix + foo.__doc__.strip() assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__.strip())
def test_structured_tool_from_function_docstring_complex_args() -> None: def test_structured_tool_from_function_docstring_complex_args() -> None:
@ -366,7 +367,7 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
prefix = "foo(bar: int, baz: List[str]) -> str - " prefix = "foo(bar: int, baz: List[str]) -> str - "
assert foo.__doc__ is not None assert foo.__doc__ is not None
assert structured_tool.description == prefix + foo.__doc__.strip() assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__).strip()
def test_structured_tool_lambda_multi_args_schema() -> None: def test_structured_tool_lambda_multi_args_schema() -> None:
@ -701,7 +702,7 @@ def test_structured_tool_from_function() -> None:
prefix = "foo(bar: int, baz: str) -> str - " prefix = "foo(bar: int, baz: str) -> str - "
assert foo.__doc__ is not None assert foo.__doc__ is not None
assert structured_tool.description == prefix + foo.__doc__.strip() assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__.strip())
def test_validation_error_handling_bool() -> None: def test_validation_error_handling_bool() -> None: