core[patch]: utils for adding/subtracting usage metadata (#27203)

This commit is contained in:
Bagatur 2024-10-08 13:15:33 -07:00 committed by GitHub
parent e3920f2320
commit e3e9ee8398
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 296 additions and 11 deletions

View File

@ -1,5 +1,6 @@
import json
from typing import Any, Literal, Optional, Union
import operator
from typing import Any, Literal, Optional, Union, cast
from pydantic import model_validator
from typing_extensions import NotRequired, Self, TypedDict
@ -27,6 +28,7 @@ from langchain_core.messages.tool import (
)
from langchain_core.utils._merge import merge_dicts, merge_lists
from langchain_core.utils.json import parse_partial_json
from langchain_core.utils.usage import _dict_int_op
class InputTokenDetails(TypedDict, total=False):
@ -432,17 +434,9 @@ def add_ai_message_chunks(
# Token usage
if left.usage_metadata or any(o.usage_metadata is not None for o in others):
usage_metadata_: UsageMetadata = left.usage_metadata or UsageMetadata(
input_tokens=0, output_tokens=0, total_tokens=0
)
usage_metadata: Optional[UsageMetadata] = left.usage_metadata
for other in others:
if other.usage_metadata is not None:
usage_metadata_["input_tokens"] += other.usage_metadata["input_tokens"]
usage_metadata_["output_tokens"] += other.usage_metadata[
"output_tokens"
]
usage_metadata_["total_tokens"] += other.usage_metadata["total_tokens"]
usage_metadata: Optional[UsageMetadata] = usage_metadata_
usage_metadata = add_usage(usage_metadata, other.usage_metadata)
else:
usage_metadata = None
@ -455,3 +449,115 @@ def add_ai_message_chunks(
usage_metadata=usage_metadata,
id=left.id,
)
def add_usage(
left: Optional[UsageMetadata], right: Optional[UsageMetadata]
) -> UsageMetadata:
"""Recursively add two UsageMetadata objects.
Example:
.. code-block:: python
from langchain_core.messages.ai import add_usage
left = UsageMetadata(
input_tokens=5,
output_tokens=0,
total_tokens=5,
input_token_details=InputTokenDetails(cache_read=3)
)
right = UsageMetadata(
input_tokens=0,
output_tokens=10,
total_tokens=10,
output_token_details=OutputTokenDetails(reasoning=4)
)
add_usage(left, right)
results in
.. code-block:: python
UsageMetadata(
input_tokens=5,
output_tokens=10,
total_tokens=15,
input_token_details=InputTokenDetails(cache_read=3),
output_token_details=OutputTokenDetails(reasoning=4)
)
"""
if not (left or right):
return UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0)
if not (left and right):
return cast(UsageMetadata, left or right)
return UsageMetadata(
**cast(
UsageMetadata,
_dict_int_op(
cast(dict, left),
cast(dict, right),
operator.add,
),
)
)
def subtract_usage(
left: Optional[UsageMetadata], right: Optional[UsageMetadata]
) -> UsageMetadata:
"""Recursively subtract two UsageMetadata objects.
Token counts cannot be negative so the actual operation is max(left - right, 0).
Example:
.. code-block:: python
from langchain_core.messages.ai import subtract_usage
left = UsageMetadata(
input_tokens=5,
output_tokens=10,
total_tokens=15,
input_token_details=InputTokenDetails(cache_read=4)
)
right = UsageMetadata(
input_tokens=3,
output_tokens=8,
total_tokens=11,
output_token_details=OutputTokenDetails(reasoning=4)
)
subtract_usage(left, right)
results in
.. code-block:: python
UsageMetadata(
input_tokens=2,
output_tokens=2,
total_tokens=4,
input_token_details=InputTokenDetails(cache_read=4),
output_token_details=OutputTokenDetails(reasoning=0)
)
"""
if not (left or right):
return UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0)
if not (left and right):
return cast(UsageMetadata, left or right)
return UsageMetadata(
**cast(
UsageMetadata,
_dict_int_op(
cast(dict, left),
cast(dict, right),
(lambda le, ri: max(le - ri, 0)),
),
)
)

View File

@ -0,0 +1,37 @@
from typing import Callable
def _dict_int_op(
left: dict,
right: dict,
op: Callable[[int, int], int],
*,
default: int = 0,
depth: int = 0,
max_depth: int = 100,
) -> dict:
if depth >= max_depth:
msg = f"{max_depth=} exceeded, unable to combine dicts."
raise ValueError(msg)
combined: dict = {}
for k in set(left).union(right):
if isinstance(left.get(k, default), int) and isinstance(
right.get(k, default), int
):
combined[k] = op(left.get(k, default), right.get(k, default))
elif isinstance(left.get(k, {}), dict) and isinstance(right.get(k, {}), dict):
combined[k] = _dict_int_op(
left.get(k, {}),
right.get(k, {}),
op,
default=default,
depth=depth + 1,
max_depth=max_depth,
)
else:
types = [type(d[k]) for d in (left, right) if k in d]
msg = (
f"Unknown value types: {types}. Only dict and int values are supported."
)
raise ValueError(msg)
return combined

View File

