Revert "anthropic: stream token usage" (#22624)

Reverts langchain-ai/langchain#20180
pull/22625/head
ccurme 1 month ago committed by GitHub
parent 0d495f3f63
commit e08879147b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -43,7 +43,6 @@ from langchain_core.messages import (
ToolCall,
ToolMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import (
@ -654,20 +653,14 @@ class ChatAnthropic(BaseChatModel):
message_chunk = AIMessageChunk(
content=message.content,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
usage_metadata=message.usage_metadata,
)
yield ChatGenerationChunk(message=message_chunk)
else:
yield cast(ChatGenerationChunk, result.generations[0])
return
full_generation_info: dict = {}
with self._client.messages.stream(**params) as stream:
for text in stream.text_stream:
chunk, full_generation_info = _make_chat_generation_chunk(
text,
stream.current_message_snapshot.model_dump(),
full_generation_info,
)
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
if run_manager:
run_manager.on_llm_new_token(text, chunk=chunk)
yield chunk
@ -699,20 +692,14 @@ class ChatAnthropic(BaseChatModel):
message_chunk = AIMessageChunk(
content=message.content,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
usage_metadata=message.usage_metadata,
)
yield ChatGenerationChunk(message=message_chunk)
else:
yield cast(ChatGenerationChunk, result.generations[0])
return
full_generation_info: dict = {}
async with self._async_client.messages.stream(**params) as stream:
async for text in stream.text_stream:
chunk, full_generation_info = _make_chat_generation_chunk(
text,
stream.current_message_snapshot.model_dump(),
full_generation_info,
)
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
if run_manager:
await run_manager.on_llm_new_token(text, chunk=chunk)
yield chunk
@ -1077,59 +1064,6 @@ def _lc_tool_calls_to_anthropic_tool_use_blocks(
return blocks
def _make_chat_generation_chunk(
text: str, message_dump: dict, full_generation_info: dict
) -> Tuple[ChatGenerationChunk, dict]:
"""Collect metadata and make ChatGenerationChunk.
Args:
text: text of the message chunk
message_dump: dict with metadata of the message chunk
full_generation_info: dict collecting metadata for full stream
Returns:
Tuple with ChatGenerationChunk and updated full_generation_info
"""
generation_info = {}
usage_metadata: Optional[UsageMetadata] = None
for k, v in message_dump.items():
if k in ("content", "role", "type") or (
k in full_generation_info and k not in ("usage", "stop_reason")
):
continue
elif k == "usage":
input_tokens = v.get("input_tokens", 0)
output_tokens = v.get("output_tokens", 0)
if "usage" not in full_generation_info:
full_generation_info[k] = v
usage_metadata = UsageMetadata(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
)
else:
seen_input_tokens = full_generation_info[k].get("input_tokens", 0)
# Anthropic returns the same input token count for each message in a
# stream. To avoid double counting, we only count the input tokens
# once. After that, we set the input tokens to zero.
new_input_tokens = 0 if seen_input_tokens else input_tokens
usage_metadata = UsageMetadata(
input_tokens=new_input_tokens,
output_tokens=output_tokens,
total_tokens=new_input_tokens + output_tokens,
)
else:
full_generation_info[k] = v
generation_info[k] = v
return (
ChatGenerationChunk(
message=AIMessageChunk(content=text, usage_metadata=usage_metadata),
generation_info=generation_info,
),
full_generation_info,
)
@deprecated(since="0.1.0", removal="0.3.0", alternative="ChatAnthropic")
class ChatAnthropicMessages(ChatAnthropic):
pass

@ -1,7 +1,7 @@
"""Test ChatAnthropic chat model."""
import json
from typing import List, Optional
from typing import List
import pytest
from langchain_core.callbacks import CallbackManager
@ -9,7 +9,6 @@ from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
HumanMessage,
SystemMessage,
ToolMessage,
@ -29,80 +28,16 @@ def test_stream() -> None:
"""Test streaming tokens from Anthropic."""
llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg]
full: Optional[BaseMessageChunk] = None
chunks_with_input_token_counts = 0
for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token.content, str)
full = token if full is None else full + token
assert isinstance(token, AIMessageChunk)
if token.usage_metadata is not None and token.usage_metadata.get(
"input_tokens"
):
chunks_with_input_token_counts += 1
if chunks_with_input_token_counts != 1:
raise AssertionError(
"Expected exactly one chunk with input token counts. "
"AIMessageChunk aggregation adds counts. Check that "
"this is behaving properly."
)
# check token usage is populated
assert isinstance(full, AIMessageChunk)
assert full.usage_metadata is not None
assert full.usage_metadata["input_tokens"] > 0
assert full.usage_metadata["output_tokens"] > 0
assert full.usage_metadata["total_tokens"] > 0
assert (
full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"]
== full.usage_metadata["total_tokens"]
)
async def test_astream() -> None:
"""Test streaming tokens from Anthropic."""
llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg]
full: Optional[BaseMessageChunk] = None
chunks_with_input_token_counts = 0
async for token in llm.astream("I'm Pickle Rick"):
assert isinstance(token.content, str)
full = token if full is None else full + token
assert isinstance(token, AIMessageChunk)
if token.usage_metadata is not None and token.usage_metadata.get(
"input_tokens"
):
chunks_with_input_token_counts += 1
if chunks_with_input_token_counts != 1:
raise AssertionError(
"Expected exactly one chunk with input token counts. "
"AIMessageChunk aggregation adds counts. Check that "
"this is behaving properly."
)
# check token usage is populated
assert isinstance(full, AIMessageChunk)
assert full.usage_metadata is not None
assert full.usage_metadata["input_tokens"] > 0
assert full.usage_metadata["output_tokens"] > 0
assert full.usage_metadata["total_tokens"] > 0
assert (
full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"]
== full.usage_metadata["total_tokens"]
)
# Check assumption that each chunk has identical input token counts.
# This assumption is baked into _make_chat_generation_chunk.
params: dict = {
"model": MODEL_NAME,
"max_tokens": 1024,
"messages": [{"role": "user", "content": "I'm Pickle Rick"}],
}
all_input_tokens = set()
async with llm._async_client.messages.stream(**params) as stream:
async for _ in stream.text_stream:
message_dump = stream.current_message_snapshot.model_dump()
if input_tokens := message_dump.get("usage", {}).get("input_tokens"):
assert input_tokens > 0
all_input_tokens.add(input_tokens)
assert len(all_input_tokens) == 1
async def test_abatch() -> None:
@ -333,17 +268,6 @@ def test_tool_use() -> None:
assert isinstance(tool_call_chunk["args"], str)
assert "location" in json.loads(tool_call_chunk["args"])
# Check usage metadata
assert gathered.usage_metadata is not None
assert gathered.usage_metadata["input_tokens"] > 0
assert gathered.usage_metadata["output_tokens"] > 0
assert gathered.usage_metadata["total_tokens"] > 0
assert (
gathered.usage_metadata["input_tokens"]
+ gathered.usage_metadata["output_tokens"]
== gathered.usage_metadata["total_tokens"]
)
def test_anthropic_with_empty_text_block() -> None:
"""Anthropic SDK can return an empty text block."""

Loading…
Cancel
Save