2024-04-16 23:25:50 +00:00
|
|
|
import json
|
2024-04-09 19:43:00 +00:00
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
from typing import Type
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
from langchain_core.language_models import BaseChatModel
|
2024-04-16 23:25:50 +00:00
|
|
|
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, ToolMessage
|
2024-04-09 19:43:00 +00:00
|
|
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
|
|
|
from langchain_core.tools import tool
|
|
|
|
|
|
|
|
|
|
|
|
class Person(BaseModel):
|
|
|
|
name: str = Field(..., description="The name of the person.")
|
|
|
|
age: int = Field(..., description="The age of the person.")
|
|
|
|
|
|
|
|
|
|
|
|
@tool
|
|
|
|
def my_adder_tool(a: int, b: int) -> int:
|
|
|
|
"""Takes two integers, a and b, and returns their sum."""
|
|
|
|
return a + b
|
|
|
|
|
|
|
|
|
|
|
|
class ChatModelIntegrationTests(ABC):
|
|
|
|
@abstractmethod
|
|
|
|
@pytest.fixture
|
|
|
|
def chat_model_class(self) -> Type[BaseChatModel]:
|
|
|
|
...
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def chat_model_params(self) -> dict:
|
|
|
|
return {}
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def chat_model_has_tool_calling(
|
|
|
|
self, chat_model_class: Type[BaseChatModel]
|
|
|
|
) -> bool:
|
2024-04-16 16:12:36 +00:00
|
|
|
return chat_model_class.bind_tools is not BaseChatModel.bind_tools
|
2024-04-09 19:43:00 +00:00
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def chat_model_has_structured_output(
|
|
|
|
self, chat_model_class: Type[BaseChatModel]
|
|
|
|
) -> bool:
|
2024-04-16 16:12:36 +00:00
|
|
|
return (
|
|
|
|
chat_model_class.with_structured_output
|
|
|
|
is not BaseChatModel.with_structured_output
|
|
|
|
)
|
2024-04-09 19:43:00 +00:00
|
|
|
|
|
|
|
def test_invoke(
|
|
|
|
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
|
|
|
) -> None:
|
|
|
|
model = chat_model_class(**chat_model_params)
|
|
|
|
result = model.invoke("Hello")
|
|
|
|
assert result is not None
|
|
|
|
assert isinstance(result, AIMessage)
|
|
|
|
assert isinstance(result.content, str)
|
|
|
|
assert len(result.content) > 0
|
|
|
|
|
|
|
|
async def test_ainvoke(
|
|
|
|
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
|
|
|
) -> None:
|
|
|
|
model = chat_model_class(**chat_model_params)
|
|
|
|
result = await model.ainvoke("Hello")
|
|
|
|
assert result is not None
|
|
|
|
assert isinstance(result, AIMessage)
|
|
|
|
assert isinstance(result.content, str)
|
|
|
|
assert len(result.content) > 0
|
|
|
|
|
|
|
|
def test_stream(
|
|
|
|
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
|
|
|
) -> None:
|
|
|
|
model = chat_model_class(**chat_model_params)
|
|
|
|
num_tokens = 0
|
|
|
|
for token in model.stream("Hello"):
|
|
|
|
assert token is not None
|
|
|
|
assert isinstance(token, AIMessageChunk)
|
|
|
|
num_tokens += len(token.content)
|
|
|
|
assert num_tokens > 0
|
|
|
|
|
|
|
|
async def test_astream(
|
|
|
|
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
|
|
|
) -> None:
|
|
|
|
model = chat_model_class(**chat_model_params)
|
|
|
|
num_tokens = 0
|
|
|
|
async for token in model.astream("Hello"):
|
|
|
|
assert token is not None
|
|
|
|
assert isinstance(token, AIMessageChunk)
|
|
|
|
num_tokens += len(token.content)
|
|
|
|
assert num_tokens > 0
|
|
|
|
|
|
|
|
def test_batch(
|
|
|
|
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
|
|
|
) -> None:
|
|
|
|
model = chat_model_class(**chat_model_params)
|
|
|
|
batch_results = model.batch(["Hello", "Hey"])
|
|
|
|
assert batch_results is not None
|
|
|
|
assert isinstance(batch_results, list)
|
|
|
|
assert len(batch_results) == 2
|
|
|
|
for result in batch_results:
|
|
|
|
assert result is not None
|
|
|
|
assert isinstance(result, AIMessage)
|
|
|
|
assert isinstance(result.content, str)
|
|
|
|
assert len(result.content) > 0
|
|
|
|
|
|
|
|
async def test_abatch(
|
|
|
|
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
|
|
|
) -> None:
|
|
|
|
model = chat_model_class(**chat_model_params)
|
|
|
|
batch_results = await model.abatch(["Hello", "Hey"])
|
|
|
|
assert batch_results is not None
|
|
|
|
assert isinstance(batch_results, list)
|
|
|
|
assert len(batch_results) == 2
|
|
|
|
for result in batch_results:
|
|
|
|
assert result is not None
|
|
|
|
assert isinstance(result, AIMessage)
|
|
|
|
assert isinstance(result.content, str)
|
|
|
|
assert len(result.content) > 0
|
2024-04-16 23:25:50 +00:00
|
|
|
|
2024-05-02 17:47:10 +00:00
|
|
|
def test_conversation(
|
|
|
|
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
|
|
|
) -> None:
|
|
|
|
model = chat_model_class(**chat_model_params)
|
|
|
|
messages = [
|
|
|
|
HumanMessage(content="hello"),
|
|
|
|
AIMessage(content="hello"),
|
|
|
|
HumanMessage(content="how are you"),
|
|
|
|
]
|
|
|
|
result = model.invoke(messages)
|
|
|
|
assert result is not None
|
|
|
|
assert isinstance(result, AIMessage)
|
|
|
|
assert isinstance(result.content, str)
|
|
|
|
assert len(result.content) > 0
|
|
|
|
|
2024-05-23 18:21:58 +00:00
|
|
|
def test_usage_metadata(
|
|
|
|
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
|
|
|
) -> None:
|
|
|
|
model = chat_model_class(**chat_model_params)
|
|
|
|
result = model.invoke("Hello")
|
|
|
|
assert result is not None
|
|
|
|
assert isinstance(result, AIMessage)
|
|
|
|
assert result.usage_metadata is not None
|
|
|
|
assert isinstance(result.usage_metadata["input_tokens"], int)
|
|
|
|
assert isinstance(result.usage_metadata["output_tokens"], int)
|
|
|
|
assert isinstance(result.usage_metadata["total_tokens"], int)
|
|
|
|
|
2024-06-06 16:11:52 +00:00
|
|
|
def test_stop_sequence(
|
|
|
|
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
|
|
|
) -> None:
|
|
|
|
model = chat_model_class(**chat_model_params)
|
|
|
|
result = model.invoke("hi", stop=["you"])
|
|
|
|
assert isinstance(result, AIMessage)
|
|
|
|
|
|
|
|
model = chat_model_class(**chat_model_params, stop=["you"])
|
|
|
|
result = model.invoke("hi")
|
|
|
|
assert isinstance(result, AIMessage)
|
|
|
|
|
2024-04-23 20:59:45 +00:00
|
|
|
def test_tool_message_histories_string_content(
|
2024-04-16 23:25:50 +00:00
|
|
|
self,
|
|
|
|
chat_model_class: Type[BaseChatModel],
|
|
|
|
chat_model_params: dict,
|
|
|
|
chat_model_has_tool_calling: bool,
|
|
|
|
) -> None:
|
2024-04-23 20:59:45 +00:00
|
|
|
"""
|
|
|
|
Test that message histories are compatible with string tool contents
|
|
|
|
(e.g. OpenAI).
|
|
|
|
"""
|
2024-04-16 23:25:50 +00:00
|
|
|
if not chat_model_has_tool_calling:
|
|
|
|
pytest.skip("Test requires tool calling.")
|
|
|
|
model = chat_model_class(**chat_model_params)
|
|
|
|
model_with_tools = model.bind_tools([my_adder_tool])
|
|
|
|
function_name = "my_adder_tool"
|
|
|
|
function_args = {"a": "1", "b": "2"}
|
|
|
|
|
2024-04-23 20:59:45 +00:00
|
|
|
messages_string_content = [
|
|
|
|
HumanMessage(content="What is 1 + 2"),
|
|
|
|
# string content (e.g. OpenAI)
|
|
|
|
AIMessage(
|
|
|
|
content="",
|
|
|
|
tool_calls=[
|
|
|
|
{
|
|
|
|
"name": function_name,
|
|
|
|
"args": function_args,
|
|
|
|
"id": "abc123",
|
|
|
|
},
|
|
|
|
],
|
|
|
|
),
|
|
|
|
ToolMessage(
|
|
|
|
name=function_name,
|
|
|
|
content=json.dumps({"result": 3}),
|
|
|
|
tool_call_id="abc123",
|
|
|
|
),
|
2024-04-16 23:25:50 +00:00
|
|
|
]
|
2024-04-23 20:59:45 +00:00
|
|
|
result_string_content = model_with_tools.invoke(messages_string_content)
|
|
|
|
assert isinstance(result_string_content, AIMessage)
|
2024-04-16 23:25:50 +00:00
|
|
|
|
2024-04-23 20:59:45 +00:00
|
|
|
def test_tool_message_histories_list_content(
|
|
|
|
self,
|
|
|
|
chat_model_class: Type[BaseChatModel],
|
|
|
|
chat_model_params: dict,
|
|
|
|
chat_model_has_tool_calling: bool,
|
|
|
|
) -> None:
|
|
|
|
"""
|
|
|
|
Test that message histories are compatible with list tool contents
|
|
|
|
(e.g. Anthropic).
|
|
|
|
"""
|
|
|
|
if not chat_model_has_tool_calling:
|
|
|
|
pytest.skip("Test requires tool calling.")
|
|
|
|
model = chat_model_class(**chat_model_params)
|
|
|
|
model_with_tools = model.bind_tools([my_adder_tool])
|
|
|
|
function_name = "my_adder_tool"
|
|
|
|
function_args = {"a": 1, "b": 2}
|
|
|
|
|
|
|
|
messages_list_content = [
|
|
|
|
HumanMessage(content="What is 1 + 2"),
|
|
|
|
# List content (e.g., Anthropic)
|
|
|
|
AIMessage(
|
|
|
|
content=[
|
|
|
|
{"type": "text", "text": "some text"},
|
|
|
|
{
|
|
|
|
"type": "tool_use",
|
|
|
|
"id": "abc123",
|
|
|
|
"name": function_name,
|
|
|
|
"input": function_args,
|
|
|
|
},
|
|
|
|
],
|
|
|
|
tool_calls=[
|
|
|
|
{
|
|
|
|
"name": function_name,
|
|
|
|
"args": function_args,
|
|
|
|
"id": "abc123",
|
|
|
|
},
|
|
|
|
],
|
|
|
|
),
|
|
|
|
ToolMessage(
|
|
|
|
name=function_name,
|
|
|
|
content=json.dumps({"result": 3}),
|
|
|
|
tool_call_id="abc123",
|
|
|
|
),
|
2024-04-17 19:55:45 +00:00
|
|
|
]
|
2024-04-23 20:59:45 +00:00
|
|
|
result_list_content = model_with_tools.invoke(messages_list_content)
|
|
|
|
assert isinstance(result_list_content, AIMessage)
|
2024-05-13 14:06:12 +00:00
|
|
|
|
|
|
|
def test_structured_few_shot_examples(
|
|
|
|
self,
|
|
|
|
chat_model_class: Type[BaseChatModel],
|
|
|
|
chat_model_params: dict,
|
|
|
|
chat_model_has_tool_calling: bool,
|
|
|
|
) -> None:
|
|
|
|
"""
|
|
|
|
Test that model can process few-shot examples with tool calls.
|
|
|
|
"""
|
|
|
|
if not chat_model_has_tool_calling:
|
|
|
|
pytest.skip("Test requires tool calling.")
|
|
|
|
model = chat_model_class(**chat_model_params)
|
|
|
|
model_with_tools = model.bind_tools([my_adder_tool])
|
|
|
|
function_name = "my_adder_tool"
|
|
|
|
function_args = {"a": 1, "b": 2}
|
|
|
|
function_result = json.dumps({"result": 3})
|
|
|
|
|
|
|
|
messages_string_content = [
|
|
|
|
HumanMessage(content="What is 1 + 2"),
|
|
|
|
AIMessage(
|
|
|
|
content="",
|
|
|
|
tool_calls=[
|
|
|
|
{
|
|
|
|
"name": function_name,
|
|
|
|
"args": function_args,
|
|
|
|
"id": "abc123",
|
|
|
|
},
|
|
|
|
],
|
|
|
|
),
|
|
|
|
ToolMessage(
|
|
|
|
name=function_name,
|
|
|
|
content=function_result,
|
|
|
|
tool_call_id="abc123",
|
|
|
|
),
|
|
|
|
AIMessage(content=function_result),
|
|
|
|
HumanMessage(content="What is 3 + 4"),
|
|
|
|
]
|
|
|
|
result_string_content = model_with_tools.invoke(messages_string_content)
|
|
|
|
assert isinstance(result_string_content, AIMessage)
|