@ -1,5 +1,13 @@
from langchain_core.load import dumpd, load
from langchain_core.messages import AIMessage, AIMessageChunk
from langchain_core.messages.ai import (
InputTokenDetails,
OutputTokenDetails,
UsageMetadata,
add_ai_message_chunks,
add_usage,
subtract_usage,
)
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
@ -92,3 +100,99 @@ def test_serdes_message_chunk() -> None:
actual = dumpd(chunk)
assert actual == expected
assert load(actual) == chunk
def test_add_usage_both_none() -> None:
result = add_usage(None, None)
assert result == UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0)
def test_add_usage_one_none() -> None:
usage = UsageMetadata(input_tokens=10, output_tokens=20, total_tokens=30)
result = add_usage(usage, None)
assert result == usage
def test_add_usage_both_present() -> None:
usage1 = UsageMetadata(input_tokens=10, output_tokens=20, total_tokens=30)
usage2 = UsageMetadata(input_tokens=5, output_tokens=10, total_tokens=15)
result = add_usage(usage1, usage2)
assert result == UsageMetadata(input_tokens=15, output_tokens=30, total_tokens=45)
def test_add_usage_with_details() -> None:
usage1 = UsageMetadata(
input_tokens=10,
output_tokens=20,
total_tokens=30,
input_token_details=InputTokenDetails(audio=5),
output_token_details=OutputTokenDetails(reasoning=10),
)
usage2 = UsageMetadata(
input_tokens=5,
output_tokens=10,
total_tokens=15,
input_token_details=InputTokenDetails(audio=3),
output_token_details=OutputTokenDetails(reasoning=5),
)
result = add_usage(usage1, usage2)
assert result["input_token_details"]["audio"] == 8
assert result["output_token_details"]["reasoning"] == 15
def test_subtract_usage_both_none() -> None:
result = subtract_usage(None, None)
assert result == UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0)
def test_subtract_usage_one_none() -> None:
usage = UsageMetadata(input_tokens=10, output_tokens=20, total_tokens=30)
result = subtract_usage(usage, None)
assert result == usage
def test_subtract_usage_both_present() -> None:
usage1 = UsageMetadata(input_tokens=10, output_tokens=20, total_tokens=30)
usage2 = UsageMetadata(input_tokens=5, output_tokens=10, total_tokens=15)
result = subtract_usage(usage1, usage2)
assert result == UsageMetadata(input_tokens=5, output_tokens=10, total_tokens=15)
def test_subtract_usage_with_negative_result() -> None:
usage1 = UsageMetadata(input_tokens=5, output_tokens=10, total_tokens=15)
usage2 = UsageMetadata(input_tokens=10, output_tokens=20, total_tokens=30)
result = subtract_usage(usage1, usage2)
assert result == UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0)
def test_add_ai_message_chunks_usage() -> None:
chunks = [
AIMessageChunk(content="", usage_metadata=None),
AIMessageChunk(
content="",
usage_metadata=UsageMetadata(
input_tokens=2, output_tokens=3, total_tokens=5
),
),
AIMessageChunk(
content="",
usage_metadata=UsageMetadata(
input_tokens=2,
output_tokens=3,
total_tokens=5,
input_token_details=InputTokenDetails(audio=1, cache_read=1),
output_token_details=OutputTokenDetails(audio=1, reasoning=2),
),
),
]
combined = add_ai_message_chunks(*chunks)
assert combined == AIMessageChunk(
content="",
usage_metadata=UsageMetadata(
input_tokens=4,
output_tokens=6,
total_tokens=10,
input_token_details=InputTokenDetails(audio=1, cache_read=1),
output_token_details=OutputTokenDetails(audio=1, reasoning=2),
),
)

View File

@ -0,0 +1,38 @@
import pytest
from langchain_core.utils.usage import _dict_int_op
def test_dict_int_op_add() -> None:
left = {"a": 1, "b": 2}
right = {"b": 3, "c": 4}
result = _dict_int_op(left, right, lambda x, y: x + y)
assert result == {"a": 1, "b": 5, "c": 4}
def test_dict_int_op_subtract() -> None:
left = {"a": 5, "b": 10}
right = {"a": 2, "b": 3, "c": 1}
result = _dict_int_op(left, right, lambda x, y: max(x - y, 0))
assert result == {"a": 3, "b": 7, "c": 0}
def test_dict_int_op_nested() -> None:
left = {"a": 1, "b": {"c": 2, "d": 3}}
right = {"a": 2, "b": {"c": 1, "e": 4}}
result = _dict_int_op(left, right, lambda x, y: x + y)
assert result == {"a": 3, "b": {"c": 3, "d": 3, "e": 4}}
def test_dict_int_op_max_depth_exceeded() -> None:
left = {"a": {"b": {"c": 1}}}
right = {"a": {"b": {"c": 2}}}
with pytest.raises(ValueError):
_dict_int_op(left, right, lambda x, y: x + y, max_depth=2)
def test_dict_int_op_invalid_types() -> None:
left = {"a": 1, "b": "string"}
right = {"a": 2, "b": 3}
with pytest.raises(ValueError):
_dict_int_op(left, right, lambda x, y: x + y)