From 0d495f3f63368e726802843977bce6c4f81fa99d Mon Sep 17 00:00:00 2001
From: Bagatur <22008038+baskaryan@users.noreply.github.com>
Date: Thu, 6 Jun 2024 08:51:34 -0700
Subject: [PATCH] anthropic: stream token usage (#20180)
open to other ideas
---------
Co-authored-by: Chester Curme
---
.../langchain_anthropic/chat_models.py | 70 ++++++++++++++++-
.../integration_tests/test_chat_models.py | 78 ++++++++++++++++++-
2 files changed, 145 insertions(+), 3 deletions(-)
diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py
index 91a6e31a2f..62f158b647 100644
--- a/libs/partners/anthropic/langchain_anthropic/chat_models.py
+++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py
@@ -43,6 +43,7 @@ from langchain_core.messages import (
ToolCall,
ToolMessage,
)
+from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import (
@@ -653,14 +654,20 @@ class ChatAnthropic(BaseChatModel):
message_chunk = AIMessageChunk(
content=message.content,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
+ usage_metadata=message.usage_metadata,
)
yield ChatGenerationChunk(message=message_chunk)
else:
yield cast(ChatGenerationChunk, result.generations[0])
return
+ full_generation_info: dict = {}
with self._client.messages.stream(**params) as stream:
for text in stream.text_stream:
- chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
+ chunk, full_generation_info = _make_chat_generation_chunk(
+ text,
+ stream.current_message_snapshot.model_dump(),
+ full_generation_info,
+ )
if run_manager:
run_manager.on_llm_new_token(text, chunk=chunk)
yield chunk
@@ -692,14 +699,20 @@ class ChatAnthropic(BaseChatModel):
message_chunk = AIMessageChunk(
content=message.content,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
+ usage_metadata=message.usage_metadata,
)
yield ChatGenerationChunk(message=message_chunk)
else:
yield cast(ChatGenerationChunk, result.generations[0])
return
+ full_generation_info: dict = {}
async with self._async_client.messages.stream(**params) as stream:
async for text in stream.text_stream:
- chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
+ chunk, full_generation_info = _make_chat_generation_chunk(
+ text,
+ stream.current_message_snapshot.model_dump(),
+ full_generation_info,
+ )
if run_manager:
await run_manager.on_llm_new_token(text, chunk=chunk)
yield chunk
@@ -1064,6 +1077,59 @@ def _lc_tool_calls_to_anthropic_tool_use_blocks(
return blocks
+def _make_chat_generation_chunk(
+ text: str, message_dump: dict, full_generation_info: dict
+) -> Tuple[ChatGenerationChunk, dict]:
+ """Collect metadata and make ChatGenerationChunk.
+
+ Args:
+ text: text of the message chunk
+ message_dump: dict with metadata of the message chunk
+ full_generation_info: dict collecting metadata for full stream
+
+ Returns:
+ Tuple with ChatGenerationChunk and updated full_generation_info
+ """
+ generation_info = {}
+ usage_metadata: Optional[UsageMetadata] = None
+ for k, v in message_dump.items():
+ if k in ("content", "role", "type") or (
+ k in full_generation_info and k not in ("usage", "stop_reason")
+ ):
+ continue
+ elif k == "usage":
+ input_tokens = v.get("input_tokens", 0)
+ output_tokens = v.get("output_tokens", 0)
+ if "usage" not in full_generation_info:
+ full_generation_info[k] = v
+ usage_metadata = UsageMetadata(
+ input_tokens=input_tokens,
+ output_tokens=output_tokens,
+ total_tokens=input_tokens + output_tokens,
+ )
+ else:
+ seen_input_tokens = full_generation_info[k].get("input_tokens", 0)
+ # Anthropic returns the same input token count for each message in a
+ # stream. To avoid double counting, we only count the input tokens
+ # once. After that, we set the input tokens to zero.
+ new_input_tokens = 0 if seen_input_tokens else input_tokens
+ usage_metadata = UsageMetadata(
+ input_tokens=new_input_tokens,
+ output_tokens=output_tokens,
+ total_tokens=new_input_tokens + output_tokens,
+ )
+ else:
+ full_generation_info[k] = v
+ generation_info[k] = v
+ return (
+ ChatGenerationChunk(
+ message=AIMessageChunk(content=text, usage_metadata=usage_metadata),
+ generation_info=generation_info,
+ ),
+ full_generation_info,
+ )
+
+
@deprecated(since="0.1.0", removal="0.3.0", alternative="ChatAnthropic")
class ChatAnthropicMessages(ChatAnthropic):
pass
diff --git a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py
index cee2cf70cf..3ce3cbab67 100644
--- a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py
+++ b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py
@@ -1,7 +1,7 @@
"""Test ChatAnthropic chat model."""
import json
-from typing import List
+from typing import List, Optional
import pytest
from langchain_core.callbacks import CallbackManager
@@ -9,6 +9,7 @@ from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
+ BaseMessageChunk,
HumanMessage,
SystemMessage,
ToolMessage,
@@ -28,16 +29,80 @@ def test_stream() -> None:
"""Test streaming tokens from Anthropic."""
llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg]
+ full: Optional[BaseMessageChunk] = None
+ chunks_with_input_token_counts = 0
for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token.content, str)
+ full = token if full is None else full + token
+ assert isinstance(token, AIMessageChunk)
+ if token.usage_metadata is not None and token.usage_metadata.get(
+ "input_tokens"
+ ):
+ chunks_with_input_token_counts += 1
+ if chunks_with_input_token_counts != 1:
+ raise AssertionError(
+ "Expected exactly one chunk with input token counts. "
+ "AIMessageChunk aggregation adds counts. Check that "
+ "this is behaving properly."
+ )
+ # check token usage is populated
+ assert isinstance(full, AIMessageChunk)
+ assert full.usage_metadata is not None
+ assert full.usage_metadata["input_tokens"] > 0
+ assert full.usage_metadata["output_tokens"] > 0
+ assert full.usage_metadata["total_tokens"] > 0
+ assert (
+ full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"]
+ == full.usage_metadata["total_tokens"]
+ )
async def test_astream() -> None:
"""Test streaming tokens from Anthropic."""
llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg]
+ full: Optional[BaseMessageChunk] = None
+ chunks_with_input_token_counts = 0
async for token in llm.astream("I'm Pickle Rick"):
assert isinstance(token.content, str)
+ full = token if full is None else full + token
+ assert isinstance(token, AIMessageChunk)
+ if token.usage_metadata is not None and token.usage_metadata.get(
+ "input_tokens"
+ ):
+ chunks_with_input_token_counts += 1
+ if chunks_with_input_token_counts != 1:
+ raise AssertionError(
+ "Expected exactly one chunk with input token counts. "
+ "AIMessageChunk aggregation adds counts. Check that "
+ "this is behaving properly."
+ )
+ # check token usage is populated
+ assert isinstance(full, AIMessageChunk)
+ assert full.usage_metadata is not None
+ assert full.usage_metadata["input_tokens"] > 0
+ assert full.usage_metadata["output_tokens"] > 0
+ assert full.usage_metadata["total_tokens"] > 0
+ assert (
+ full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"]
+ == full.usage_metadata["total_tokens"]
+ )
+
+ # Check assumption that each chunk has identical input token counts.
+ # This assumption is baked into _make_chat_generation_chunk.
+ params: dict = {
+ "model": MODEL_NAME,
+ "max_tokens": 1024,
+ "messages": [{"role": "user", "content": "I'm Pickle Rick"}],
+ }
+ all_input_tokens = set()
+ async with llm._async_client.messages.stream(**params) as stream:
+ async for _ in stream.text_stream:
+ message_dump = stream.current_message_snapshot.model_dump()
+ if input_tokens := message_dump.get("usage", {}).get("input_tokens"):
+ assert input_tokens > 0
+ all_input_tokens.add(input_tokens)
+ assert len(all_input_tokens) == 1
async def test_abatch() -> None:
@@ -268,6 +333,17 @@ def test_tool_use() -> None:
assert isinstance(tool_call_chunk["args"], str)
assert "location" in json.loads(tool_call_chunk["args"])
+ # Check usage metadata
+ assert gathered.usage_metadata is not None
+ assert gathered.usage_metadata["input_tokens"] > 0
+ assert gathered.usage_metadata["output_tokens"] > 0
+ assert gathered.usage_metadata["total_tokens"] > 0
+ assert (
+ gathered.usage_metadata["input_tokens"]
+ + gathered.usage_metadata["output_tokens"]
+ == gathered.usage_metadata["total_tokens"]
+ )
+
def test_anthropic_with_empty_text_block() -> None:
"""Anthropic SDK can return an empty text block."""