qianfan generate/agenerate with usage_metadata (#25332)

This commit is contained in:
Chen Xiabin 2024-08-13 21:24:41 +08:00 committed by GitHub
parent ebbe609193
commit 24155aa1ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -512,6 +512,7 @@ class QianfanChatEndpoint(BaseChatModel):
if self.streaming:
completion = ""
chat_generation_info: Dict = {}
usage_metadata: Optional[UsageMetadata] = None
for chunk in self._stream(messages, stop, run_manager, **kwargs):
chat_generation_info = (
chunk.generation_info
@ -519,7 +520,14 @@ class QianfanChatEndpoint(BaseChatModel):
else chat_generation_info
)
completion += chunk.text
lc_msg = AIMessage(content=completion, additional_kwargs={})
if isinstance(chunk.message, AIMessageChunk):
usage_metadata = chunk.message.usage_metadata
lc_msg = AIMessage(
content=completion,
additional_kwargs={},
usage_metadata=usage_metadata,
)
gen = ChatGeneration(
message=lc_msg,
generation_info=dict(finish_reason="stop"),
@ -527,7 +535,7 @@ class QianfanChatEndpoint(BaseChatModel):
return ChatResult(
generations=[gen],
llm_output={
"token_usage": chat_generation_info.get("usage", {}),
"token_usage": usage_metadata or {},
"model_name": self.model,
},
)
@ -556,6 +564,7 @@ class QianfanChatEndpoint(BaseChatModel):
if self.streaming:
completion = ""
chat_generation_info: Dict = {}
usage_metadata: Optional[UsageMetadata] = None
async for chunk in self._astream(messages, stop, run_manager, **kwargs):
chat_generation_info = (
chunk.generation_info
@ -564,7 +573,14 @@ class QianfanChatEndpoint(BaseChatModel):
)
completion += chunk.text
lc_msg = AIMessage(content=completion, additional_kwargs={})
if isinstance(chunk.message, AIMessageChunk):
usage_metadata = chunk.message.usage_metadata
lc_msg = AIMessage(
content=completion,
additional_kwargs={},
usage_metadata=usage_metadata,
)
gen = ChatGeneration(
message=lc_msg,
generation_info=dict(finish_reason="stop"),
@ -572,7 +588,7 @@ class QianfanChatEndpoint(BaseChatModel):
return ChatResult(
generations=[gen],
llm_output={
"token_usage": chat_generation_info.get("usage", {}),
"token_usage": usage_metadata or {},
"model_name": self.model,
},
)