mirror of
https://github.com/hwchase17/langchain
synced 2024-11-13 19:10:52 +00:00
core[patch]: utils for adding/subtracting usage metadata (#27203)
This commit is contained in:
parent
e3920f2320
commit
e3e9ee8398
@ -1,5 +1,6 @@
|
||||
import json
|
||||
from typing import Any, Literal, Optional, Union
|
||||
import operator
|
||||
from typing import Any, Literal, Optional, Union, cast
|
||||
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import NotRequired, Self, TypedDict
|
||||
@ -27,6 +28,7 @@ from langchain_core.messages.tool import (
|
||||
)
|
||||
from langchain_core.utils._merge import merge_dicts, merge_lists
|
||||
from langchain_core.utils.json import parse_partial_json
|
||||
from langchain_core.utils.usage import _dict_int_op
|
||||
|
||||
|
||||
class InputTokenDetails(TypedDict, total=False):
|
||||
@ -432,17 +434,9 @@ def add_ai_message_chunks(
|
||||
|
||||
# Token usage
|
||||
if left.usage_metadata or any(o.usage_metadata is not None for o in others):
|
||||
usage_metadata_: UsageMetadata = left.usage_metadata or UsageMetadata(
|
||||
input_tokens=0, output_tokens=0, total_tokens=0
|
||||
)
|
||||
usage_metadata: Optional[UsageMetadata] = left.usage_metadata
|
||||
for other in others:
|
||||
if other.usage_metadata is not None:
|
||||
usage_metadata_["input_tokens"] += other.usage_metadata["input_tokens"]
|
||||
usage_metadata_["output_tokens"] += other.usage_metadata[
|
||||
"output_tokens"
|
||||
]
|
||||
usage_metadata_["total_tokens"] += other.usage_metadata["total_tokens"]
|
||||
usage_metadata: Optional[UsageMetadata] = usage_metadata_
|
||||
usage_metadata = add_usage(usage_metadata, other.usage_metadata)
|
||||
else:
|
||||
usage_metadata = None
|
||||
|
||||
@ -455,3 +449,115 @@ def add_ai_message_chunks(
|
||||
usage_metadata=usage_metadata,
|
||||
id=left.id,
|
||||
)
|
||||
|
||||
|
||||
def add_usage(
|
||||
left: Optional[UsageMetadata], right: Optional[UsageMetadata]
|
||||
) -> UsageMetadata:
|
||||
"""Recursively add two UsageMetadata objects.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.messages.ai import add_usage
|
||||
|
||||
left = UsageMetadata(
|
||||
input_tokens=5,
|
||||
output_tokens=0,
|
||||
total_tokens=5,
|
||||
input_token_details=InputTokenDetails(cache_read=3)
|
||||
)
|
||||
right = UsageMetadata(
|
||||
input_tokens=0,
|
||||
output_tokens=10,
|
||||
total_tokens=10,
|
||||
output_token_details=OutputTokenDetails(reasoning=4)
|
||||
)
|
||||
|
||||
add_usage(left, right)
|
||||
|
||||
results in
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
UsageMetadata(
|
||||
input_tokens=5,
|
||||
output_tokens=10,
|
||||
total_tokens=15,
|
||||
input_token_details=InputTokenDetails(cache_read=3),
|
||||
output_token_details=OutputTokenDetails(reasoning=4)
|
||||
)
|
||||
|
||||
"""
|
||||
if not (left or right):
|
||||
return UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0)
|
||||
if not (left and right):
|
||||
return cast(UsageMetadata, left or right)
|
||||
|
||||
return UsageMetadata(
|
||||
**cast(
|
||||
UsageMetadata,
|
||||
_dict_int_op(
|
||||
cast(dict, left),
|
||||
cast(dict, right),
|
||||
operator.add,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def subtract_usage(
|
||||
left: Optional[UsageMetadata], right: Optional[UsageMetadata]
|
||||
) -> UsageMetadata:
|
||||
"""Recursively subtract two UsageMetadata objects.
|
||||
|
||||
Token counts cannot be negative so the actual operation is max(left - right, 0).
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.messages.ai import subtract_usage
|
||||
|
||||
left = UsageMetadata(
|
||||
input_tokens=5,
|
||||
output_tokens=10,
|
||||
total_tokens=15,
|
||||
input_token_details=InputTokenDetails(cache_read=4)
|
||||
)
|
||||
right = UsageMetadata(
|
||||
input_tokens=3,
|
||||
output_tokens=8,
|
||||
total_tokens=11,
|
||||
output_token_details=OutputTokenDetails(reasoning=4)
|
||||
)
|
||||
|
||||
subtract_usage(left, right)
|
||||
|
||||
results in
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
UsageMetadata(
|
||||
input_tokens=2,
|
||||
output_tokens=2,
|
||||
total_tokens=4,
|
||||
input_token_details=InputTokenDetails(cache_read=4),
|
||||
output_token_details=OutputTokenDetails(reasoning=0)
|
||||
)
|
||||
|
||||
"""
|
||||
if not (left or right):
|
||||
return UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0)
|
||||
if not (left and right):
|
||||
return cast(UsageMetadata, left or right)
|
||||
|
||||
return UsageMetadata(
|
||||
**cast(
|
||||
UsageMetadata,
|
||||
_dict_int_op(
|
||||
cast(dict, left),
|
||||
cast(dict, right),
|
||||
(lambda le, ri: max(le - ri, 0)),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
37
libs/core/langchain_core/utils/usage.py
Normal file
37
libs/core/langchain_core/utils/usage.py
Normal file
@ -0,0 +1,37 @@
|
||||
from typing import Callable
|
||||
|
||||
|
||||
def _dict_int_op(
|
||||
left: dict,
|
||||
right: dict,
|
||||
op: Callable[[int, int], int],
|
||||
*,
|
||||
default: int = 0,
|
||||
depth: int = 0,
|
||||
max_depth: int = 100,
|
||||
) -> dict:
|
||||
if depth >= max_depth:
|
||||
msg = f"{max_depth=} exceeded, unable to combine dicts."
|
||||
raise ValueError(msg)
|
||||
combined: dict = {}
|
||||
for k in set(left).union(right):
|
||||
if isinstance(left.get(k, default), int) and isinstance(
|
||||
right.get(k, default), int
|
||||
):
|
||||
combined[k] = op(left.get(k, default), right.get(k, default))
|
||||
elif isinstance(left.get(k, {}), dict) and isinstance(right.get(k, {}), dict):
|
||||
combined[k] = _dict_int_op(
|
||||
left.get(k, {}),
|
||||
right.get(k, {}),
|
||||
op,
|
||||
default=default,
|
||||
depth=depth + 1,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
else:
|
||||
types = [type(d[k]) for d in (left, right) if k in d]
|
||||
msg = (
|
||||
f"Unknown value types: {types}. Only dict and int values are supported."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return combined
|
@ -1,5 +1,13 @@
|
||||
from langchain_core.load import dumpd, load
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages.ai import (
|
||||
InputTokenDetails,
|
||||
OutputTokenDetails,
|
||||
UsageMetadata,
|
||||
add_ai_message_chunks,
|
||||
add_usage,
|
||||
subtract_usage,
|
||||
)
|
||||
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
|
||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
||||
@ -92,3 +100,99 @@ def test_serdes_message_chunk() -> None:
|
||||
actual = dumpd(chunk)
|
||||
assert actual == expected
|
||||
assert load(actual) == chunk
|
||||
|
||||
|
||||
def test_add_usage_both_none() -> None:
|
||||
result = add_usage(None, None)
|
||||
assert result == UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0)
|
||||
|
||||
|
||||
def test_add_usage_one_none() -> None:
|
||||
usage = UsageMetadata(input_tokens=10, output_tokens=20, total_tokens=30)
|
||||
result = add_usage(usage, None)
|
||||
assert result == usage
|
||||
|
||||
|
||||
def test_add_usage_both_present() -> None:
|
||||
usage1 = UsageMetadata(input_tokens=10, output_tokens=20, total_tokens=30)
|
||||
usage2 = UsageMetadata(input_tokens=5, output_tokens=10, total_tokens=15)
|
||||
result = add_usage(usage1, usage2)
|
||||
assert result == UsageMetadata(input_tokens=15, output_tokens=30, total_tokens=45)
|
||||
|
||||
|
||||
def test_add_usage_with_details() -> None:
|
||||
usage1 = UsageMetadata(
|
||||
input_tokens=10,
|
||||
output_tokens=20,
|
||||
total_tokens=30,
|
||||
input_token_details=InputTokenDetails(audio=5),
|
||||
output_token_details=OutputTokenDetails(reasoning=10),
|
||||
)
|
||||
usage2 = UsageMetadata(
|
||||
input_tokens=5,
|
||||
output_tokens=10,
|
||||
total_tokens=15,
|
||||
input_token_details=InputTokenDetails(audio=3),
|
||||
output_token_details=OutputTokenDetails(reasoning=5),
|
||||
)
|
||||
result = add_usage(usage1, usage2)
|
||||
assert result["input_token_details"]["audio"] == 8
|
||||
assert result["output_token_details"]["reasoning"] == 15
|
||||
|
||||
|
||||
def test_subtract_usage_both_none() -> None:
|
||||
result = subtract_usage(None, None)
|
||||
assert result == UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0)
|
||||
|
||||
|
||||
def test_subtract_usage_one_none() -> None:
|
||||
usage = UsageMetadata(input_tokens=10, output_tokens=20, total_tokens=30)
|
||||
result = subtract_usage(usage, None)
|
||||
assert result == usage
|
||||
|
||||
|
||||
def test_subtract_usage_both_present() -> None:
|
||||
usage1 = UsageMetadata(input_tokens=10, output_tokens=20, total_tokens=30)
|
||||
usage2 = UsageMetadata(input_tokens=5, output_tokens=10, total_tokens=15)
|
||||
result = subtract_usage(usage1, usage2)
|
||||
assert result == UsageMetadata(input_tokens=5, output_tokens=10, total_tokens=15)
|
||||
|
||||
|
||||
def test_subtract_usage_with_negative_result() -> None:
|
||||
usage1 = UsageMetadata(input_tokens=5, output_tokens=10, total_tokens=15)
|
||||
usage2 = UsageMetadata(input_tokens=10, output_tokens=20, total_tokens=30)
|
||||
result = subtract_usage(usage1, usage2)
|
||||
assert result == UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0)
|
||||
|
||||
|
||||
def test_add_ai_message_chunks_usage() -> None:
|
||||
chunks = [
|
||||
AIMessageChunk(content="", usage_metadata=None),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
usage_metadata=UsageMetadata(
|
||||
input_tokens=2, output_tokens=3, total_tokens=5
|
||||
),
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
usage_metadata=UsageMetadata(
|
||||
input_tokens=2,
|
||||
output_tokens=3,
|
||||
total_tokens=5,
|
||||
input_token_details=InputTokenDetails(audio=1, cache_read=1),
|
||||
output_token_details=OutputTokenDetails(audio=1, reasoning=2),
|
||||
),
|
||||
),
|
||||
]
|
||||
combined = add_ai_message_chunks(*chunks)
|
||||
assert combined == AIMessageChunk(
|
||||
content="",
|
||||
usage_metadata=UsageMetadata(
|
||||
input_tokens=4,
|
||||
output_tokens=6,
|
||||
total_tokens=10,
|
||||
input_token_details=InputTokenDetails(audio=1, cache_read=1),
|
||||
output_token_details=OutputTokenDetails(audio=1, reasoning=2),
|
||||
),
|
||||
)
|
||||
|
38
libs/core/tests/unit_tests/utils/test_usage.py
Normal file
38
libs/core/tests/unit_tests/utils/test_usage.py
Normal file
@ -0,0 +1,38 @@
|
||||
import pytest
|
||||
|
||||
from langchain_core.utils.usage import _dict_int_op
|
||||
|
||||
|
||||
def test_dict_int_op_add() -> None:
|
||||
left = {"a": 1, "b": 2}
|
||||
right = {"b": 3, "c": 4}
|
||||
result = _dict_int_op(left, right, lambda x, y: x + y)
|
||||
assert result == {"a": 1, "b": 5, "c": 4}
|
||||
|
||||
|
||||
def test_dict_int_op_subtract() -> None:
|
||||
left = {"a": 5, "b": 10}
|
||||
right = {"a": 2, "b": 3, "c": 1}
|
||||
result = _dict_int_op(left, right, lambda x, y: max(x - y, 0))
|
||||
assert result == {"a": 3, "b": 7, "c": 0}
|
||||
|
||||
|
||||
def test_dict_int_op_nested() -> None:
|
||||
left = {"a": 1, "b": {"c": 2, "d": 3}}
|
||||
right = {"a": 2, "b": {"c": 1, "e": 4}}
|
||||
result = _dict_int_op(left, right, lambda x, y: x + y)
|
||||
assert result == {"a": 3, "b": {"c": 3, "d": 3, "e": 4}}
|
||||
|
||||
|
||||
def test_dict_int_op_max_depth_exceeded() -> None:
|
||||
left = {"a": {"b": {"c": 1}}}
|
||||
right = {"a": {"b": {"c": 2}}}
|
||||
with pytest.raises(ValueError):
|
||||
_dict_int_op(left, right, lambda x, y: x + y, max_depth=2)
|
||||
|
||||
|
||||
def test_dict_int_op_invalid_types() -> None:
|
||||
left = {"a": 1, "b": "string"}
|
||||
right = {"a": 2, "b": 3}
|
||||
with pytest.raises(ValueError):
|
||||
_dict_int_op(left, right, lambda x, y: x + y)
|
Loading…
Reference in New Issue
Block a user