From ad9750403b40def548c848913c2e8e2924f98a27 Mon Sep 17 00:00:00 2001 From: Shuqian Date: Tue, 9 Apr 2024 22:18:48 +0800 Subject: [PATCH] community[minor]: add bedrock anthropic callback for token usage counting (#19864) **Description:** add bedrock anthropic callback for token usage counting, consulted openai callback. --------- Co-authored-by: Massimiliano Pronesti --- .../callbacks/bedrock_anthropic_callback.py | 111 ++++++++++++++++++ .../langchain_community/callbacks/manager.py | 30 ++++- .../chat_models/bedrock.py | 2 +- .../callbacks/test_callback_manager.py | 32 +++++ .../unit_tests/chat_models/test_bedrock.py | 29 +++++ 5 files changed, 202 insertions(+), 2 deletions(-) create mode 100644 libs/community/langchain_community/callbacks/bedrock_anthropic_callback.py diff --git a/libs/community/langchain_community/callbacks/bedrock_anthropic_callback.py b/libs/community/langchain_community/callbacks/bedrock_anthropic_callback.py new file mode 100644 index 0000000000..d146bf8fbe --- /dev/null +++ b/libs/community/langchain_community/callbacks/bedrock_anthropic_callback.py @@ -0,0 +1,111 @@ +import threading +from typing import Any, Dict, List, Union + +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.outputs import LLMResult + +MODEL_COST_PER_1K_INPUT_TOKENS = { + "anthropic.claude-instant-v1": 0.0008, + "anthropic.claude-v2": 0.008, + "anthropic.claude-v2:1": 0.008, + "anthropic.claude-3-sonnet-20240229-v1:0": 0.003, + "anthropic.claude-3-haiku-20240307-v1:0": 0.00025, +} + +MODEL_COST_PER_1K_OUTPUT_TOKENS = { + "anthropic.claude-instant-v1": 0.0024, + "anthropic.claude-v2": 0.024, + "anthropic.claude-v2:1": 0.024, + "anthropic.claude-3-sonnet-20240229-v1:0": 0.015, + "anthropic.claude-3-haiku-20240307-v1:0": 0.00125, +} + + +def _get_anthropic_claude_token_cost( + prompt_tokens: int, completion_tokens: int, model_id: Union[str, None] +) -> float: + """Get the cost of tokens for the Claude model.""" + if model_id not in MODEL_COST_PER_1K_INPUT_TOKENS: + raise ValueError( + f"Unknown model: {model_id}. Please provide a valid Anthropic model name." + "Known models are: " + ", ".join(MODEL_COST_PER_1K_INPUT_TOKENS.keys()) + ) + return (prompt_tokens / 1000) * MODEL_COST_PER_1K_INPUT_TOKENS[model_id] + ( + completion_tokens / 1000 + ) * MODEL_COST_PER_1K_OUTPUT_TOKENS[model_id] + + +class BedrockAnthropicTokenUsageCallbackHandler(BaseCallbackHandler): + """Callback Handler that tracks bedrock anthropic info.""" + + total_tokens: int = 0 + prompt_tokens: int = 0 + completion_tokens: int = 0 + successful_requests: int = 0 + total_cost: float = 0.0 + + def __init__(self) -> None: + super().__init__() + self._lock = threading.Lock() + + def __repr__(self) -> str: + return ( + f"Tokens Used: {self.total_tokens}\n" + f"\tPrompt Tokens: {self.prompt_tokens}\n" + f"\tCompletion Tokens: {self.completion_tokens}\n" + f"Successful Requests: {self.successful_requests}\n" + f"Total Cost (USD): ${self.total_cost}" + ) + + @property + def always_verbose(self) -> bool: + """Whether to call verbose callbacks even if verbose is False.""" + return True + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Print out the prompts.""" + pass + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Print out the token.""" + pass + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Collect token usage.""" + if response.llm_output is None: + return None + + if "usage" not in response.llm_output: + with self._lock: + self.successful_requests += 1 + return None + + # compute tokens and cost for this request + token_usage = response.llm_output["usage"] + completion_tokens = token_usage.get("completion_tokens", 0) + prompt_tokens = token_usage.get("prompt_tokens", 0) + total_tokens = token_usage.get("total_tokens", 0) + model_id = response.llm_output.get("model_id", None) + total_cost = _get_anthropic_claude_token_cost( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + model_id=model_id, + ) + + # update shared state behind lock + with self._lock: + self.total_cost += total_cost + self.total_tokens += total_tokens + self.prompt_tokens += prompt_tokens + self.completion_tokens += completion_tokens + self.successful_requests += 1 + + def __copy__(self) -> "BedrockAnthropicTokenUsageCallbackHandler": + """Return a copy of the callback handler.""" + return self + + def __deepcopy__(self, memo: Any) -> "BedrockAnthropicTokenUsageCallbackHandler": + """Return a deep copy of the callback handler.""" + return self diff --git a/libs/community/langchain_community/callbacks/manager.py b/libs/community/langchain_community/callbacks/manager.py index ec03a82345..f5b4530ea2 100644 --- a/libs/community/langchain_community/callbacks/manager.py +++ b/libs/community/langchain_community/callbacks/manager.py @@ -10,6 +10,9 @@ from typing import ( from langchain_core.tracers.context import register_configure_hook +from langchain_community.callbacks.bedrock_anthropic_callback import ( + BedrockAnthropicTokenUsageCallbackHandler, +) from langchain_community.callbacks.openai_info import OpenAICallbackHandler from langchain_community.callbacks.tracers.comet import CometTracer from langchain_community.callbacks.tracers.wandb import WandbTracer @@ -19,7 +22,10 @@ logger = logging.getLogger(__name__) openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar( "openai_callback", default=None ) -wandb_tracing_callback_var: ContextVar[Optional[WandbTracer]] = ContextVar( # noqa: E501 +bedrock_anthropic_callback_var: (ContextVar)[ + Optional[BedrockAnthropicTokenUsageCallbackHandler] +] = ContextVar("bedrock_anthropic_callback", default=None) +wandb_tracing_callback_var: ContextVar[Optional[WandbTracer]] = ContextVar( "tracing_wandb_callback", default=None ) comet_tracing_callback_var: ContextVar[Optional[CometTracer]] = ContextVar( # noqa: E501 @@ -27,6 +33,7 @@ comet_tracing_callback_var: ContextVar[Optional[CometTracer]] = ContextVar( # n ) register_configure_hook(openai_callback_var, True) +register_configure_hook(bedrock_anthropic_callback_var, True) register_configure_hook( wandb_tracing_callback_var, True, WandbTracer, "LANGCHAIN_WANDB_TRACING" ) @@ -53,6 +60,27 @@ def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]: openai_callback_var.set(None) +@contextmanager +def get_bedrock_anthropic_callback() -> ( + Generator[BedrockAnthropicTokenUsageCallbackHandler, None, None] +): + """Get the Bedrock anthropic callback handler in a context manager. + which conveniently exposes token and cost information. + + Returns: + BedrockAnthropicTokenUsageCallbackHandler: + The Bedrock anthropic callback handler. + + Example: + >>> with get_bedrock_anthropic_callback() as cb: + ... # Use the Bedrock anthropic callback handler + """ + cb = BedrockAnthropicTokenUsageCallbackHandler() + bedrock_anthropic_callback_var.set(cb) + yield cb + bedrock_anthropic_callback_var.set(None) + + @contextmanager def wandb_tracing_enabled( session_name: str = "default", diff --git a/libs/community/langchain_community/chat_models/bedrock.py b/libs/community/langchain_community/chat_models/bedrock.py index 4cb73455f2..587db5ee35 100644 --- a/libs/community/langchain_community/chat_models/bedrock.py +++ b/libs/community/langchain_community/chat_models/bedrock.py @@ -308,7 +308,7 @@ class BedrockChat(BaseChatModel, BedrockBase): final_output = {} for output in llm_outputs: output = output or {} - usage = output.pop("usage", {}) + usage = output.get("usage", {}) for token_type, token_count in usage.items(): final_usage[token_type] += token_count final_output.update(output) diff --git a/libs/community/tests/unit_tests/callbacks/test_callback_manager.py b/libs/community/tests/unit_tests/callbacks/test_callback_manager.py index 353a72c85b..cf308c9304 100644 --- a/libs/community/tests/unit_tests/callbacks/test_callback_manager.py +++ b/libs/community/tests/unit_tests/callbacks/test_callback_manager.py @@ -7,6 +7,7 @@ from langchain_core.outputs import LLMResult from langchain_core.tracers.langchain import LangChainTracer, wait_for_all_tracers from langchain_community.callbacks import get_openai_callback +from langchain_community.callbacks.manager import get_bedrock_anthropic_callback from langchain_community.llms.openai import BaseOpenAI @@ -77,6 +78,37 @@ def test_callback_manager_configure_context_vars( ) mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response) + # The callback handler has been updated + assert cb.successful_requests == 1 + assert cb.total_tokens == 3 + assert cb.prompt_tokens == 2 + assert cb.completion_tokens == 1 + assert cb.total_cost > 0 + + with get_bedrock_anthropic_callback() as cb: + # This is a new empty callback handler + assert cb.successful_requests == 0 + assert cb.total_tokens == 0 + + # configure adds this bedrock anthropic cb, + # but doesn't modify the group manager + mngr = CallbackManager.configure(group_manager) + assert mngr.handlers == [tracer, cb] + assert group_manager.handlers == [tracer] + + response = LLMResult( + generations=[], + llm_output={ + "usage": { + "prompt_tokens": 2, + "completion_tokens": 1, + "total_tokens": 3, + }, + "model_id": "anthropic.claude-instant-v1", + }, + ) + mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response) + # The callback handler has been updated assert cb.successful_requests == 1 assert cb.total_tokens == 3 diff --git a/libs/community/tests/unit_tests/chat_models/test_bedrock.py b/libs/community/tests/unit_tests/chat_models/test_bedrock.py index b515c99e5a..c78dceadaa 100644 --- a/libs/community/tests/unit_tests/chat_models/test_bedrock.py +++ b/libs/community/tests/unit_tests/chat_models/test_bedrock.py @@ -58,3 +58,32 @@ def test_different_models_bedrock(model_id: str) -> None: # should not throw an error model.invoke("hello there") + + +def test_bedrock_combine_llm_output() -> None: + model_id = "anthropic.claude-3-haiku-20240307-v1:0" + client = MagicMock() + llm_outputs = [ + { + "model_id": "anthropic.claude-3-haiku-20240307-v1:0", + "usage": { + "completion_tokens": 1, + "prompt_tokens": 2, + "total_tokens": 3, + }, + }, + { + "model_id": "anthropic.claude-3-haiku-20240307-v1:0", + "usage": { + "completion_tokens": 1, + "prompt_tokens": 2, + "total_tokens": 3, + }, + }, + ] + model = BedrockChat(model_id=model_id, client=client) + final_output = model._combine_llm_outputs(llm_outputs) + assert final_output["model_id"] == model_id + assert final_output["usage"]["completion_tokens"] == 2 + assert final_output["usage"]["prompt_tokens"] == 4 + assert final_output["usage"]["total_tokens"] == 6