From 28ad244e775610d4826fbfe34c62dba826cf26c9 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 12 Sep 2024 21:47:47 -0700 Subject: [PATCH] community, openai: support nested dicts (#26414) needed for thinking tokens --------- Co-authored-by: Erick Friis --- .../langchain_community/chat_models/openai.py | 32 +++++++++++++++- .../langchain_openai/chat_models/base.py | 31 +++++++++++++++- .../__snapshots__/test_azure_standard.ambr | 37 ------------------- .../__snapshots__/test_base_standard.ambr | 37 ------------------- 4 files changed, 61 insertions(+), 76 deletions(-) diff --git a/libs/community/langchain_community/chat_models/openai.py b/libs/community/langchain_community/chat_models/openai.py index 54326d0acc..cc320de50c 100644 --- a/libs/community/langchain_community/chat_models/openai.py +++ b/libs/community/langchain_community/chat_models/openai.py @@ -5,6 +5,7 @@ from __future__ import annotations import logging import os import sys +import warnings from typing import ( TYPE_CHECKING, Any, @@ -146,6 +147,33 @@ def _convert_delta_to_message_chunk( return default_class(content=content) # type: ignore[call-arg] +def _update_token_usage( + overall_token_usage: Union[int, dict], new_usage: Union[int, dict] +) -> Union[int, dict]: + # Token usage is either ints or dictionaries + # `reasoning_tokens` is nested inside `completion_tokens_details` + if isinstance(new_usage, int): + if not isinstance(overall_token_usage, int): + raise ValueError( + f"Got different types for token usage: " + f"{type(new_usage)} and {type(overall_token_usage)}" + ) + return new_usage + overall_token_usage + elif isinstance(new_usage, dict): + if not isinstance(overall_token_usage, dict): + raise ValueError( + f"Got different types for token usage: " + f"{type(new_usage)} and {type(overall_token_usage)}" + ) + return { + k: _update_token_usage(overall_token_usage.get(k, 0), v) + for k, v in new_usage.items() + } + else: + warnings.warn(f"Unexpected type for token usage: {type(new_usage)}") + return new_usage + + @deprecated( since="0.0.10", removal="1.0", alternative_import="langchain_openai.ChatOpenAI" ) @@ -374,7 +402,9 @@ class ChatOpenAI(BaseChatModel): if token_usage is not None: for k, v in token_usage.items(): if k in overall_token_usage: - overall_token_usage[k] += v + overall_token_usage[k] = _update_token_usage( + overall_token_usage[k], v + ) else: overall_token_usage[k] = v if system_fingerprint is None: diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index b67bcafe60..61d9141df4 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -335,6 +335,33 @@ def _convert_chunk_to_generation_chunk( return generation_chunk +def _update_token_usage( + overall_token_usage: Union[int, dict], new_usage: Union[int, dict] +) -> Union[int, dict]: + # Token usage is either ints or dictionaries + # `reasoning_tokens` is nested inside `completion_tokens_details` + if isinstance(new_usage, int): + if not isinstance(overall_token_usage, int): + raise ValueError( + f"Got different types for token usage: " + f"{type(new_usage)} and {type(overall_token_usage)}" + ) + return new_usage + overall_token_usage + elif isinstance(new_usage, dict): + if not isinstance(overall_token_usage, dict): + raise ValueError( + f"Got different types for token usage: " + f"{type(new_usage)} and {type(overall_token_usage)}" + ) + return { + k: _update_token_usage(overall_token_usage.get(k, 0), v) + for k, v in new_usage.items() + } + else: + warnings.warn(f"Unexpected type for token usage: {type(new_usage)}") + return new_usage + + class _FunctionCall(TypedDict): name: str @@ -561,7 +588,9 @@ class BaseChatOpenAI(BaseChatModel): if token_usage is not None: for k, v in token_usage.items(): if k in overall_token_usage: - overall_token_usage[k] += v + overall_token_usage[k] = _update_token_usage( + overall_token_usage[k], v + ) else: overall_token_usage[k] = v if system_fingerprint is None: diff --git a/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_azure_standard.ambr b/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_azure_standard.ambr index 8233a605fd..b2a33640aa 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_azure_standard.ambr +++ b/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_azure_standard.ambr @@ -1,43 +1,6 @@ # serializer version: 1 # name: TestOpenAIStandard.test_serdes[serialized] dict({ - 'graph': dict({ - 'edges': list([ - dict({ - 'source': 0, - 'target': 1, - }), - dict({ - 'source': 1, - 'target': 2, - }), - ]), - 'nodes': list([ - dict({ - 'data': 'AzureChatOpenAIInput', - 'id': 0, - 'type': 'schema', - }), - dict({ - 'data': dict({ - 'id': list([ - 'langchain', - 'chat_models', - 'azure_openai', - 'AzureChatOpenAI', - ]), - 'name': 'AzureChatOpenAI', - }), - 'id': 1, - 'type': 'runnable', - }), - dict({ - 'data': 'AzureChatOpenAIOutput', - 'id': 2, - 'type': 'schema', - }), - ]), - }), 'id': list([ 'langchain', 'chat_models', diff --git a/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_base_standard.ambr b/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_base_standard.ambr index 02e27256e5..dfcdefa88a 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_base_standard.ambr +++ b/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_base_standard.ambr @@ -1,43 +1,6 @@ # serializer version: 1 # name: TestOpenAIStandard.test_serdes[serialized] dict({ - 'graph': dict({ - 'edges': list([ - dict({ - 'source': 0, - 'target': 1, - }), - dict({ - 'source': 1, - 'target': 2, - }), - ]), - 'nodes': list([ - dict({ - 'data': 'ChatOpenAIInput', - 'id': 0, - 'type': 'schema', - }), - dict({ - 'data': dict({ - 'id': list([ - 'langchain', - 'chat_models', - 'openai', - 'ChatOpenAI', - ]), - 'name': 'ChatOpenAI', - }), - 'id': 1, - 'type': 'runnable', - }), - dict({ - 'data': 'ChatOpenAIOutput', - 'id': 2, - 'type': 'schema', - }), - ]), - }), 'id': list([ 'langchain', 'chat_models',