From e08879147b17e45fc3ecaf1387a3c43b2534b173 Mon Sep 17 00:00:00 2001 From: ccurme Date: Thu, 6 Jun 2024 12:05:08 -0400 Subject: [PATCH] Revert "anthropic: stream token usage" (#22624) Reverts langchain-ai/langchain#20180 --- .../langchain_anthropic/chat_models.py | 70 +---------------- .../integration_tests/test_chat_models.py | 78 +------------------ 2 files changed, 3 insertions(+), 145 deletions(-) diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index 62f158b647..91a6e31a2f 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -43,7 +43,6 @@ from langchain_core.messages import ( ToolCall, ToolMessage, ) -from langchain_core.messages.ai import UsageMetadata from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.runnables import ( @@ -654,20 +653,14 @@ class ChatAnthropic(BaseChatModel): message_chunk = AIMessageChunk( content=message.content, tool_call_chunks=tool_call_chunks, # type: ignore[arg-type] - usage_metadata=message.usage_metadata, ) yield ChatGenerationChunk(message=message_chunk) else: yield cast(ChatGenerationChunk, result.generations[0]) return - full_generation_info: dict = {} with self._client.messages.stream(**params) as stream: for text in stream.text_stream: - chunk, full_generation_info = _make_chat_generation_chunk( - text, - stream.current_message_snapshot.model_dump(), - full_generation_info, - ) + chunk = ChatGenerationChunk(message=AIMessageChunk(content=text)) if run_manager: run_manager.on_llm_new_token(text, chunk=chunk) yield chunk @@ -699,20 +692,14 @@ class ChatAnthropic(BaseChatModel): message_chunk = AIMessageChunk( content=message.content, tool_call_chunks=tool_call_chunks, # type: ignore[arg-type] - usage_metadata=message.usage_metadata, ) yield ChatGenerationChunk(message=message_chunk) else: yield cast(ChatGenerationChunk, result.generations[0]) return - full_generation_info: dict = {} async with self._async_client.messages.stream(**params) as stream: async for text in stream.text_stream: - chunk, full_generation_info = _make_chat_generation_chunk( - text, - stream.current_message_snapshot.model_dump(), - full_generation_info, - ) + chunk = ChatGenerationChunk(message=AIMessageChunk(content=text)) if run_manager: await run_manager.on_llm_new_token(text, chunk=chunk) yield chunk @@ -1077,59 +1064,6 @@ def _lc_tool_calls_to_anthropic_tool_use_blocks( return blocks -def _make_chat_generation_chunk( - text: str, message_dump: dict, full_generation_info: dict -) -> Tuple[ChatGenerationChunk, dict]: - """Collect metadata and make ChatGenerationChunk. - - Args: - text: text of the message chunk - message_dump: dict with metadata of the message chunk - full_generation_info: dict collecting metadata for full stream - - Returns: - Tuple with ChatGenerationChunk and updated full_generation_info - """ - generation_info = {} - usage_metadata: Optional[UsageMetadata] = None - for k, v in message_dump.items(): - if k in ("content", "role", "type") or ( - k in full_generation_info and k not in ("usage", "stop_reason") - ): - continue - elif k == "usage": - input_tokens = v.get("input_tokens", 0) - output_tokens = v.get("output_tokens", 0) - if "usage" not in full_generation_info: - full_generation_info[k] = v - usage_metadata = UsageMetadata( - input_tokens=input_tokens, - output_tokens=output_tokens, - total_tokens=input_tokens + output_tokens, - ) - else: - seen_input_tokens = full_generation_info[k].get("input_tokens", 0) - # Anthropic returns the same input token count for each message in a - # stream. To avoid double counting, we only count the input tokens - # once. After that, we set the input tokens to zero. - new_input_tokens = 0 if seen_input_tokens else input_tokens - usage_metadata = UsageMetadata( - input_tokens=new_input_tokens, - output_tokens=output_tokens, - total_tokens=new_input_tokens + output_tokens, - ) - else: - full_generation_info[k] = v - generation_info[k] = v - return ( - ChatGenerationChunk( - message=AIMessageChunk(content=text, usage_metadata=usage_metadata), - generation_info=generation_info, - ), - full_generation_info, - ) - - @deprecated(since="0.1.0", removal="0.3.0", alternative="ChatAnthropic") class ChatAnthropicMessages(ChatAnthropic): pass diff --git a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py index 3ce3cbab67..cee2cf70cf 100644 --- a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py @@ -1,7 +1,7 @@ """Test ChatAnthropic chat model.""" import json -from typing import List, Optional +from typing import List import pytest from langchain_core.callbacks import CallbackManager @@ -9,7 +9,6 @@ from langchain_core.messages import ( AIMessage, AIMessageChunk, BaseMessage, - BaseMessageChunk, HumanMessage, SystemMessage, ToolMessage, @@ -29,80 +28,16 @@ def test_stream() -> None: """Test streaming tokens from Anthropic.""" llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg] - full: Optional[BaseMessageChunk] = None - chunks_with_input_token_counts = 0 for token in llm.stream("I'm Pickle Rick"): assert isinstance(token.content, str) - full = token if full is None else full + token - assert isinstance(token, AIMessageChunk) - if token.usage_metadata is not None and token.usage_metadata.get( - "input_tokens" - ): - chunks_with_input_token_counts += 1 - if chunks_with_input_token_counts != 1: - raise AssertionError( - "Expected exactly one chunk with input token counts. " - "AIMessageChunk aggregation adds counts. Check that " - "this is behaving properly." - ) - # check token usage is populated - assert isinstance(full, AIMessageChunk) - assert full.usage_metadata is not None - assert full.usage_metadata["input_tokens"] > 0 - assert full.usage_metadata["output_tokens"] > 0 - assert full.usage_metadata["total_tokens"] > 0 - assert ( - full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"] - == full.usage_metadata["total_tokens"] - ) async def test_astream() -> None: """Test streaming tokens from Anthropic.""" llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg] - full: Optional[BaseMessageChunk] = None - chunks_with_input_token_counts = 0 async for token in llm.astream("I'm Pickle Rick"): assert isinstance(token.content, str) - full = token if full is None else full + token - assert isinstance(token, AIMessageChunk) - if token.usage_metadata is not None and token.usage_metadata.get( - "input_tokens" - ): - chunks_with_input_token_counts += 1 - if chunks_with_input_token_counts != 1: - raise AssertionError( - "Expected exactly one chunk with input token counts. " - "AIMessageChunk aggregation adds counts. Check that " - "this is behaving properly." - ) - # check token usage is populated - assert isinstance(full, AIMessageChunk) - assert full.usage_metadata is not None - assert full.usage_metadata["input_tokens"] > 0 - assert full.usage_metadata["output_tokens"] > 0 - assert full.usage_metadata["total_tokens"] > 0 - assert ( - full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"] - == full.usage_metadata["total_tokens"] - ) - - # Check assumption that each chunk has identical input token counts. - # This assumption is baked into _make_chat_generation_chunk. - params: dict = { - "model": MODEL_NAME, - "max_tokens": 1024, - "messages": [{"role": "user", "content": "I'm Pickle Rick"}], - } - all_input_tokens = set() - async with llm._async_client.messages.stream(**params) as stream: - async for _ in stream.text_stream: - message_dump = stream.current_message_snapshot.model_dump() - if input_tokens := message_dump.get("usage", {}).get("input_tokens"): - assert input_tokens > 0 - all_input_tokens.add(input_tokens) - assert len(all_input_tokens) == 1 async def test_abatch() -> None: @@ -333,17 +268,6 @@ def test_tool_use() -> None: assert isinstance(tool_call_chunk["args"], str) assert "location" in json.loads(tool_call_chunk["args"]) - # Check usage metadata - assert gathered.usage_metadata is not None - assert gathered.usage_metadata["input_tokens"] > 0 - assert gathered.usage_metadata["output_tokens"] > 0 - assert gathered.usage_metadata["total_tokens"] > 0 - assert ( - gathered.usage_metadata["input_tokens"] - + gathered.usage_metadata["output_tokens"] - == gathered.usage_metadata["total_tokens"] - ) - def test_anthropic_with_empty_text_block() -> None: """Anthropic SDK can return an empty text block."""