From 24155aa1acaad34726fce0650774c7461dbf4640 Mon Sep 17 00:00:00 2001 From: Chen Xiabin <128658978+bimslab@users.noreply.github.com> Date: Tue, 13 Aug 2024 21:24:41 +0800 Subject: [PATCH] qianfan generate/agenerate with usage_metadata (#25332) --- .../chat_models/baidu_qianfan_endpoint.py | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py b/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py index 2390c3fd91..d088054261 100644 --- a/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py +++ b/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py @@ -512,6 +512,7 @@ class QianfanChatEndpoint(BaseChatModel): if self.streaming: completion = "" chat_generation_info: Dict = {} + usage_metadata: Optional[UsageMetadata] = None for chunk in self._stream(messages, stop, run_manager, **kwargs): chat_generation_info = ( chunk.generation_info @@ -519,7 +520,14 @@ class QianfanChatEndpoint(BaseChatModel): else chat_generation_info ) completion += chunk.text - lc_msg = AIMessage(content=completion, additional_kwargs={}) + if isinstance(chunk.message, AIMessageChunk): + usage_metadata = chunk.message.usage_metadata + + lc_msg = AIMessage( + content=completion, + additional_kwargs={}, + usage_metadata=usage_metadata, + ) gen = ChatGeneration( message=lc_msg, generation_info=dict(finish_reason="stop"), @@ -527,7 +535,7 @@ class QianfanChatEndpoint(BaseChatModel): return ChatResult( generations=[gen], llm_output={ - "token_usage": chat_generation_info.get("usage", {}), + "token_usage": usage_metadata or {}, "model_name": self.model, }, ) @@ -556,6 +564,7 @@ class QianfanChatEndpoint(BaseChatModel): if self.streaming: completion = "" chat_generation_info: Dict = {} + usage_metadata: Optional[UsageMetadata] = None async for chunk in self._astream(messages, stop, run_manager, **kwargs): chat_generation_info = ( chunk.generation_info @@ -564,7 +573,14 @@ class QianfanChatEndpoint(BaseChatModel): ) completion += chunk.text - lc_msg = AIMessage(content=completion, additional_kwargs={}) + if isinstance(chunk.message, AIMessageChunk): + usage_metadata = chunk.message.usage_metadata + + lc_msg = AIMessage( + content=completion, + additional_kwargs={}, + usage_metadata=usage_metadata, + ) gen = ChatGeneration( message=lc_msg, generation_info=dict(finish_reason="stop"), @@ -572,7 +588,7 @@ class QianfanChatEndpoint(BaseChatModel): return ChatResult( generations=[gen], llm_output={ - "token_usage": chat_generation_info.get("usage", {}), + "token_usage": usage_metadata or {}, "model_name": self.model, }, )