standard-tests: split tool calling test (#20803)

just making it a bit easier to grok
pull/20659/head^2
Erick Friis 1 month ago committed by GitHub
parent 6622829c67
commit ddc2274aea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -117,13 +117,16 @@ class ChatModelIntegrationTests(ABC):
assert isinstance(result.content, str) assert isinstance(result.content, str)
assert len(result.content) > 0 assert len(result.content) > 0
def test_tool_message_histories( def test_tool_message_histories_string_content(
self, self,
chat_model_class: Type[BaseChatModel], chat_model_class: Type[BaseChatModel],
chat_model_params: dict, chat_model_params: dict,
chat_model_has_tool_calling: bool, chat_model_has_tool_calling: bool,
) -> None: ) -> None:
"""Test that message histories are compatible across providers.""" """
Test that message histories are compatible with string tool contents
(e.g. OpenAI).
"""
if not chat_model_has_tool_calling: if not chat_model_has_tool_calling:
pytest.skip("Test requires tool calling.") pytest.skip("Test requires tool calling.")
model = chat_model_class(**chat_model_params) model = chat_model_class(**chat_model_params)
@ -131,55 +134,71 @@ class ChatModelIntegrationTests(ABC):
function_name = "my_adder_tool" function_name = "my_adder_tool"
function_args = {"a": "1", "b": "2"} function_args = {"a": "1", "b": "2"}
human_message = HumanMessage(content="What is 1 + 2") messages_string_content = [
tool_message = ToolMessage( HumanMessage(content="What is 1 + 2"),
name=function_name, # string content (e.g. OpenAI)
content=json.dumps({"result": 3}), AIMessage(
tool_call_id="abc123", content="",
) tool_calls=[
{
# String content (e.g., OpenAI) "name": function_name,
string_content_msg = AIMessage( "args": function_args,
content="", "id": "abc123",
tool_calls=[ },
{ ],
"name": function_name, ),
"args": function_args, ToolMessage(
"id": "abc123", name=function_name,
}, content=json.dumps({"result": 3}),
], tool_call_id="abc123",
) ),
messages = [
human_message,
string_content_msg,
tool_message,
] ]
result = model_with_tools.invoke(messages) result_string_content = model_with_tools.invoke(messages_string_content)
assert isinstance(result, AIMessage) assert isinstance(result_string_content, AIMessage)
# List content (e.g., Anthropic) def test_tool_message_histories_list_content(
list_content_msg = AIMessage( self,
content=[ chat_model_class: Type[BaseChatModel],
{"type": "text", "text": "some text"}, chat_model_params: dict,
{ chat_model_has_tool_calling: bool,
"type": "tool_use", ) -> None:
"id": "abc123", """
"name": function_name, Test that message histories are compatible with list tool contents
"input": function_args, (e.g. Anthropic).
}, """
], if not chat_model_has_tool_calling:
tool_calls=[ pytest.skip("Test requires tool calling.")
{ model = chat_model_class(**chat_model_params)
"name": function_name, model_with_tools = model.bind_tools([my_adder_tool])
"args": function_args, function_name = "my_adder_tool"
"id": "abc123", function_args = {"a": 1, "b": 2}
},
], messages_list_content = [
) HumanMessage(content="What is 1 + 2"),
messages = [ # List content (e.g., Anthropic)
human_message, AIMessage(
list_content_msg, content=[
tool_message, {"type": "text", "text": "some text"},
{
"type": "tool_use",
"id": "abc123",
"name": function_name,
"input": function_args,
},
],
tool_calls=[
{
"name": function_name,
"args": function_args,
"id": "abc123",
},
],
),
ToolMessage(
name=function_name,
content=json.dumps({"result": 3}),
tool_call_id="abc123",
),
] ]
result = model_with_tools.invoke(messages) result_list_content = model_with_tools.invoke(messages_list_content)
assert isinstance(result, AIMessage) assert isinstance(result_list_content, AIMessage)

Loading…
Cancel
Save