You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/partners/anthropic/tests/integration_tests/test_chat_models.py

522 lines
26 KiB
Python

"""Test ChatAnthropic chat model."""
import json
from typing import List, Optional
import pytest
from langchain_core.callbacks import CallbackManager
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import tool
from langchain_anthropic import ChatAnthropic, ChatAnthropicMessages
from tests.unit_tests._utils import FakeCallbackHandler
MODEL_NAME = "claude-3-sonnet-20240229"
def test_stream() -> None:
"""Test streaming tokens from Anthropic."""
llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg]
full: Optional[BaseMessageChunk] = None
chunks_with_input_token_counts = 0
chunks_with_output_token_counts = 0
for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token.content, str)
full = token if full is None else full + token
assert isinstance(token, AIMessageChunk)
if token.usage_metadata is not None:
if token.usage_metadata.get("input_tokens"):
chunks_with_input_token_counts += 1
elif token.usage_metadata.get("output_tokens"):
chunks_with_output_token_counts += 1
if chunks_with_input_token_counts != 1 or chunks_with_output_token_counts != 1:
raise AssertionError(
"Expected exactly one chunk with input or output token counts. "
"AIMessageChunk aggregation adds counts. Check that "
"this is behaving properly."
)
# check token usage is populated
assert isinstance(full, AIMessageChunk)
assert full.usage_metadata is not None
assert full.usage_metadata["input_tokens"] > 0
assert full.usage_metadata["output_tokens"] > 0
assert full.usage_metadata["total_tokens"] > 0
assert (
full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"]
== full.usage_metadata["total_tokens"]
)
assert "stop_reason" in full.response_metadata
assert "stop_sequence" in full.response_metadata
async def test_astream() -> None:
"""Test streaming tokens from Anthropic."""
llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg]
full: Optional[BaseMessageChunk] = None
chunks_with_input_token_counts = 0
chunks_with_output_token_counts = 0
async for token in llm.astream("I'm Pickle Rick"):
assert isinstance(token.content, str)
full = token if full is None else full + token
assert isinstance(token, AIMessageChunk)
if token.usage_metadata is not None:
if token.usage_metadata.get("input_tokens"):
chunks_with_input_token_counts += 1
elif token.usage_metadata.get("output_tokens"):
chunks_with_output_token_counts += 1
if chunks_with_input_token_counts != 1 or chunks_with_output_token_counts != 1:
raise AssertionError(
"Expected exactly one chunk with input or output token counts. "
"AIMessageChunk aggregation adds counts. Check that "
"this is behaving properly."
)
# check token usage is populated
assert isinstance(full, AIMessageChunk)
assert full.usage_metadata is not None
assert full.usage_metadata["input_tokens"] > 0
assert full.usage_metadata["output_tokens"] > 0
assert full.usage_metadata["total_tokens"] > 0
assert (
full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"]
== full.usage_metadata["total_tokens"]
)
assert "stop_reason" in full.response_metadata
assert "stop_sequence" in full.response_metadata
# test usage metadata can be excluded
model = ChatAnthropic(model_name=MODEL_NAME, stream_usage=False) # type: ignore[call-arg]
async for token in model.astream("hi"):
assert isinstance(token, AIMessageChunk)
assert token.usage_metadata is None
# check we override with kwarg
model = ChatAnthropic(model_name=MODEL_NAME) # type: ignore[call-arg]
assert model.stream_usage
async for token in model.astream("hi", stream_usage=False):
assert isinstance(token, AIMessageChunk)
assert token.usage_metadata is None
# Check expected raw API output
async_client = model._async_client
params: dict = {
"model": "claude-3-haiku-20240307",
"max_tokens": 1024,
"messages": [{"role": "user", "content": "hi"}],
"temperature": 0.0,
}
stream = await async_client.messages.create(**params, stream=True)
async for event in stream:
if event.type == "message_start":
assert event.message.usage.input_tokens > 1
# Note: this single output token included in message start event
# does not appear to contribute to overall output token counts. It
# is excluded from the total token count.
assert event.message.usage.output_tokens == 1
elif event.type == "message_delta":
assert event.usage.output_tokens > 1
else:
pass
async def test_abatch() -> None:
"""Test streaming tokens from ChatAnthropicMessages."""
llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg]
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token.content, str)
async def test_abatch_tags() -> None:
"""Test batch tokens from ChatAnthropicMessages."""
llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg]
result = await llm.abatch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
)
for token in result:
assert isinstance(token.content, str)
async def test_async_tool_use() -> None:
llm = ChatAnthropic( # type: ignore[call-arg]
model=MODEL_NAME,
)
llm_with_tools = llm.bind_tools(
[
{
"name": "get_weather",
"description": "Get weather report for a city",
"input_schema": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
}
]
)
response = await llm_with_tools.ainvoke("what's the weather in san francisco, ca")
assert isinstance(response, AIMessage)
assert isinstance(response.content, list)
assert isinstance(response.tool_calls, list)
assert len(response.tool_calls) == 1
tool_call = response.tool_calls[0]
assert tool_call["name"] == "get_weather"
assert isinstance(tool_call["args"], dict)
assert "location" in tool_call["args"]
# Test streaming
first = True
chunks = [] # type: ignore
async for chunk in llm_with_tools.astream(
"what's the weather in san francisco, ca"
):
chunks = chunks + [chunk]
if first:
gathered = chunk
first = False
else:
gathered = gathered + chunk # type: ignore
assert len(chunks) > 1
assert isinstance(gathered, AIMessageChunk)
assert isinstance(gathered.tool_call_chunks, list)
assert len(gathered.tool_call_chunks) == 1
tool_call_chunk = gathered.tool_call_chunks[0]
assert tool_call_chunk["name"] == "get_weather"
assert isinstance(tool_call_chunk["args"], str)
assert "location" in json.loads(tool_call_chunk["args"])
def test_batch() -> None:
"""Test batch tokens from ChatAnthropicMessages."""
llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg]
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token.content, str)
async def test_ainvoke() -> None:
"""Test invoke tokens from ChatAnthropicMessages."""
llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg]
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result.content, str)
def test_invoke() -> None:
"""Test invoke tokens from ChatAnthropicMessages."""
llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg]
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)
def test_system_invoke() -> None:
"""Test invoke tokens with a system message"""
llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg]
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are an expert cartographer. If asked, you are a cartographer. "
"STAY IN CHARACTER",
),
("human", "Are you a mathematician?"),
]
)
chain = prompt | llm
result = chain.invoke({})
assert isinstance(result.content, str)
def test_anthropic_call() -> None:
"""Test valid call to anthropic."""
chat = ChatAnthropic(model=MODEL_NAME) # type: ignore[call-arg]
message = HumanMessage(content="Hello")
response = chat.invoke([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_anthropic_generate() -> None:
"""Test generate method of anthropic."""
chat = ChatAnthropic(model=MODEL_NAME) # type: ignore[call-arg]
chat_messages: List[List[BaseMessage]] = [
[HumanMessage(content="How many toes do dogs have?")]
]
messages_copy = [messages.copy() for messages in chat_messages]
result: LLMResult = chat.generate(chat_messages)
assert isinstance(result, LLMResult)
for response in result.generations[0]:
assert isinstance(response, ChatGeneration)
assert isinstance(response.text, str)
assert response.text == response.message.content
assert chat_messages == messages_copy
def test_anthropic_streaming() -> None:
"""Test streaming tokens from anthropic."""
chat = ChatAnthropic(model=MODEL_NAME) # type: ignore[call-arg]
message = HumanMessage(content="Hello")
response = chat.stream([message])
for token in response:
assert isinstance(token, AIMessageChunk)
assert isinstance(token.content, str)
def test_anthropic_streaming_callback() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatAnthropic( # type: ignore[call-arg]
model=MODEL_NAME,
callback_manager=callback_manager,
verbose=True,
)
message = HumanMessage(content="Write me a sentence with 10 words.")
for token in chat.stream([message]):
assert isinstance(token, AIMessageChunk)
assert isinstance(token.content, str)
assert callback_handler.llm_streams > 1
async def test_anthropic_async_streaming_callback() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatAnthropic( # type: ignore[call-arg]
model=MODEL_NAME,
callback_manager=callback_manager,
verbose=True,
)
chat_messages: List[BaseMessage] = [
HumanMessage(content="How many toes do dogs have?")
]
async for token in chat.astream(chat_messages):
assert isinstance(token, AIMessageChunk)
assert isinstance(token.content, str)
assert callback_handler.llm_streams > 1
def test_anthropic_multimodal() -> None:
"""Test that multimodal inputs are handled correctly."""
chat = ChatAnthropic(model=MODEL_NAME) # type: ignore[call-arg]
messages = [
HumanMessage(
content=[
{
"type": "image_url",
"image_url": {
# langchain logo
"url": "", # noqa: E501
},
},
{"type": "text", "text": "What is this a logo for?"},
]
)
]
response = chat.invoke(messages)
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_streaming() -> None:
"""Test streaming tokens from Anthropic."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
llm = ChatAnthropicMessages( # type: ignore[call-arg, call-arg]
model_name=MODEL_NAME, streaming=True, callback_manager=callback_manager
)
response = llm.generate([[HumanMessage(content="I'm Pickle Rick")]])
assert callback_handler.llm_streams > 0
assert isinstance(response, LLMResult)
async def test_astreaming() -> None:
"""Test streaming tokens from Anthropic."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
llm = ChatAnthropicMessages( # type: ignore[call-arg, call-arg]
model_name=MODEL_NAME, streaming=True, callback_manager=callback_manager
)
response = await llm.agenerate([[HumanMessage(content="I'm Pickle Rick")]])
assert callback_handler.llm_streams > 0
assert isinstance(response, LLMResult)
def test_tool_use() -> None:
llm = ChatAnthropic(model=MODEL_NAME) # type: ignore[call-arg]
llm_with_tools = llm.bind_tools(
[
{
"name": "get_weather",
"description": "Get weather report for a city",
"input_schema": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
}
]
)
response = llm_with_tools.invoke("what's the weather in san francisco, ca")
assert isinstance(response, AIMessage)
assert isinstance(response.content, list)
assert isinstance(response.tool_calls, list)
assert len(response.tool_calls) == 1
tool_call = response.tool_calls[0]
assert tool_call["name"] == "get_weather"
assert isinstance(tool_call["args"], dict)
assert "location" in tool_call["args"]
# Test streaming
input = "how are you? what's the weather in san francisco, ca"
first = True
chunks = [] # type: ignore
for chunk in llm_with_tools.stream(input):
chunks = chunks + [chunk]
if first:
gathered = chunk
first = False
else:
gathered = gathered + chunk # type: ignore
assert len(chunks) > 1
assert isinstance(gathered.content, list)
assert len(gathered.content) == 2
tool_use_block = None
for content_block in gathered.content:
assert isinstance(content_block, dict)
if content_block["type"] == "tool_use":
tool_use_block = content_block
break
assert tool_use_block is not None
assert tool_use_block["name"] == "get_weather"
assert "location" in json.loads(tool_use_block["partial_json"])
assert isinstance(gathered, AIMessageChunk)
assert isinstance(gathered.tool_calls, list)
assert len(gathered.tool_calls) == 1
tool_call = gathered.tool_calls[0]
assert tool_call["name"] == "get_weather"
assert isinstance(tool_call["args"], dict)
assert "location" in tool_call["args"]
assert tool_call["id"] is not None
# Test passing response back to model
stream = llm_with_tools.stream(
[
input,
gathered,
ToolMessage(content="sunny and warm", tool_call_id=tool_call["id"]),
]
)
chunks = [] # type: ignore
first = True
for chunk in stream:
chunks = chunks + [chunk]
if first:
gathered = chunk
first = False
else:
gathered = gathered + chunk # type: ignore
assert len(chunks) > 1
def test_anthropic_with_empty_text_block() -> None:
"""Anthropic SDK can return an empty text block."""
@tool
def type_letter(letter: str) -> str:
"""Type the given letter."""
return "OK"
model = ChatAnthropic(model="claude-3-opus-20240229", temperature=0).bind_tools( # type: ignore[call-arg]
[type_letter]
)
messages = [
SystemMessage(
content="Repeat the given string using the provided tools. Do not write "
"anything else or provide any explanations. For example, "
"if the string is 'abc', you must print the "
"letters 'a', 'b', and 'c' one at a time and in that order. "
),
HumanMessage(content="dog"),
AIMessage(
content=[
{"text": "", "type": "text"},
{
"id": "toolu_01V6d6W32QGGSmQm4BT98EKk",
"input": {"letter": "d"},
"name": "type_letter",
"type": "tool_use",
},
],
tool_calls=[
{
"name": "type_letter",
"args": {"letter": "d"},
"id": "toolu_01V6d6W32QGGSmQm4BT98EKk",
"type": "tool_call",
},
],
),
ToolMessage(content="OK", tool_call_id="toolu_01V6d6W32QGGSmQm4BT98EKk"),
]
model.invoke(messages)
def test_with_structured_output() -> None:
llm = ChatAnthropic( # type: ignore[call-arg]
model="claude-3-opus-20240229",
)
structured_llm = llm.with_structured_output(
{
"name": "get_weather",
"description": "Get weather report for a city",
"input_schema": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
}
)
response = structured_llm.invoke("what's the weather in san francisco, ca")
assert isinstance(response, dict)
assert response["location"]
class GetWeather(BaseModel):
"""Get the current weather in a given location"""
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
@pytest.mark.parametrize("tool_choice", ["GetWeather", "auto", "any"])
def test_anthropic_bind_tools_tool_choice(tool_choice: str) -> None:
chat_model = ChatAnthropic( # type: ignore[call-arg]
model=MODEL_NAME,
)
chat_model_with_tools = chat_model.bind_tools([GetWeather], tool_choice=tool_choice)
response = chat_model_with_tools.invoke("what's the weather in ny and la")
assert isinstance(response, AIMessage)