mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
community, openai: support nested dicts (#26414)
needed for thinking tokens --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
c0dd293f10
commit
28ad244e77
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -146,6 +147,33 @@ def _convert_delta_to_message_chunk(
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def _update_token_usage(
|
||||
overall_token_usage: Union[int, dict], new_usage: Union[int, dict]
|
||||
) -> Union[int, dict]:
|
||||
# Token usage is either ints or dictionaries
|
||||
# `reasoning_tokens` is nested inside `completion_tokens_details`
|
||||
if isinstance(new_usage, int):
|
||||
if not isinstance(overall_token_usage, int):
|
||||
raise ValueError(
|
||||
f"Got different types for token usage: "
|
||||
f"{type(new_usage)} and {type(overall_token_usage)}"
|
||||
)
|
||||
return new_usage + overall_token_usage
|
||||
elif isinstance(new_usage, dict):
|
||||
if not isinstance(overall_token_usage, dict):
|
||||
raise ValueError(
|
||||
f"Got different types for token usage: "
|
||||
f"{type(new_usage)} and {type(overall_token_usage)}"
|
||||
)
|
||||
return {
|
||||
k: _update_token_usage(overall_token_usage.get(k, 0), v)
|
||||
for k, v in new_usage.items()
|
||||
}
|
||||
else:
|
||||
warnings.warn(f"Unexpected type for token usage: {type(new_usage)}")
|
||||
return new_usage
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.0.10", removal="1.0", alternative_import="langchain_openai.ChatOpenAI"
|
||||
)
|
||||
@ -374,7 +402,9 @@ class ChatOpenAI(BaseChatModel):
|
||||
if token_usage is not None:
|
||||
for k, v in token_usage.items():
|
||||
if k in overall_token_usage:
|
||||
overall_token_usage[k] += v
|
||||
overall_token_usage[k] = _update_token_usage(
|
||||
overall_token_usage[k], v
|
||||
)
|
||||
else:
|
||||
overall_token_usage[k] = v
|
||||
if system_fingerprint is None:
|
||||
|
@ -335,6 +335,33 @@ def _convert_chunk_to_generation_chunk(
|
||||
return generation_chunk
|
||||
|
||||
|
||||
def _update_token_usage(
|
||||
overall_token_usage: Union[int, dict], new_usage: Union[int, dict]
|
||||
) -> Union[int, dict]:
|
||||
# Token usage is either ints or dictionaries
|
||||
# `reasoning_tokens` is nested inside `completion_tokens_details`
|
||||
if isinstance(new_usage, int):
|
||||
if not isinstance(overall_token_usage, int):
|
||||
raise ValueError(
|
||||
f"Got different types for token usage: "
|
||||
f"{type(new_usage)} and {type(overall_token_usage)}"
|
||||
)
|
||||
return new_usage + overall_token_usage
|
||||
elif isinstance(new_usage, dict):
|
||||
if not isinstance(overall_token_usage, dict):
|
||||
raise ValueError(
|
||||
f"Got different types for token usage: "
|
||||
f"{type(new_usage)} and {type(overall_token_usage)}"
|
||||
)
|
||||
return {
|
||||
k: _update_token_usage(overall_token_usage.get(k, 0), v)
|
||||
for k, v in new_usage.items()
|
||||
}
|
||||
else:
|
||||
warnings.warn(f"Unexpected type for token usage: {type(new_usage)}")
|
||||
return new_usage
|
||||
|
||||
|
||||
class _FunctionCall(TypedDict):
|
||||
name: str
|
||||
|
||||
@ -561,7 +588,9 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
if token_usage is not None:
|
||||
for k, v in token_usage.items():
|
||||
if k in overall_token_usage:
|
||||
overall_token_usage[k] += v
|
||||
overall_token_usage[k] = _update_token_usage(
|
||||
overall_token_usage[k], v
|
||||
)
|
||||
else:
|
||||
overall_token_usage[k] = v
|
||||
if system_fingerprint is None:
|
||||
|
@ -1,43 +1,6 @@
|
||||
# serializer version: 1
|
||||
# name: TestOpenAIStandard.test_serdes[serialized]
|
||||
dict({
|
||||
'graph': dict({
|
||||
'edges': list([
|
||||
dict({
|
||||
'source': 0,
|
||||
'target': 1,
|
||||
}),
|
||||
dict({
|
||||
'source': 1,
|
||||
'target': 2,
|
||||
}),
|
||||
]),
|
||||
'nodes': list([
|
||||
dict({
|
||||
'data': 'AzureChatOpenAIInput',
|
||||
'id': 0,
|
||||
'type': 'schema',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'id': list([
|
||||
'langchain',
|
||||
'chat_models',
|
||||
'azure_openai',
|
||||
'AzureChatOpenAI',
|
||||
]),
|
||||
'name': 'AzureChatOpenAI',
|
||||
}),
|
||||
'id': 1,
|
||||
'type': 'runnable',
|
||||
}),
|
||||
dict({
|
||||
'data': 'AzureChatOpenAIOutput',
|
||||
'id': 2,
|
||||
'type': 'schema',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
'id': list([
|
||||
'langchain',
|
||||
'chat_models',
|
||||
|
@ -1,43 +1,6 @@
|
||||
# serializer version: 1
|
||||
# name: TestOpenAIStandard.test_serdes[serialized]
|
||||
dict({
|
||||
'graph': dict({
|
||||
'edges': list([
|
||||
dict({
|
||||
'source': 0,
|
||||
'target': 1,
|
||||
}),
|
||||
dict({
|
||||
'source': 1,
|
||||
'target': 2,
|
||||
}),
|
||||
]),
|
||||
'nodes': list([
|
||||
dict({
|
||||
'data': 'ChatOpenAIInput',
|
||||
'id': 0,
|
||||
'type': 'schema',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'id': list([
|
||||
'langchain',
|
||||
'chat_models',
|
||||
'openai',
|
||||
'ChatOpenAI',
|
||||
]),
|
||||
'name': 'ChatOpenAI',
|
||||
}),
|
||||
'id': 1,
|
||||
'type': 'runnable',
|
||||
}),
|
||||
dict({
|
||||
'data': 'ChatOpenAIOutput',
|
||||
'id': 2,
|
||||
'type': 'schema',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
'id': list([
|
||||
'langchain',
|
||||
'chat_models',
|
||||
|
Loading…
Reference in New Issue
Block a user