Fix tests

This commit is contained in:
William Fu-Hinthorn 2024-04-26 18:29:14 -07:00
parent 894cf7824b
commit 1df6da2583
2 changed files with 8 additions and 7 deletions

View File

@ -210,9 +210,9 @@ class ChildTool(BaseTool):
You can use these to eg identify a specific instance of a tool with its use case.
"""
handle_tool_error: Optional[Union[bool, str, Callable[[ToolException], str]]] = (
False
)
handle_tool_error: Optional[
Union[bool, str, Callable[[ToolException], str]]
] = False
"""Handle the content of the ToolException thrown."""
handle_validation_error: Optional[
@ -838,7 +838,7 @@ class StructuredTool(BaseTool):
# Description example:
# search_api(query: str) - Searches the API for the query.
sig = signature(source_function)
description = f"{name}{sig} - {description_.strip()}"
description_ = f"{name}{sig} - {description_.strip()}"
_args_schema = args_schema
if _args_schema is None and infer_schema:
# schema name is appended within function

View File

@ -3,6 +3,7 @@
import asyncio
import json
import sys
import textwrap
from datetime import datetime
from enum import Enum
from functools import partial
@ -333,7 +334,7 @@ def test_structured_tool_from_function_docstring() -> None:
prefix = "foo(bar: int, baz: str) -> str - "
assert foo.__doc__ is not None
assert structured_tool.description == prefix + foo.__doc__.strip()
assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__.strip())
def test_structured_tool_from_function_docstring_complex_args() -> None:
@ -366,7 +367,7 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
prefix = "foo(bar: int, baz: List[str]) -> str - "
assert foo.__doc__ is not None
assert structured_tool.description == prefix + foo.__doc__.strip()
assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__).strip()
def test_structured_tool_lambda_multi_args_schema() -> None:
@ -701,7 +702,7 @@ def test_structured_tool_from_function() -> None:
prefix = "foo(bar: int, baz: str) -> str - "
assert foo.__doc__ is not None
assert structured_tool.description == prefix + foo.__doc__.strip()
assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__.strip())
def test_validation_error_handling_bool() -> None: