community, openai: support nested dicts (#26414)

needed for thinking tokens

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Harrison Chase 2024-09-12 21:47:47 -07:00 committed by GitHub
parent c0dd293f10
commit 28ad244e77
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 61 additions and 76 deletions

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import logging
import os
import sys
import warnings
from typing import (
TYPE_CHECKING,
Any,
@ -146,6 +147,33 @@ def _convert_delta_to_message_chunk(
return default_class(content=content) # type: ignore[call-arg]
def _update_token_usage(
overall_token_usage: Union[int, dict], new_usage: Union[int, dict]
) -> Union[int, dict]:
# Token usage is either ints or dictionaries
# `reasoning_tokens` is nested inside `completion_tokens_details`
if isinstance(new_usage, int):
if not isinstance(overall_token_usage, int):
raise ValueError(
f"Got different types for token usage: "
f"{type(new_usage)} and {type(overall_token_usage)}"
)
return new_usage + overall_token_usage
elif isinstance(new_usage, dict):
if not isinstance(overall_token_usage, dict):
raise ValueError(
f"Got different types for token usage: "
f"{type(new_usage)} and {type(overall_token_usage)}"
)
return {
k: _update_token_usage(overall_token_usage.get(k, 0), v)
for k, v in new_usage.items()
}
else:
warnings.warn(f"Unexpected type for token usage: {type(new_usage)}")
return new_usage
@deprecated(
since="0.0.10", removal="1.0", alternative_import="langchain_openai.ChatOpenAI"
)
@ -374,7 +402,9 @@ class ChatOpenAI(BaseChatModel):
if token_usage is not None:
for k, v in token_usage.items():
if k in overall_token_usage:
overall_token_usage[k] += v
overall_token_usage[k] = _update_token_usage(
overall_token_usage[k], v
)
else:
overall_token_usage[k] = v
if system_fingerprint is None:

View File

@ -335,6 +335,33 @@ def _convert_chunk_to_generation_chunk(
return generation_chunk
def _update_token_usage(
overall_token_usage: Union[int, dict], new_usage: Union[int, dict]
) -> Union[int, dict]:
# Token usage is either ints or dictionaries
# `reasoning_tokens` is nested inside `completion_tokens_details`
if isinstance(new_usage, int):
if not isinstance(overall_token_usage, int):
raise ValueError(
f"Got different types for token usage: "
f"{type(new_usage)} and {type(overall_token_usage)}"
)
return new_usage + overall_token_usage
elif isinstance(new_usage, dict):
if not isinstance(overall_token_usage, dict):
raise ValueError(
f"Got different types for token usage: "
f"{type(new_usage)} and {type(overall_token_usage)}"
)
return {
k: _update_token_usage(overall_token_usage.get(k, 0), v)
for k, v in new_usage.items()
}
else:
warnings.warn(f"Unexpected type for token usage: {type(new_usage)}")
return new_usage
class _FunctionCall(TypedDict):
name: str
@ -561,7 +588,9 @@ class BaseChatOpenAI(BaseChatModel):
if token_usage is not None:
for k, v in token_usage.items():
if k in overall_token_usage:
overall_token_usage[k] += v
overall_token_usage[k] = _update_token_usage(
overall_token_usage[k], v
)
else:
overall_token_usage[k] = v
if system_fingerprint is None:

View File

@ -1,43 +1,6 @@
# serializer version: 1
# name: TestOpenAIStandard.test_serdes[serialized]
dict({
'graph': dict({
'edges': list([
dict({
'source': 0,
'target': 1,
}),
dict({
'source': 1,
'target': 2,
}),
]),
'nodes': list([
dict({
'data': 'AzureChatOpenAIInput',
'id': 0,
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain',
'chat_models',
'azure_openai',
'AzureChatOpenAI',
]),
'name': 'AzureChatOpenAI',
}),
'id': 1,
'type': 'runnable',
}),
dict({
'data': 'AzureChatOpenAIOutput',
'id': 2,
'type': 'schema',
}),
]),
}),
'id': list([
'langchain',
'chat_models',

View File

@ -1,43 +1,6 @@
# serializer version: 1
# name: TestOpenAIStandard.test_serdes[serialized]
dict({
'graph': dict({
'edges': list([
dict({
'source': 0,
'target': 1,
}),
dict({
'source': 1,
'target': 2,
}),
]),
'nodes': list([
dict({
'data': 'ChatOpenAIInput',
'id': 0,
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain',
'chat_models',
'openai',
'ChatOpenAI',
]),
'name': 'ChatOpenAI',
}),
'id': 1,
'type': 'runnable',
}),
dict({
'data': 'ChatOpenAIOutput',
'id': 2,
'type': 'schema',
}),
]),
}),
'id': list([
'langchain',
'chat_models',