community: fix issue of the existence of numeric object in additional_kwargs a… (#24863)

- **Description:** A previous PR breaks the code from
`baidu_qianfan_endpoint.py` which causes the malfunction of streaming
This commit is contained in:
Dobiichi-Origami 2024-08-05 22:15:55 +08:00 committed by GitHub
parent cda79dbb6c
commit c5cb52a3c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -34,6 +34,7 @@ from langchain_core.messages import (
ToolMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.tool import tool_call_chunk
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
@ -104,13 +105,18 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage:
# align to api sample, which affects the llm function_call output
additional_kwargs["function_call"].pop("thoughts")
# DO NOT ADD ANY NUMERIC OBJECT TO `msg_additional_kwargs` AND `additional_kwargs`
# ALONG WITH THEIRS SUB-CONTAINERS !!!
# OR IT WILL RAISE A DEADLY EXCEPTION FROM `merge_dict`
# 不要往 `msg_additional_kwargs` 和 `additional_kwargs` 里面加任何数值类对象!
# 子容器也不行!
# 不然 `merge_dict` 会报错导致代码无法运行
additional_kwargs = {**_dict.get("body", {}), **additional_kwargs}
msg_additional_kwargs = dict(
finish_reason=additional_kwargs.get("finish_reason", ""),
request_id=additional_kwargs["id"],
object=additional_kwargs.get("object", ""),
search_info=additional_kwargs.get("search_info", []),
usage=additional_kwargs.get("usage", None),
)
if additional_kwargs.get("function_call", {}):
@ -125,22 +131,20 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage:
}
]
if usage := additional_kwargs.get("usage", None):
return AIMessage(
content=content,
additional_kwargs=msg_additional_kwargs,
usage_metadata=UsageMetadata(
input_tokens=usage.get("prompt_tokens", 0),
output_tokens=usage.get("completion_tokens", 0),
total_tokens=usage.get("total_tokens", 0),
),
)
return AIMessage(
ret = AIMessage(
content=content,
additional_kwargs=msg_additional_kwargs,
)
if usage := additional_kwargs.get("usage", None):
ret.usage_metadata = UsageMetadata(
input_tokens=usage.get("prompt_tokens", 0),
output_tokens=usage.get("completion_tokens", 0),
total_tokens=usage.get("total_tokens", 0),
)
return ret
class QianfanChatEndpoint(BaseChatModel):
"""Baidu Qianfan chat model integration.
@ -613,6 +617,15 @@ class QianfanChatEndpoint(BaseChatModel):
role="assistant",
additional_kwargs=additional_kwargs,
usage_metadata=msg.usage_metadata,
tool_call_chunks=[
tool_call_chunk(
name=tc["name"],
args=json.dumps(tc["args"]),
id=tc["id"],
index=None,
)
for tc in msg.tool_calls
],
),
generation_info=msg.additional_kwargs,
)
@ -641,6 +654,15 @@ class QianfanChatEndpoint(BaseChatModel):
role="assistant",
additional_kwargs=additional_kwargs,
usage_metadata=msg.usage_metadata,
tool_call_chunks=[
tool_call_chunk(
name=tc["name"],
args=json.dumps(tc["args"]),
id=tc["id"],
index=None,
)
for tc in msg.tool_calls
],
),
generation_info=msg.additional_kwargs,
)