|
|
|
@ -9,6 +9,7 @@ from functools import partial
|
|
|
|
|
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
|
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
from typing_extensions import Annotated
|
|
|
|
|
|
|
|
|
|
from langchain_core.callbacks import (
|
|
|
|
|
AsyncCallbackManagerForToolRun,
|
|
|
|
@ -23,6 +24,7 @@ from langchain_core.tools import (
|
|
|
|
|
Tool,
|
|
|
|
|
ToolException,
|
|
|
|
|
_create_subset_model,
|
|
|
|
|
create_schema_from_function,
|
|
|
|
|
tool,
|
|
|
|
|
)
|
|
|
|
|
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
|
|
|
@ -53,7 +55,12 @@ class _MockStructuredTool(BaseTool):
|
|
|
|
|
args_schema: Type[BaseModel] = _MockSchema
|
|
|
|
|
description: str = "A Structured Tool"
|
|
|
|
|
|
|
|
|
|
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
|
|
|
|
def _run(
|
|
|
|
|
self,
|
|
|
|
|
arg1: int,
|
|
|
|
|
arg2: bool,
|
|
|
|
|
arg3: Optional[dict] = None,
|
|
|
|
|
) -> str:
|
|
|
|
|
return f"{arg1} {arg2} {arg3}"
|
|
|
|
|
|
|
|
|
|
async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
|
|
|
@ -70,6 +77,33 @@ def test_structured_args() -> None:
|
|
|
|
|
assert structured_api.run(args) == expected_result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10 or above")
|
|
|
|
|
def test_structured_args_description() -> None:
|
|
|
|
|
class _AnnotatedTool(BaseTool):
|
|
|
|
|
name: str = "structured_api"
|
|
|
|
|
description: str = "A Structured Tool"
|
|
|
|
|
|
|
|
|
|
def _run(
|
|
|
|
|
self,
|
|
|
|
|
arg1: int,
|
|
|
|
|
arg2: Annotated[bool, "V important"],
|
|
|
|
|
arg3: Optional[dict] = None,
|
|
|
|
|
) -> str:
|
|
|
|
|
return f"{arg1} {arg2} {arg3}"
|
|
|
|
|
|
|
|
|
|
async def _arun(
|
|
|
|
|
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
|
|
|
|
) -> str:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
expected = {
|
|
|
|
|
"arg1": {"title": "Arg1", "type": "integer"},
|
|
|
|
|
"arg2": {"title": "Arg2", "type": "boolean", "description": "V important"},
|
|
|
|
|
"arg3": {"title": "Arg3", "type": "object"},
|
|
|
|
|
}
|
|
|
|
|
assert _AnnotatedTool().args == expected
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_misannotated_base_tool_raises_error() -> None:
|
|
|
|
|
"""Test that a BaseTool with the incorrect typehint raises an exception.""" ""
|
|
|
|
|
with pytest.raises(SchemaAnnotationError):
|
|
|
|
@ -876,6 +910,73 @@ def test_tool_invoke_optional_args(inputs: dict, expected: Optional[dict]) -> No
|
|
|
|
|
foo.invoke(inputs) # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10 or above")
|
|
|
|
|
def test_create_schema_from_function_with_descriptions() -> None:
|
|
|
|
|
def foo(bar: int, baz: str) -> str:
|
|
|
|
|
"""Docstring
|
|
|
|
|
Args:
|
|
|
|
|
bar: int
|
|
|
|
|
baz: str
|
|
|
|
|
"""
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
schema = create_schema_from_function("foo", foo)
|
|
|
|
|
assert schema.schema() == {
|
|
|
|
|
"title": "fooSchema",
|
|
|
|
|
"type": "object",
|
|
|
|
|
"properties": {
|
|
|
|
|
"bar": {"title": "Bar", "type": "integer"},
|
|
|
|
|
"baz": {"title": "Baz", "type": "string"},
|
|
|
|
|
},
|
|
|
|
|
"required": ["bar", "baz"],
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def foo_annotated(
|
|
|
|
|
bar: Annotated[int, "This is bar", {"gte": 5}, "it's useful"],
|
|
|
|
|
) -> str:
|
|
|
|
|
"""Docstring
|
|
|
|
|
Args:
|
|
|
|
|
bar: int
|
|
|
|
|
"""
|
|
|
|
|
raise bar
|
|
|
|
|
|
|
|
|
|
schema = create_schema_from_function("foo_annotated", foo_annotated)
|
|
|
|
|
assert schema.schema() == {
|
|
|
|
|
"title": "foo_annotatedSchema",
|
|
|
|
|
"type": "object",
|
|
|
|
|
"properties": {
|
|
|
|
|
"bar": {
|
|
|
|
|
"title": "Bar",
|
|
|
|
|
"type": "integer",
|
|
|
|
|
"description": "This is bar\nit's useful",
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
"required": ["bar"],
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_annotated_tool_typing() -> None:
|
|
|
|
|
@tool
|
|
|
|
|
def foo(bar: Annotated[int, "This is bar", {"gte": 5}, "it's useful"]) -> str:
|
|
|
|
|
"""The foo."""
|
|
|
|
|
return str(bar)
|
|
|
|
|
|
|
|
|
|
assert foo.invoke({"bar": 5}) == "5" # type: ignore
|
|
|
|
|
with pytest.raises(ValidationError):
|
|
|
|
|
foo.invoke({"bar": 4}) # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def test_annotated_async_tool_typing() -> None:
|
|
|
|
|
@tool
|
|
|
|
|
async def foo(bar: Annotated[int, "This is bar", {"gte": 5}, "it's useful"]) -> str:
|
|
|
|
|
"""The foo."""
|
|
|
|
|
return str(bar)
|
|
|
|
|
|
|
|
|
|
assert await foo.ainvoke({"bar": 5}) == "5" # type: ignore
|
|
|
|
|
with pytest.raises(ValidationError):
|
|
|
|
|
await foo.ainvoke({"bar": 4}) # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_tool_pass_context() -> None:
|
|
|
|
|
@tool
|
|
|
|
|
def foo(bar: str) -> str:
|
|
|
|
|