@ -2,11 +2,19 @@
from datetime import datetime
from datetime import datetime
from functools import partial
from functools import partial
from typing import Optional , Type , Union
from typing import Optional , Type , Union
from unittest . mock import MagicMock
import pydantic
import pydantic
import pytest
import pytest
from pydantic import BaseModel
from pydantic import BaseModel
from langchain . agents . agent import Agent
from langchain . agents . chat . base import ChatAgent
from langchain . agents . conversational . base import ConversationalAgent
from langchain . agents . conversational_chat . base import ConversationalChatAgent
from langchain . agents . mrkl . base import ZeroShotAgent
from langchain . agents . react . base import ReActDocstoreAgent , ReActTextWorldAgent
from langchain . agents . self_ask_with_search . base import SelfAskWithSearchAgent
from langchain . agents . tools import Tool , tool
from langchain . agents . tools import Tool , tool
from langchain . tools . base import BaseTool , SchemaAnnotationError
from langchain . tools . base import BaseTool , SchemaAnnotationError
@ -152,6 +160,7 @@ def test_decorated_function_schema_equivalent() -> None:
return f " { arg1 } { arg2 } { arg3 } "
return f " { arg1 } { arg2 } { arg3 } "
assert isinstance ( structured_tool_input , Tool )
assert isinstance ( structured_tool_input , Tool )
assert structured_tool_input . args_schema is not None
assert (
assert (
structured_tool_input . args_schema . schema ( ) [ " properties " ]
structured_tool_input . args_schema . schema ( ) [ " properties " ]
== _MockSchema . schema ( ) [ " properties " ]
== _MockSchema . schema ( ) [ " properties " ]
@ -309,33 +318,38 @@ def test_tool_with_kwargs() -> None:
@tool ( return_direct = True )
@tool ( return_direct = True )
def search_api (
def search_api (
arg_1 : float ,
arg_0 : str ,
arg_1 : float = 4.3 ,
ping : str = " hi " ,
ping : str = " hi " ,
) - > str :
) - > str :
""" Search the API for the query. """
""" Search the API for the query. """
return f " arg_ 1={ arg_1 } , ping= { ping } "
return f " arg_ 0={ arg_0 } , arg_ 1={ arg_1 } , ping= { ping } "
assert isinstance ( search_api , Tool )
assert isinstance ( search_api , Tool )
result = search_api . run (
result = search_api . run (
tool_input = {
tool_input = {
" arg_0 " : " foo " ,
" arg_1 " : 3.2 ,
" arg_1 " : 3.2 ,
" ping " : " pong " ,
" ping " : " pong " ,
}
}
)
)
assert result == " arg_ 1=3.2, ping=pong"
assert result == " arg_ 0=foo, arg_ 1=3.2, ping=pong"
result = search_api . run (
result = search_api . run (
tool_input = {
tool_input = {
" arg_ 1" : 3.2 ,
" arg_ 0" : " foo " ,
}
}
)
)
assert result == " arg_1=3.2, ping=hi "
assert result == " arg_0=foo, arg_1=4.3, ping=hi "
# For backwards compatibility, we still accept a single str arg
result = search_api . run ( " foobar " )
assert result == " arg_0=foobar, arg_1=4.3, ping=hi "
def test_missing_docstring ( ) - > None :
def test_missing_docstring ( ) - > None :
""" Test error is raised when docstring is missing. """
""" Test error is raised when docstring is missing. """
# expect to throw a value error if theres no docstring
# expect to throw a value error if theres no docstring
with pytest . raises ( AssertionError ):
with pytest . raises ( AssertionError , match = " Function must have a docstring " ):
@tool
@tool
def search_api ( query : str ) - > str :
def search_api ( query : str ) - > str :
@ -348,11 +362,13 @@ def test_create_tool_positional_args() -> None:
assert test_tool ( " foo " ) == " foo "
assert test_tool ( " foo " ) == " foo "
assert test_tool . name == " test_name "
assert test_tool . name == " test_name "
assert test_tool . description == " test_description "
assert test_tool . description == " test_description "
assert test_tool . is_single_input
def test_create_tool_keyword_args ( ) - > None :
def test_create_tool_keyword_args ( ) - > None :
""" Test that keyword arguments are allowed. """
""" Test that keyword arguments are allowed. """
test_tool = Tool ( name = " test_name " , func = lambda x : x , description = " test_description " )
test_tool = Tool ( name = " test_name " , func = lambda x : x , description = " test_description " )
assert test_tool . is_single_input
assert test_tool ( " foo " ) == " foo "
assert test_tool ( " foo " ) == " foo "
assert test_tool . name == " test_name "
assert test_tool . name == " test_name "
assert test_tool . description == " test_description "
assert test_tool . description == " test_description "
@ -371,8 +387,39 @@ async def test_create_async_tool() -> None:
description = " test_description " ,
description = " test_description " ,
coroutine = _test_func ,
coroutine = _test_func ,
)
)
assert test_tool . is_single_input
assert test_tool ( " foo " ) == " foo "
assert test_tool ( " foo " ) == " foo "
assert test_tool . name == " test_name "
assert test_tool . name == " test_name "
assert test_tool . description == " test_description "
assert test_tool . description == " test_description "
assert test_tool . coroutine is not None
assert test_tool . coroutine is not None
assert await test_tool . arun ( " foo " ) == " foo "
assert await test_tool . arun ( " foo " ) == " foo "
@pytest.mark.parametrize (
" agent_cls " ,
[
ChatAgent ,
ZeroShotAgent ,
ConversationalChatAgent ,
ConversationalAgent ,
ReActDocstoreAgent ,
ReActTextWorldAgent ,
SelfAskWithSearchAgent ,
] ,
)
def test_single_input_agent_raises_error_on_structured_tool (
agent_cls : Type [ Agent ] ,
) - > None :
""" Test that older agents raise errors on older tools. """
@tool
def the_tool ( foo : str , bar : str ) - > str :
""" Return the concat of foo and bar. """
return foo + bar
with pytest . raises (
ValueError ,
match = f " { agent_cls . __name__ } does not support " # type: ignore
f " multi-input tool { the_tool . name } . " ,
) :
agent_cls . from_llm_and_tools ( MagicMock ( ) , [ the_tool ] ) # type: ignore