standard tests: add test for tool calling (#23234)

Including streaming
pull/23211/head^2
ccurme 2 weeks ago committed by GitHub
parent 12e0c28a6e
commit a7b4175091
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,10 +1,18 @@
import base64
import json
from typing import Optional
import httpx
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, ToolMessage
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessageChunk,
HumanMessage,
ToolMessage,
)
from langchain_core.tools import tool
from langchain_standard_tests.unit_tests.chat_models import (
ChatModelTests,
@ -12,6 +20,21 @@ from langchain_standard_tests.unit_tests.chat_models import (
)
@tool
def magic_function(input: int) -> int:
"""Applies a magic function to an input."""
return input + 2
def _validate_tool_call_message(message: AIMessage) -> None:
assert isinstance(message, AIMessage)
assert len(message.tool_calls) == 1
tool_call = message.tool_calls[0]
assert tool_call["name"] == "magic_function"
assert tool_call["args"] == {"input": 3}
assert tool_call["id"] is not None
class ChatModelIntegrationTests(ChatModelTests):
def test_invoke(self, model: BaseChatModel) -> None:
result = model.invoke("Hello")
@ -98,6 +121,24 @@ class ChatModelIntegrationTests(ChatModelTests):
result = custom_model.invoke("hi")
assert isinstance(result, AIMessage)
def test_tool_calling(self, model: BaseChatModel) -> None:
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
model_with_tools = model.bind_tools([magic_function])
# Test invoke
query = "What is the value of magic_function(3)? Use the tool."
result = model_with_tools.invoke(query)
assert isinstance(result, AIMessage)
_validate_tool_call_message(result)
# Test stream
full: Optional[BaseMessageChunk] = None
for chunk in model_with_tools.stream(query):
full = chunk if full is None else full + chunk # type: ignore
assert isinstance(full, AIMessage)
_validate_tool_call_message(full)
def test_tool_message_histories_string_content(
self,
model: BaseChatModel,

Loading…
Cancel
Save