mistral, openai: allow anthropic-style messages in message histories (#20565)

This commit is contained in:
ccurme 2024-04-17 15:55:45 -04:00 committed by GitHub
parent 7a7851aa06
commit 2238490069
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 76 additions and 21 deletions

View File

@ -283,9 +283,16 @@ def _convert_message_to_mistral_chat_message(
tool_calls.append(chunk)
else:
pass
if tool_calls and message.content:
# Assistant message must have either content or tool_calls, but not both.
# Some providers may not support tool_calls in the same message as content.
# This is done to ensure compatibility with messages from other providers.
content: Any = ""
else:
content = message.content
return {
"role": "assistant",
"content": message.content,
"content": content,
"tool_calls": tool_calls,
}
elif isinstance(message, SystemMessage):

View File

@ -148,6 +148,22 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
return ChatMessage(content=_dict.get("content", ""), role=role, id=id_)
def _format_message_content(content: Any) -> Any:
"""Format message content."""
if content and isinstance(content, list):
# Remove unexpected block types
formatted_content = []
for block in content:
if isinstance(block, dict) and "type" in block and block["type"] != "text":
continue
else:
formatted_content.append(block)
else:
formatted_content = content
return formatted_content
def _convert_message_to_dict(message: BaseMessage) -> dict:
"""Convert a LangChain message to a dictionary.
@ -158,7 +174,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
The dictionary.
"""
message_dict: Dict[str, Any] = {
"content": message.content,
"content": _format_message_content(message.content),
}
if (name := message.name or message.additional_kwargs.get("name")) is not None:
message_dict["name"] = name

View File

@ -117,12 +117,13 @@ class ChatModelIntegrationTests(ABC):
assert isinstance(result.content, str)
assert len(result.content) > 0
def test_tool_message(
def test_tool_message_histories(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
chat_model_has_tool_calling: bool,
) -> None:
"""Test that message histories are compatible across providers."""
if not chat_model_has_tool_calling:
pytest.skip("Test requires tool calling.")
model = chat_model_class(**chat_model_params)
@ -130,24 +131,55 @@ class ChatModelIntegrationTests(ABC):
function_name = "my_adder_tool"
function_args = {"a": "1", "b": "2"}
messages = [
HumanMessage(content="What is 1 + 2"),
AIMessage(
content="",
tool_calls=[
{
"name": function_name,
"args": function_args,
"id": "abc123",
},
],
),
ToolMessage(
name=function_name,
content=json.dumps({"result": 3}),
tool_call_id="abc123",
),
]
human_message = HumanMessage(content="What is 1 + 2")
tool_message = ToolMessage(
name=function_name,
content=json.dumps({"result": 3}),
tool_call_id="abc123",
)
# String content (e.g., OpenAI)
string_content_msg = AIMessage(
content="",
tool_calls=[
{
"name": function_name,
"args": function_args,
"id": "abc123",
},
],
)
messages = [
human_message,
string_content_msg,
tool_message,
]
result = model_with_tools.invoke(messages)
assert isinstance(result, AIMessage)
# List content (e.g., Anthropic)
list_content_msg = AIMessage(
content=[
{"type": "text", "text": "some text"},
{
"type": "tool_use",
"id": "abc123",
"name": function_name,
"input": function_args,
},
],
tool_calls=[
{
"name": function_name,
"args": function_args,
"id": "abc123",
},
],
)
messages = [
human_message,
list_content_msg,
tool_message,
]
result = model_with_tools.invoke(messages)
assert isinstance(result, AIMessage)