rfc: add AIMessage.token_usage

pull/20522/head
Bagatur 1 month ago
parent 77eba10f47
commit 3038dc1588

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Literal
from typing import Any, Dict, List, Literal, Optional, TypedDict
from langchain_core.messages.base import (
BaseMessage,
@ -19,6 +19,11 @@ from langchain_core.utils.json import (
)
class TokenUsage(TypedDict):
prompt_tokens: int
completion_tokens: int
class AIMessage(BaseMessage):
"""Message from an AI."""
@ -31,6 +36,7 @@ class AIMessage(BaseMessage):
"""If provided, tool calls associated with the message."""
invalid_tool_calls: List[InvalidToolCall] = []
"""If provided, tool calls with parsing errors associated with the message."""
token_usage: Optional[TokenUsage] = None
type: Literal["ai"] = "ai"
@ -167,12 +173,22 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
else:
tool_call_chunks = []
if self.token_usage and other.token_usage:
token_usage = TokenUsage(
prompt_tokens=self.token_usage["prompt_tokens"],
completion_tokens=self.token_usage["completion_tokens"]
+ other.token_usage["completion_tokens"],
)
else:
token_usage = self.token_usage or other.token_usage
return self.__class__(
example=self.example,
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
response_metadata=response_metadata,
token_usage=token_usage,
id=self.id,
)

@ -58,6 +58,7 @@ from langchain_core.messages import (
ToolMessage,
ToolMessageChunk,
)
from langchain_core.messages.ai import TokenUsage
from langchain_core.output_parsers import (
JsonOutputParser,
PydanticOutputParser,
@ -512,24 +513,30 @@ class ChatOpenAI(BaseChatModel):
choice = chunk["choices"][0]
if choice["delta"] is None:
continue
chunk = _convert_delta_to_message_chunk(
msg_chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
if chunk.get("usage"):
token_usage = TokenUsage(
prompt_tokens=chunk["usage"]["prompt_tokens"],
completion_tokens=chunk["usage"]["completion_tokens"],
)
msg_chunk.token_usage = token_usage
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
default_chunk_class = msg_chunk.__class__
gen_chunk = ChatGenerationChunk(
message=msg_chunk, generation_info=generation_info or None
)
if run_manager:
run_manager.on_llm_new_token(
chunk.text, chunk=chunk, logprobs=logprobs
gen_chunk.text, chunk=gen_chunk, logprobs=logprobs
)
yield chunk
yield gen_chunk
def _generate(
self,
@ -574,7 +581,15 @@ class ChatOpenAI(BaseChatModel):
raise ValueError(response.get("error"))
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
message: AIMessage = cast(
AIMessage, _convert_dict_to_message(res["message"])
)
if response.get("usage"):
token_usage = TokenUsage(
prompt_tokens=response["usage"]["prompt_tokens"],
completion_tokens=response["usage"]["completion_tokens"],
)
message.token_usage = token_usage
generation_info = dict(finish_reason=res.get("finish_reason"))
if "logprobs" in res:
generation_info["logprobs"] = res["logprobs"]
@ -583,9 +598,8 @@ class ChatOpenAI(BaseChatModel):
generation_info=generation_info,
)
generations.append(gen)
token_usage = response.get("usage", {})
llm_output = {
"token_usage": token_usage,
"token_usage": response.get("usage", {}),
"model_name": self.model_name,
"system_fingerprint": response.get("system_fingerprint", ""),
}

Loading…
Cancel
Save