[Patch] Dedent docstring (#20959)

Technically a slight prompt breaking change, but I think positive EV in
that it saves tokens and results in more sane / in-distribution prompts
pull/21043/head
William FH 4 weeks ago committed by GitHub
parent 845d8e0025
commit 5c63ac3dd7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -21,6 +21,7 @@ from __future__ import annotations
import asyncio
import inspect
import textwrap
import uuid
import warnings
from abc import ABC, abstractmethod
@ -825,16 +826,19 @@ class StructuredTool(BaseTool):
else:
raise ValueError("Function and/or coroutine must be provided")
name = name or source_function.__name__
description = description or source_function.__doc__
if description is None:
description_ = description or source_function.__doc__
if description_ is None:
raise ValueError(
"Function must have a docstring if description not provided."
)
if description is None:
# Only apply if using the function's docstring
description_ = textwrap.dedent(description_).strip()
# Description example:
# search_api(query: str) - Searches the API for the query.
sig = signature(source_function)
description = f"{name}{sig} - {description.strip()}"
description_ = f"{name}{sig} - {description_.strip()}"
_args_schema = args_schema
if _args_schema is None and infer_schema:
# schema name is appended within function
@ -844,7 +848,7 @@ class StructuredTool(BaseTool):
func=func,
coroutine=coroutine,
args_schema=_args_schema, # type: ignore[arg-type]
description=description,
description=description_,
return_direct=return_direct,
**kwargs,
)

@ -3,6 +3,7 @@
import asyncio
import json
import sys
import textwrap
from datetime import datetime
from enum import Enum
from functools import partial
@ -333,7 +334,7 @@ def test_structured_tool_from_function_docstring() -> None:
prefix = "foo(bar: int, baz: str) -> str - "
assert foo.__doc__ is not None
assert structured_tool.description == prefix + foo.__doc__.strip()
assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__.strip())
def test_structured_tool_from_function_docstring_complex_args() -> None:
@ -366,7 +367,7 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
prefix = "foo(bar: int, baz: List[str]) -> str - "
assert foo.__doc__ is not None
assert structured_tool.description == prefix + foo.__doc__.strip()
assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__).strip()
def test_structured_tool_lambda_multi_args_schema() -> None:
@ -701,7 +702,7 @@ def test_structured_tool_from_function() -> None:
prefix = "foo(bar: int, baz: str) -> str - "
assert foo.__doc__ is not None
assert structured_tool.description == prefix + foo.__doc__.strip()
assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__.strip())
def test_validation_error_handling_bool() -> None:

Loading…
Cancel
Save