"""Test ChatAnthropic chat model.""" import json from typing import List, Optional import pytest from langchain_core.callbacks import CallbackManager from langchain_core.messages import ( AIMessage, AIMessageChunk, BaseMessage, BaseMessageChunk, HumanMessage, SystemMessage, ToolMessage, ) from langchain_core.outputs import ChatGeneration, LLMResult from langchain_core.prompts import ChatPromptTemplate from langchain_core.tools import tool from pydantic import BaseModel, Field from langchain_anthropic import ChatAnthropic, ChatAnthropicMessages from tests.unit_tests._utils import FakeCallbackHandler MODEL_NAME = "claude-3-sonnet-20240229" def test_stream() -> None: """Test streaming tokens from Anthropic.""" llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg] full: Optional[BaseMessageChunk] = None chunks_with_input_token_counts = 0 chunks_with_output_token_counts = 0 for token in llm.stream("I'm Pickle Rick"): assert isinstance(token.content, str) full = token if full is None else full + token assert isinstance(token, AIMessageChunk) if token.usage_metadata is not None: if token.usage_metadata.get("input_tokens"): chunks_with_input_token_counts += 1 elif token.usage_metadata.get("output_tokens"): chunks_with_output_token_counts += 1 if chunks_with_input_token_counts != 1 or chunks_with_output_token_counts != 1: raise AssertionError( "Expected exactly one chunk with input or output token counts. " "AIMessageChunk aggregation adds counts. Check that " "this is behaving properly." ) # check token usage is populated assert isinstance(full, AIMessageChunk) assert full.usage_metadata is not None assert full.usage_metadata["input_tokens"] > 0 assert full.usage_metadata["output_tokens"] > 0 assert full.usage_metadata["total_tokens"] > 0 assert ( full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"] == full.usage_metadata["total_tokens"] ) assert "stop_reason" in full.response_metadata assert "stop_sequence" in full.response_metadata async def test_astream() -> None: """Test streaming tokens from Anthropic.""" llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg] full: Optional[BaseMessageChunk] = None chunks_with_input_token_counts = 0 chunks_with_output_token_counts = 0 async for token in llm.astream("I'm Pickle Rick"): assert isinstance(token.content, str) full = token if full is None else full + token assert isinstance(token, AIMessageChunk) if token.usage_metadata is not None: if token.usage_metadata.get("input_tokens"): chunks_with_input_token_counts += 1 elif token.usage_metadata.get("output_tokens"): chunks_with_output_token_counts += 1 if chunks_with_input_token_counts != 1 or chunks_with_output_token_counts != 1: raise AssertionError( "Expected exactly one chunk with input or output token counts. " "AIMessageChunk aggregation adds counts. Check that " "this is behaving properly." ) # check token usage is populated assert isinstance(full, AIMessageChunk) assert full.usage_metadata is not None assert full.usage_metadata["input_tokens"] > 0 assert full.usage_metadata["output_tokens"] > 0 assert full.usage_metadata["total_tokens"] > 0 assert ( full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"] == full.usage_metadata["total_tokens"] ) assert "stop_reason" in full.response_metadata assert "stop_sequence" in full.response_metadata # test usage metadata can be excluded model = ChatAnthropic(model_name=MODEL_NAME, stream_usage=False) # type: ignore[call-arg] async for token in model.astream("hi"): assert isinstance(token, AIMessageChunk) assert token.usage_metadata is None # check we override with kwarg model = ChatAnthropic(model_name=MODEL_NAME) # type: ignore[call-arg] assert model.stream_usage async for token in model.astream("hi", stream_usage=False): assert isinstance(token, AIMessageChunk) assert token.usage_metadata is None # Check expected raw API output async_client = model._async_client params: dict = { "model": "claude-3-haiku-20240307", "max_tokens": 1024, "messages": [{"role": "user", "content": "hi"}], "temperature": 0.0, } stream = await async_client.messages.create(**params, stream=True) async for event in stream: if event.type == "message_start": assert event.message.usage.input_tokens > 1 # Note: this single output token included in message start event # does not appear to contribute to overall output token counts. It # is excluded from the total token count. assert event.message.usage.output_tokens == 1 elif event.type == "message_delta": assert event.usage.output_tokens > 1 else: pass async def test_abatch() -> None: """Test streaming tokens from ChatAnthropicMessages.""" llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg] result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"]) for token in result: assert isinstance(token.content, str) async def test_abatch_tags() -> None: """Test batch tokens from ChatAnthropicMessages.""" llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg] result = await llm.abatch( ["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]} ) for token in result: assert isinstance(token.content, str) async def test_async_tool_use() -> None: llm = ChatAnthropic( # type: ignore[call-arg] model=MODEL_NAME, ) llm_with_tools = llm.bind_tools( [ { "name": "get_weather", "description": "Get weather report for a city", "input_schema": { "type": "object", "properties": {"location": {"type": "string"}}, }, } ] ) response = await llm_with_tools.ainvoke("what's the weather in san francisco, ca") assert isinstance(response, AIMessage) assert isinstance(response.content, list) assert isinstance(response.tool_calls, list) assert len(response.tool_calls) == 1 tool_call = response.tool_calls[0] assert tool_call["name"] == "get_weather" assert isinstance(tool_call["args"], dict) assert "location" in tool_call["args"] # Test streaming first = True chunks = [] # type: ignore async for chunk in llm_with_tools.astream( "what's the weather in san francisco, ca" ): chunks = chunks + [chunk] if first: gathered = chunk first = False else: gathered = gathered + chunk # type: ignore assert len(chunks) > 1 assert isinstance(gathered, AIMessageChunk) assert isinstance(gathered.tool_call_chunks, list) assert len(gathered.tool_call_chunks) == 1 tool_call_chunk = gathered.tool_call_chunks[0] assert tool_call_chunk["name"] == "get_weather" assert isinstance(tool_call_chunk["args"], str) assert "location" in json.loads(tool_call_chunk["args"]) def test_batch() -> None: """Test batch tokens from ChatAnthropicMessages.""" llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg] result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"]) for token in result: assert isinstance(token.content, str) async def test_ainvoke() -> None: """Test invoke tokens from ChatAnthropicMessages.""" llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg] result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]}) assert isinstance(result.content, str) def test_invoke() -> None: """Test invoke tokens from ChatAnthropicMessages.""" llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg] result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) assert isinstance(result.content, str) def test_system_invoke() -> None: """Test invoke tokens with a system message""" llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg] prompt = ChatPromptTemplate.from_messages( [ ( "system", "You are an expert cartographer. If asked, you are a cartographer. " "STAY IN CHARACTER", ), ("human", "Are you a mathematician?"), ] ) chain = prompt | llm result = chain.invoke({}) assert isinstance(result.content, str) def test_anthropic_call() -> None: """Test valid call to anthropic.""" chat = ChatAnthropic(model=MODEL_NAME) # type: ignore[call-arg] message = HumanMessage(content="Hello") response = chat.invoke([message]) assert isinstance(response, AIMessage) assert isinstance(response.content, str) def test_anthropic_generate() -> None: """Test generate method of anthropic.""" chat = ChatAnthropic(model=MODEL_NAME) # type: ignore[call-arg] chat_messages: List[List[BaseMessage]] = [ [HumanMessage(content="How many toes do dogs have?")] ] messages_copy = [messages.copy() for messages in chat_messages] result: LLMResult = chat.generate(chat_messages) assert isinstance(result, LLMResult) for response in result.generations[0]: assert isinstance(response, ChatGeneration) assert isinstance(response.text, str) assert response.text == response.message.content assert chat_messages == messages_copy def test_anthropic_streaming() -> None: """Test streaming tokens from anthropic.""" chat = ChatAnthropic(model=MODEL_NAME) # type: ignore[call-arg] message = HumanMessage(content="Hello") response = chat.stream([message]) for token in response: assert isinstance(token, AIMessageChunk) assert isinstance(token.content, str) def test_anthropic_streaming_callback() -> None: """Test that streaming correctly invokes on_llm_new_token callback.""" callback_handler = FakeCallbackHandler() callback_manager = CallbackManager([callback_handler]) chat = ChatAnthropic( # type: ignore[call-arg] model=MODEL_NAME, callback_manager=callback_manager, verbose=True, ) message = HumanMessage(content="Write me a sentence with 10 words.") for token in chat.stream([message]): assert isinstance(token, AIMessageChunk) assert isinstance(token.content, str) assert callback_handler.llm_streams > 1 async def test_anthropic_async_streaming_callback() -> None: """Test that streaming correctly invokes on_llm_new_token callback.""" callback_handler = FakeCallbackHandler() callback_manager = CallbackManager([callback_handler]) chat = ChatAnthropic( # type: ignore[call-arg] model=MODEL_NAME, callback_manager=callback_manager, verbose=True, ) chat_messages: List[BaseMessage] = [ HumanMessage(content="How many toes do dogs have?") ] async for token in chat.astream(chat_messages): assert isinstance(token, AIMessageChunk) assert isinstance(token.content, str) assert callback_handler.llm_streams > 1 def test_anthropic_multimodal() -> None: """Test that multimodal inputs are handled correctly.""" chat = ChatAnthropic(model=MODEL_NAME) # type: ignore[call-arg] messages = [ HumanMessage( content=[ { "type": "image_url", "image_url": { # langchain logo "url": "", # noqa: E501 }, }, {"type": "text", "text": "What is this a logo for?"}, ] ) ] response = chat.invoke(messages) assert isinstance(response, AIMessage) assert isinstance(response.content, str) def test_streaming() -> None: """Test streaming tokens from Anthropic.""" callback_handler = FakeCallbackHandler() callback_manager = CallbackManager([callback_handler]) llm = ChatAnthropicMessages( # type: ignore[call-arg, call-arg] model_name=MODEL_NAME, streaming=True, callback_manager=callback_manager ) response = llm.generate([[HumanMessage(content="I'm Pickle Rick")]]) assert callback_handler.llm_streams > 0 assert isinstance(response, LLMResult) async def test_astreaming() -> None: """Test streaming tokens from Anthropic.""" callback_handler = FakeCallbackHandler() callback_manager = CallbackManager([callback_handler]) llm = ChatAnthropicMessages( # type: ignore[call-arg, call-arg] model_name=MODEL_NAME, streaming=True, callback_manager=callback_manager ) response = await llm.agenerate([[HumanMessage(content="I'm Pickle Rick")]]) assert callback_handler.llm_streams > 0 assert isinstance(response, LLMResult) def test_tool_use() -> None: llm = ChatAnthropic(model=MODEL_NAME) # type: ignore[call-arg] llm_with_tools = llm.bind_tools( [ { "name": "get_weather", "description": "Get weather report for a city", "input_schema": { "type": "object", "properties": {"location": {"type": "string"}}, }, } ] ) response = llm_with_tools.invoke("what's the weather in san francisco, ca") assert isinstance(response, AIMessage) assert isinstance(response.content, list) assert isinstance(response.tool_calls, list) assert len(response.tool_calls) == 1 tool_call = response.tool_calls[0] assert tool_call["name"] == "get_weather" assert isinstance(tool_call["args"], dict) assert "location" in tool_call["args"] # Test streaming input = "how are you? what's the weather in san francisco, ca" first = True chunks = [] # type: ignore for chunk in llm_with_tools.stream(input): chunks = chunks + [chunk] if first: gathered = chunk first = False else: gathered = gathered + chunk # type: ignore assert len(chunks) > 1 assert isinstance(gathered.content, list) assert len(gathered.content) == 2 tool_use_block = None for content_block in gathered.content: assert isinstance(content_block, dict) if content_block["type"] == "tool_use": tool_use_block = content_block break assert tool_use_block is not None assert tool_use_block["name"] == "get_weather" assert "location" in json.loads(tool_use_block["partial_json"]) assert isinstance(gathered, AIMessageChunk) assert isinstance(gathered.tool_calls, list) assert len(gathered.tool_calls) == 1 tool_call = gathered.tool_calls[0] assert tool_call["name"] == "get_weather" assert isinstance(tool_call["args"], dict) assert "location" in tool_call["args"] assert tool_call["id"] is not None # Test passing response back to model stream = llm_with_tools.stream( [ input, gathered, ToolMessage(content="sunny and warm", tool_call_id=tool_call["id"]), ] ) chunks = [] # type: ignore first = True for chunk in stream: chunks = chunks + [chunk] if first: gathered = chunk first = False else: gathered = gathered + chunk # type: ignore assert len(chunks) > 1 def test_anthropic_with_empty_text_block() -> None: """Anthropic SDK can return an empty text block.""" @tool def type_letter(letter: str) -> str: """Type the given letter.""" return "OK" model = ChatAnthropic(model="claude-3-opus-20240229", temperature=0).bind_tools( # type: ignore[call-arg] [type_letter] ) messages = [ SystemMessage( content="Repeat the given string using the provided tools. Do not write " "anything else or provide any explanations. For example, " "if the string is 'abc', you must print the " "letters 'a', 'b', and 'c' one at a time and in that order. " ), HumanMessage(content="dog"), AIMessage( content=[ {"text": "", "type": "text"}, { "id": "toolu_01V6d6W32QGGSmQm4BT98EKk", "input": {"letter": "d"}, "name": "type_letter", "type": "tool_use", }, ], tool_calls=[ { "name": "type_letter", "args": {"letter": "d"}, "id": "toolu_01V6d6W32QGGSmQm4BT98EKk", "type": "tool_call", }, ], ), ToolMessage(content="OK", tool_call_id="toolu_01V6d6W32QGGSmQm4BT98EKk"), ] model.invoke(messages) def test_with_structured_output() -> None: llm = ChatAnthropic( # type: ignore[call-arg] model="claude-3-opus-20240229", ) structured_llm = llm.with_structured_output( { "name": "get_weather", "description": "Get weather report for a city", "input_schema": { "type": "object", "properties": {"location": {"type": "string"}}, }, } ) response = structured_llm.invoke("what's the weather in san francisco, ca") assert isinstance(response, dict) assert response["location"] class GetWeather(BaseModel): """Get the current weather in a given location""" location: str = Field(..., description="The city and state, e.g. San Francisco, CA") @pytest.mark.parametrize("tool_choice", ["GetWeather", "auto", "any"]) def test_anthropic_bind_tools_tool_choice(tool_choice: str) -> None: chat_model = ChatAnthropic( # type: ignore[call-arg] model=MODEL_NAME, ) chat_model_with_tools = chat_model.bind_tools([GetWeather], tool_choice=tool_choice) response = chat_model_with_tools.invoke("what's the weather in ny and la") assert isinstance(response, AIMessage)