|
|
@ -1,10 +1,18 @@
|
|
|
|
import base64
|
|
|
|
import base64
|
|
|
|
import json
|
|
|
|
import json
|
|
|
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
|
|
import httpx
|
|
|
|
import httpx
|
|
|
|
import pytest
|
|
|
|
import pytest
|
|
|
|
from langchain_core.language_models import BaseChatModel
|
|
|
|
from langchain_core.language_models import BaseChatModel
|
|
|
|
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, ToolMessage
|
|
|
|
from langchain_core.messages import (
|
|
|
|
|
|
|
|
AIMessage,
|
|
|
|
|
|
|
|
AIMessageChunk,
|
|
|
|
|
|
|
|
BaseMessageChunk,
|
|
|
|
|
|
|
|
HumanMessage,
|
|
|
|
|
|
|
|
ToolMessage,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
from langchain_core.tools import tool
|
|
|
|
|
|
|
|
|
|
|
|
from langchain_standard_tests.unit_tests.chat_models import (
|
|
|
|
from langchain_standard_tests.unit_tests.chat_models import (
|
|
|
|
ChatModelTests,
|
|
|
|
ChatModelTests,
|
|
|
@ -12,6 +20,21 @@ from langchain_standard_tests.unit_tests.chat_models import (
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@tool
|
|
|
|
|
|
|
|
def magic_function(input: int) -> int:
|
|
|
|
|
|
|
|
"""Applies a magic function to an input."""
|
|
|
|
|
|
|
|
return input + 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _validate_tool_call_message(message: AIMessage) -> None:
|
|
|
|
|
|
|
|
assert isinstance(message, AIMessage)
|
|
|
|
|
|
|
|
assert len(message.tool_calls) == 1
|
|
|
|
|
|
|
|
tool_call = message.tool_calls[0]
|
|
|
|
|
|
|
|
assert tool_call["name"] == "magic_function"
|
|
|
|
|
|
|
|
assert tool_call["args"] == {"input": 3}
|
|
|
|
|
|
|
|
assert tool_call["id"] is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChatModelIntegrationTests(ChatModelTests):
|
|
|
|
class ChatModelIntegrationTests(ChatModelTests):
|
|
|
|
def test_invoke(self, model: BaseChatModel) -> None:
|
|
|
|
def test_invoke(self, model: BaseChatModel) -> None:
|
|
|
|
result = model.invoke("Hello")
|
|
|
|
result = model.invoke("Hello")
|
|
|
@ -98,6 +121,24 @@ class ChatModelIntegrationTests(ChatModelTests):
|
|
|
|
result = custom_model.invoke("hi")
|
|
|
|
result = custom_model.invoke("hi")
|
|
|
|
assert isinstance(result, AIMessage)
|
|
|
|
assert isinstance(result, AIMessage)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_tool_calling(self, model: BaseChatModel) -> None:
|
|
|
|
|
|
|
|
if not self.has_tool_calling:
|
|
|
|
|
|
|
|
pytest.skip("Test requires tool calling.")
|
|
|
|
|
|
|
|
model_with_tools = model.bind_tools([magic_function])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Test invoke
|
|
|
|
|
|
|
|
query = "What is the value of magic_function(3)? Use the tool."
|
|
|
|
|
|
|
|
result = model_with_tools.invoke(query)
|
|
|
|
|
|
|
|
assert isinstance(result, AIMessage)
|
|
|
|
|
|
|
|
_validate_tool_call_message(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Test stream
|
|
|
|
|
|
|
|
full: Optional[BaseMessageChunk] = None
|
|
|
|
|
|
|
|
for chunk in model_with_tools.stream(query):
|
|
|
|
|
|
|
|
full = chunk if full is None else full + chunk # type: ignore
|
|
|
|
|
|
|
|
assert isinstance(full, AIMessage)
|
|
|
|
|
|
|
|
_validate_tool_call_message(full)
|
|
|
|
|
|
|
|
|
|
|
|
def test_tool_message_histories_string_content(
|
|
|
|
def test_tool_message_histories_string_content(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
model: BaseChatModel,
|
|
|
|
model: BaseChatModel,
|
|
|
|