|
|
|
@ -1,9 +1,10 @@
|
|
|
|
|
import json
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
from typing import Type
|
|
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
from langchain_core.language_models import BaseChatModel
|
|
|
|
|
from langchain_core.messages import AIMessage, AIMessageChunk
|
|
|
|
|
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, ToolMessage
|
|
|
|
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
|
|
|
|
from langchain_core.tools import tool
|
|
|
|
|
|
|
|
|
@ -115,3 +116,38 @@ class ChatModelIntegrationTests(ABC):
|
|
|
|
|
assert isinstance(result, AIMessage)
|
|
|
|
|
assert isinstance(result.content, str)
|
|
|
|
|
assert len(result.content) > 0
|
|
|
|
|
|
|
|
|
|
def test_tool_message(
|
|
|
|
|
self,
|
|
|
|
|
chat_model_class: Type[BaseChatModel],
|
|
|
|
|
chat_model_params: dict,
|
|
|
|
|
chat_model_has_tool_calling: bool,
|
|
|
|
|
) -> None:
|
|
|
|
|
if not chat_model_has_tool_calling:
|
|
|
|
|
pytest.skip("Test requires tool calling.")
|
|
|
|
|
model = chat_model_class(**chat_model_params)
|
|
|
|
|
model_with_tools = model.bind_tools([my_adder_tool])
|
|
|
|
|
function_name = "my_adder_tool"
|
|
|
|
|
function_args = {"a": "1", "b": "2"}
|
|
|
|
|
|
|
|
|
|
messages = [
|
|
|
|
|
HumanMessage(content="What is 1 + 2"),
|
|
|
|
|
AIMessage(
|
|
|
|
|
content="",
|
|
|
|
|
tool_calls=[
|
|
|
|
|
{
|
|
|
|
|
"name": function_name,
|
|
|
|
|
"args": function_args,
|
|
|
|
|
"id": "abc123",
|
|
|
|
|
},
|
|
|
|
|
],
|
|
|
|
|
),
|
|
|
|
|
ToolMessage(
|
|
|
|
|
name=function_name,
|
|
|
|
|
content=json.dumps({"result": 3}),
|
|
|
|
|
tool_call_id="abc123",
|
|
|
|
|
),
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
result = model_with_tools.invoke(messages)
|
|
|
|
|
assert isinstance(result, AIMessage)
|
|
|
|
|