langchain/libs/partners/anthropic/tests/unit_tests/test_chat_models.py

550 lines
18 KiB
Python
Raw Normal View History

"""Test chat model integration."""
import os
from typing import Any, Callable, Dict, Literal, Type, cast
import pytest
from anthropic.types import Message, TextBlock, Usage
2024-04-04 20:22:48 +00:00
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr
from langchain_core.runnables import RunnableBinding
2024-04-04 20:22:48 +00:00
from langchain_core.tools import BaseTool
from pytest import CaptureFixture, MonkeyPatch
from langchain_anthropic import ChatAnthropic
from langchain_anthropic.chat_models import (
_format_messages,
_merge_messages,
convert_to_anthropic_tool,
)
os.environ["ANTHROPIC_API_KEY"] = "foo"
def test_initialization() -> None:
"""Test chat model initialization."""
for model in [
ChatAnthropic(model_name="claude-instant-1.2", api_key="xyz", timeout=2), # type: ignore[arg-type, call-arg]
2024-06-03 15:21:55 +00:00
ChatAnthropic( # type: ignore[call-arg, call-arg, call-arg]
model="claude-instant-1.2",
anthropic_api_key="xyz",
default_request_timeout=2,
base_url="https://api.anthropic.com",
),
]:
assert model.model == "claude-instant-1.2"
assert cast(SecretStr, model.anthropic_api_key).get_secret_value() == "xyz"
assert model.default_request_timeout == 2.0
assert model.anthropic_api_url == "https://api.anthropic.com"
@pytest.mark.requires("anthropic")
def test_anthropic_model_name_param() -> None:
2024-06-03 15:21:55 +00:00
llm = ChatAnthropic(model_name="foo") # type: ignore[call-arg, call-arg]
assert llm.model == "foo"
@pytest.mark.requires("anthropic")
def test_anthropic_model_param() -> None:
2024-06-03 15:21:55 +00:00
llm = ChatAnthropic(model="foo") # type: ignore[call-arg]
assert llm.model == "foo"
@pytest.mark.requires("anthropic")
def test_anthropic_model_kwargs() -> None:
2024-06-03 15:21:55 +00:00
llm = ChatAnthropic(model_name="foo", model_kwargs={"foo": "bar"}) # type: ignore[call-arg, call-arg]
assert llm.model_kwargs == {"foo": "bar"}
@pytest.mark.requires("anthropic")
def test_anthropic_invalid_model_kwargs() -> None:
with pytest.raises(ValueError):
2024-06-03 15:21:55 +00:00
ChatAnthropic(model="foo", model_kwargs={"max_tokens_to_sample": 5}) # type: ignore[call-arg]
@pytest.mark.requires("anthropic")
def test_anthropic_incorrect_field() -> None:
with pytest.warns(match="not default parameter"):
2024-06-03 15:21:55 +00:00
llm = ChatAnthropic(model="foo", foo="bar") # type: ignore[call-arg, call-arg]
assert llm.model_kwargs == {"foo": "bar"}
@pytest.mark.requires("anthropic")
def test_anthropic_initialization() -> None:
"""Test anthropic initialization."""
# Verify that chat anthropic can be initialized using a secret key provided
# as a parameter rather than an environment variable.
2024-06-03 15:21:55 +00:00
ChatAnthropic(model="test", anthropic_api_key="test") # type: ignore[call-arg, call-arg]
def test__format_output() -> None:
anthropic_msg = Message(
id="foo",
content=[TextBlock(type="text", text="bar")],
model="baz",
role="assistant",
stop_reason=None,
stop_sequence=None,
usage=Usage(input_tokens=2, output_tokens=1),
type="message",
)
expected = ChatResult(
generations=[
ChatGeneration(
2024-06-03 15:21:55 +00:00
message=AIMessage( # type: ignore[misc]
"bar",
usage_metadata={
"input_tokens": 2,
"output_tokens": 1,
"total_tokens": 3,
},
)
),
],
llm_output={
"id": "foo",
"model": "baz",
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": 2, "output_tokens": 1},
},
)
2024-06-03 15:21:55 +00:00
llm = ChatAnthropic(model="test", anthropic_api_key="test") # type: ignore[call-arg, call-arg]
actual = llm._format_output(anthropic_msg)
assert expected == actual
2024-04-04 20:22:48 +00:00
def test__merge_messages() -> None:
messages = [
2024-06-03 15:21:55 +00:00
SystemMessage("foo"), # type: ignore[misc]
HumanMessage("bar"), # type: ignore[misc]
AIMessage( # type: ignore[misc]
2024-04-04 20:22:48 +00:00
[
{"text": "baz", "type": "text"},
{
"tool_input": {"a": "b"},
"type": "tool_use",
"id": "1",
"text": None,
"name": "buz",
},
{"text": "baz", "type": "text"},
{
"tool_input": {"a": "c"},
"type": "tool_use",
"id": "2",
"text": None,
"name": "blah",
},
]
),
2024-06-03 15:21:55 +00:00
ToolMessage("buz output", tool_call_id="1"), # type: ignore[misc]
ToolMessage("blah output", tool_call_id="2"), # type: ignore[misc]
HumanMessage("next thing"), # type: ignore[misc]
2024-04-04 20:22:48 +00:00
]
expected = [
2024-06-03 15:21:55 +00:00
SystemMessage("foo"), # type: ignore[misc]
HumanMessage("bar"), # type: ignore[misc]
AIMessage( # type: ignore[misc]
2024-04-04 20:22:48 +00:00
[
{"text": "baz", "type": "text"},
{
"tool_input": {"a": "b"},
"type": "tool_use",
"id": "1",
"text": None,
"name": "buz",
},
{"text": "baz", "type": "text"},
{
"tool_input": {"a": "c"},
"type": "tool_use",
"id": "2",
"text": None,
"name": "blah",
},
]
),
2024-06-03 15:21:55 +00:00
HumanMessage( # type: ignore[misc]
2024-04-04 20:22:48 +00:00
[
{"type": "tool_result", "content": "buz output", "tool_use_id": "1"},
{"type": "tool_result", "content": "blah output", "tool_use_id": "2"},
{"type": "text", "text": "next thing"},
]
),
]
actual = _merge_messages(messages)
assert expected == actual
def test__merge_messages_mutation() -> None:
original_messages = [
2024-06-03 15:21:55 +00:00
HumanMessage([{"type": "text", "text": "bar"}]), # type: ignore[misc]
HumanMessage("next thing"), # type: ignore[misc]
]
messages = [
2024-06-03 15:21:55 +00:00
HumanMessage([{"type": "text", "text": "bar"}]), # type: ignore[misc]
HumanMessage("next thing"), # type: ignore[misc]
]
expected = [
2024-06-03 15:21:55 +00:00
HumanMessage( # type: ignore[misc]
[{"type": "text", "text": "bar"}, {"type": "text", "text": "next thing"}]
),
]
actual = _merge_messages(messages)
assert expected == actual
assert messages == original_messages
2024-04-04 20:22:48 +00:00
@pytest.fixture()
def pydantic() -> Type[BaseModel]:
class dummy_function(BaseModel):
"""dummy function"""
arg1: int = Field(..., description="foo")
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
return dummy_function
@pytest.fixture()
def function() -> Callable:
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
"""dummy function
Args:
arg1: foo
arg2: one of 'bar', 'baz'
"""
pass
return dummy_function
@pytest.fixture()
def dummy_tool() -> BaseTool:
class Schema(BaseModel):
arg1: int = Field(..., description="foo")
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
class DummyFunction(BaseTool):
args_schema: Type[BaseModel] = Schema
name: str = "dummy_function"
description: str = "dummy function"
def _run(self, *args: Any, **kwargs: Any) -> Any:
pass
return DummyFunction()
@pytest.fixture()
def json_schema() -> Dict:
return {
"title": "dummy_function",
"description": "dummy function",
"type": "object",
"properties": {
"arg1": {"description": "foo", "type": "integer"},
"arg2": {
"description": "one of 'bar', 'baz'",
"enum": ["bar", "baz"],
"type": "string",
},
},
"required": ["arg1", "arg2"],
}
@pytest.fixture()
def openai_function() -> Dict:
return {
"name": "dummy_function",
"description": "dummy function",
"parameters": {
"type": "object",
"properties": {
"arg1": {"description": "foo", "type": "integer"},
"arg2": {
"description": "one of 'bar', 'baz'",
"enum": ["bar", "baz"],
"type": "string",
},
},
"required": ["arg1", "arg2"],
},
}
def test_convert_to_anthropic_tool(
pydantic: Type[BaseModel],
function: Callable,
dummy_tool: BaseTool,
json_schema: Dict,
openai_function: Dict,
) -> None:
expected = {
"name": "dummy_function",
"description": "dummy function",
"input_schema": {
"type": "object",
"properties": {
"arg1": {"description": "foo", "type": "integer"},
"arg2": {
"description": "one of 'bar', 'baz'",
"enum": ["bar", "baz"],
"type": "string",
},
},
"required": ["arg1", "arg2"],
},
}
for fn in (pydantic, function, dummy_tool, json_schema, expected, openai_function):
actual = convert_to_anthropic_tool(fn) # type: ignore
assert actual == expected
def test__format_messages_with_tool_calls() -> None:
2024-06-03 15:21:55 +00:00
system = SystemMessage("fuzz") # type: ignore[misc]
human = HumanMessage("foo") # type: ignore[misc]
ai = AIMessage( # type: ignore[misc]
"",
tool_calls=[{"name": "bar", "id": "1", "args": {"baz": "buzz"}}],
)
2024-06-03 15:21:55 +00:00
tool = ToolMessage( # type: ignore[misc]
"blurb",
tool_call_id="1",
)
messages = [system, human, ai, tool]
expected = (
"fuzz",
[
{"role": "user", "content": "foo"},
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"name": "bar",
"id": "1",
"input": {"baz": "buzz"},
}
],
},
{
"role": "user",
"content": [
{"type": "tool_result", "content": "blurb", "tool_use_id": "1"}
],
},
],
)
actual = _format_messages(messages)
assert expected == actual
def test__format_messages_with_str_content_and_tool_calls() -> None:
2024-06-03 15:21:55 +00:00
system = SystemMessage("fuzz") # type: ignore[misc]
human = HumanMessage("foo") # type: ignore[misc]
# If content and tool_calls are specified and content is a string, then both are
# included with content first.
2024-06-03 15:21:55 +00:00
ai = AIMessage( # type: ignore[misc]
"thought",
tool_calls=[{"name": "bar", "id": "1", "args": {"baz": "buzz"}}],
)
2024-06-03 15:21:55 +00:00
tool = ToolMessage("blurb", tool_call_id="1") # type: ignore[misc]
messages = [system, human, ai, tool]
expected = (
"fuzz",
[
{"role": "user", "content": "foo"},
{
"role": "assistant",
"content": [
{"type": "text", "text": "thought"},
{
"type": "tool_use",
"name": "bar",
"id": "1",
"input": {"baz": "buzz"},
},
],
},
{
"role": "user",
"content": [
{"type": "tool_result", "content": "blurb", "tool_use_id": "1"}
],
},
],
)
actual = _format_messages(messages)
assert expected == actual
def test__format_messages_with_list_content_and_tool_calls() -> None:
2024-06-03 15:21:55 +00:00
system = SystemMessage("fuzz") # type: ignore[misc]
human = HumanMessage("foo") # type: ignore[misc]
# If content and tool_calls are specified and content is a list, then content is
# preferred.
2024-06-03 15:21:55 +00:00
ai = AIMessage( # type: ignore[misc]
[{"type": "text", "text": "thought"}],
tool_calls=[{"name": "bar", "id": "1", "args": {"baz": "buzz"}}],
)
2024-06-03 15:21:55 +00:00
tool = ToolMessage( # type: ignore[misc]
"blurb",
tool_call_id="1",
)
messages = [system, human, ai, tool]
expected = (
"fuzz",
[
{"role": "user", "content": "foo"},
{
"role": "assistant",
"content": [{"type": "text", "text": "thought"}],
},
{
"role": "user",
"content": [
{"type": "tool_result", "content": "blurb", "tool_use_id": "1"}
],
},
],
)
actual = _format_messages(messages)
assert expected == actual
def test__format_messages_with_tool_use_blocks_and_tool_calls() -> None:
"""Show that tool_calls are preferred to tool_use blocks when both have same id."""
2024-06-03 15:21:55 +00:00
system = SystemMessage("fuzz") # type: ignore[misc]
human = HumanMessage("foo") # type: ignore[misc]
# NOTE: tool_use block in contents and tool_calls have different arguments.
2024-06-03 15:21:55 +00:00
ai = AIMessage( # type: ignore[misc]
[
{"type": "text", "text": "thought"},
{
"type": "tool_use",
"name": "bar",
"id": "1",
"input": {"baz": "NOT_BUZZ"},
},
],
tool_calls=[{"name": "bar", "id": "1", "args": {"baz": "BUZZ"}}],
)
2024-06-03 15:21:55 +00:00
tool = ToolMessage("blurb", tool_call_id="1") # type: ignore[misc]
messages = [system, human, ai, tool]
expected = (
"fuzz",
[
{"role": "user", "content": "foo"},
{
"role": "assistant",
"content": [
{"type": "text", "text": "thought"},
{
"type": "tool_use",
"name": "bar",
"id": "1",
"input": {"baz": "BUZZ"}, # tool_calls value preferred.
},
],
},
{
"role": "user",
"content": [
{"type": "tool_result", "content": "blurb", "tool_use_id": "1"}
],
},
],
)
actual = _format_messages(messages)
assert expected == actual
def test_anthropic_api_key_is_secret_string() -> None:
"""Test that the API key is stored as a SecretStr."""
2024-06-03 15:21:55 +00:00
chat_model = ChatAnthropic( # type: ignore[call-arg, call-arg]
model="claude-3-opus-20240229",
anthropic_api_key="secret-api-key",
)
assert isinstance(chat_model.anthropic_api_key, SecretStr)
def test_anthropic_api_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test that the API key is masked when passed from an environment variable."""
monkeypatch.setenv("ANTHROPIC_API_KEY ", "secret-api-key")
2024-06-03 15:21:55 +00:00
chat_model = ChatAnthropic( # type: ignore[call-arg]
model="claude-3-opus-20240229",
)
print(chat_model.anthropic_api_key, end="") # noqa: T201
captured = capsys.readouterr()
assert captured.out == "**********"
def test_anthropic_api_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
) -> None:
"""Test that the API key is masked when passed via the constructor."""
2024-06-03 15:21:55 +00:00
chat_model = ChatAnthropic( # type: ignore[call-arg, call-arg]
model="claude-3-opus-20240229",
anthropic_api_key="secret-api-key",
)
print(chat_model.anthropic_api_key, end="") # noqa: T201
captured = capsys.readouterr()
assert captured.out == "**********"
def test_anthropic_uses_actual_secret_value_from_secretstr() -> None:
"""Test that the actual secret value is correctly retrieved."""
2024-06-03 15:21:55 +00:00
chat_model = ChatAnthropic( # type: ignore[call-arg, call-arg]
model="claude-3-opus-20240229",
anthropic_api_key="secret-api-key",
)
assert (
cast(SecretStr, chat_model.anthropic_api_key).get_secret_value()
== "secret-api-key"
)
class GetWeather(BaseModel):
"""Get the current weather in a given location"""
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
def test_anthropic_bind_tools_tool_choice() -> None:
2024-06-03 15:21:55 +00:00
chat_model = ChatAnthropic( # type: ignore[call-arg, call-arg]
model="claude-3-opus-20240229",
anthropic_api_key="secret-api-key",
)
chat_model_with_tools = chat_model.bind_tools(
[GetWeather], tool_choice={"type": "tool", "name": "GetWeather"}
)
assert cast(RunnableBinding, chat_model_with_tools).kwargs["tool_choice"] == {
"type": "tool",
"name": "GetWeather",
}
chat_model_with_tools = chat_model.bind_tools(
[GetWeather], tool_choice="GetWeather"
)
assert cast(RunnableBinding, chat_model_with_tools).kwargs["tool_choice"] == {
"type": "tool",
"name": "GetWeather",
}
chat_model_with_tools = chat_model.bind_tools([GetWeather], tool_choice="auto")
assert cast(RunnableBinding, chat_model_with_tools).kwargs["tool_choice"] == {
"type": "auto"
}
chat_model_with_tools = chat_model.bind_tools([GetWeather], tool_choice="any")
assert cast(RunnableBinding, chat_model_with_tools).kwargs["tool_choice"] == {
"type": "any"
}