From a7b41750912d5986035d773bcb5aaf4753ea2d75 Mon Sep 17 00:00:00 2001 From: ccurme Date: Thu, 20 Jun 2024 17:20:11 -0400 Subject: [PATCH] standard tests: add test for tool calling (#23234) Including streaming --- .../integration_tests/chat_models.py | 43 ++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py index 44bf02a0ba..90ef83fd4f 100644 --- a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py @@ -1,10 +1,18 @@ import base64 import json +from typing import Optional import httpx import pytest from langchain_core.language_models import BaseChatModel -from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, ToolMessage +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessageChunk, + HumanMessage, + ToolMessage, +) +from langchain_core.tools import tool from langchain_standard_tests.unit_tests.chat_models import ( ChatModelTests, @@ -12,6 +20,21 @@ from langchain_standard_tests.unit_tests.chat_models import ( ) +@tool +def magic_function(input: int) -> int: + """Applies a magic function to an input.""" + return input + 2 + + +def _validate_tool_call_message(message: AIMessage) -> None: + assert isinstance(message, AIMessage) + assert len(message.tool_calls) == 1 + tool_call = message.tool_calls[0] + assert tool_call["name"] == "magic_function" + assert tool_call["args"] == {"input": 3} + assert tool_call["id"] is not None + + class ChatModelIntegrationTests(ChatModelTests): def test_invoke(self, model: BaseChatModel) -> None: result = model.invoke("Hello") @@ -98,6 +121,24 @@ class ChatModelIntegrationTests(ChatModelTests): result = custom_model.invoke("hi") assert isinstance(result, AIMessage) + def test_tool_calling(self, model: BaseChatModel) -> None: + if not self.has_tool_calling: + pytest.skip("Test requires tool calling.") + model_with_tools = model.bind_tools([magic_function]) + + # Test invoke + query = "What is the value of magic_function(3)? Use the tool." + result = model_with_tools.invoke(query) + assert isinstance(result, AIMessage) + _validate_tool_call_message(result) + + # Test stream + full: Optional[BaseMessageChunk] = None + for chunk in model_with_tools.stream(query): + full = chunk if full is None else full + chunk # type: ignore + assert isinstance(full, AIMessage) + _validate_tool_call_message(full) + def test_tool_message_histories_string_content( self, model: BaseChatModel,