From 0d495f3f63368e726802843977bce6c4f81fa99d Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Thu, 6 Jun 2024 08:51:34 -0700 Subject: [PATCH] anthropic: stream token usage (#20180) open to other ideas Screenshot 2024-04-08 at 5 34 08 PM --------- Co-authored-by: Chester Curme --- .../langchain_anthropic/chat_models.py | 70 ++++++++++++++++- .../integration_tests/test_chat_models.py | 78 ++++++++++++++++++- 2 files changed, 145 insertions(+), 3 deletions(-) diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index 91a6e31a2f..62f158b647 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -43,6 +43,7 @@ from langchain_core.messages import ( ToolCall, ToolMessage, ) +from langchain_core.messages.ai import UsageMetadata from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.runnables import ( @@ -653,14 +654,20 @@ class ChatAnthropic(BaseChatModel): message_chunk = AIMessageChunk( content=message.content, tool_call_chunks=tool_call_chunks, # type: ignore[arg-type] + usage_metadata=message.usage_metadata, ) yield ChatGenerationChunk(message=message_chunk) else: yield cast(ChatGenerationChunk, result.generations[0]) return + full_generation_info: dict = {} with self._client.messages.stream(**params) as stream: for text in stream.text_stream: - chunk = ChatGenerationChunk(message=AIMessageChunk(content=text)) + chunk, full_generation_info = _make_chat_generation_chunk( + text, + stream.current_message_snapshot.model_dump(), + full_generation_info, + ) if run_manager: run_manager.on_llm_new_token(text, chunk=chunk) yield chunk @@ -692,14 +699,20 @@ class ChatAnthropic(BaseChatModel): message_chunk = AIMessageChunk( content=message.content, tool_call_chunks=tool_call_chunks, # type: ignore[arg-type] + usage_metadata=message.usage_metadata, ) yield ChatGenerationChunk(message=message_chunk) else: yield cast(ChatGenerationChunk, result.generations[0]) return + full_generation_info: dict = {} async with self._async_client.messages.stream(**params) as stream: async for text in stream.text_stream: - chunk = ChatGenerationChunk(message=AIMessageChunk(content=text)) + chunk, full_generation_info = _make_chat_generation_chunk( + text, + stream.current_message_snapshot.model_dump(), + full_generation_info, + ) if run_manager: await run_manager.on_llm_new_token(text, chunk=chunk) yield chunk @@ -1064,6 +1077,59 @@ def _lc_tool_calls_to_anthropic_tool_use_blocks( return blocks +def _make_chat_generation_chunk( + text: str, message_dump: dict, full_generation_info: dict +) -> Tuple[ChatGenerationChunk, dict]: + """Collect metadata and make ChatGenerationChunk. + + Args: + text: text of the message chunk + message_dump: dict with metadata of the message chunk + full_generation_info: dict collecting metadata for full stream + + Returns: + Tuple with ChatGenerationChunk and updated full_generation_info + """ + generation_info = {} + usage_metadata: Optional[UsageMetadata] = None + for k, v in message_dump.items(): + if k in ("content", "role", "type") or ( + k in full_generation_info and k not in ("usage", "stop_reason") + ): + continue + elif k == "usage": + input_tokens = v.get("input_tokens", 0) + output_tokens = v.get("output_tokens", 0) + if "usage" not in full_generation_info: + full_generation_info[k] = v + usage_metadata = UsageMetadata( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=input_tokens + output_tokens, + ) + else: + seen_input_tokens = full_generation_info[k].get("input_tokens", 0) + # Anthropic returns the same input token count for each message in a + # stream. To avoid double counting, we only count the input tokens + # once. After that, we set the input tokens to zero. + new_input_tokens = 0 if seen_input_tokens else input_tokens + usage_metadata = UsageMetadata( + input_tokens=new_input_tokens, + output_tokens=output_tokens, + total_tokens=new_input_tokens + output_tokens, + ) + else: + full_generation_info[k] = v + generation_info[k] = v + return ( + ChatGenerationChunk( + message=AIMessageChunk(content=text, usage_metadata=usage_metadata), + generation_info=generation_info, + ), + full_generation_info, + ) + + @deprecated(since="0.1.0", removal="0.3.0", alternative="ChatAnthropic") class ChatAnthropicMessages(ChatAnthropic): pass diff --git a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py index cee2cf70cf..3ce3cbab67 100644 --- a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py @@ -1,7 +1,7 @@ """Test ChatAnthropic chat model.""" import json -from typing import List +from typing import List, Optional import pytest from langchain_core.callbacks import CallbackManager @@ -9,6 +9,7 @@ from langchain_core.messages import ( AIMessage, AIMessageChunk, BaseMessage, + BaseMessageChunk, HumanMessage, SystemMessage, ToolMessage, @@ -28,16 +29,80 @@ def test_stream() -> None: """Test streaming tokens from Anthropic.""" llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg] + full: Optional[BaseMessageChunk] = None + chunks_with_input_token_counts = 0 for token in llm.stream("I'm Pickle Rick"): assert isinstance(token.content, str) + full = token if full is None else full + token + assert isinstance(token, AIMessageChunk) + if token.usage_metadata is not None and token.usage_metadata.get( + "input_tokens" + ): + chunks_with_input_token_counts += 1 + if chunks_with_input_token_counts != 1: + raise AssertionError( + "Expected exactly one chunk with input token counts. " + "AIMessageChunk aggregation adds counts. Check that " + "this is behaving properly." + ) + # check token usage is populated + assert isinstance(full, AIMessageChunk) + assert full.usage_metadata is not None + assert full.usage_metadata["input_tokens"] > 0 + assert full.usage_metadata["output_tokens"] > 0 + assert full.usage_metadata["total_tokens"] > 0 + assert ( + full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"] + == full.usage_metadata["total_tokens"] + ) async def test_astream() -> None: """Test streaming tokens from Anthropic.""" llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg] + full: Optional[BaseMessageChunk] = None + chunks_with_input_token_counts = 0 async for token in llm.astream("I'm Pickle Rick"): assert isinstance(token.content, str) + full = token if full is None else full + token + assert isinstance(token, AIMessageChunk) + if token.usage_metadata is not None and token.usage_metadata.get( + "input_tokens" + ): + chunks_with_input_token_counts += 1 + if chunks_with_input_token_counts != 1: + raise AssertionError( + "Expected exactly one chunk with input token counts. " + "AIMessageChunk aggregation adds counts. Check that " + "this is behaving properly." + ) + # check token usage is populated + assert isinstance(full, AIMessageChunk) + assert full.usage_metadata is not None + assert full.usage_metadata["input_tokens"] > 0 + assert full.usage_metadata["output_tokens"] > 0 + assert full.usage_metadata["total_tokens"] > 0 + assert ( + full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"] + == full.usage_metadata["total_tokens"] + ) + + # Check assumption that each chunk has identical input token counts. + # This assumption is baked into _make_chat_generation_chunk. + params: dict = { + "model": MODEL_NAME, + "max_tokens": 1024, + "messages": [{"role": "user", "content": "I'm Pickle Rick"}], + } + all_input_tokens = set() + async with llm._async_client.messages.stream(**params) as stream: + async for _ in stream.text_stream: + message_dump = stream.current_message_snapshot.model_dump() + if input_tokens := message_dump.get("usage", {}).get("input_tokens"): + assert input_tokens > 0 + all_input_tokens.add(input_tokens) + assert len(all_input_tokens) == 1 async def test_abatch() -> None: @@ -268,6 +333,17 @@ def test_tool_use() -> None: assert isinstance(tool_call_chunk["args"], str) assert "location" in json.loads(tool_call_chunk["args"]) + # Check usage metadata + assert gathered.usage_metadata is not None + assert gathered.usage_metadata["input_tokens"] > 0 + assert gathered.usage_metadata["output_tokens"] > 0 + assert gathered.usage_metadata["total_tokens"] > 0 + assert ( + gathered.usage_metadata["input_tokens"] + + gathered.usage_metadata["output_tokens"] + == gathered.usage_metadata["total_tokens"] + ) + def test_anthropic_with_empty_text_block() -> None: """Anthropic SDK can return an empty text block."""