mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
mistral, openai: allow anthropic-style messages in message histories (#20565)
This commit is contained in:
parent
7a7851aa06
commit
2238490069
@ -283,9 +283,16 @@ def _convert_message_to_mistral_chat_message(
|
||||
tool_calls.append(chunk)
|
||||
else:
|
||||
pass
|
||||
if tool_calls and message.content:
|
||||
# Assistant message must have either content or tool_calls, but not both.
|
||||
# Some providers may not support tool_calls in the same message as content.
|
||||
# This is done to ensure compatibility with messages from other providers.
|
||||
content: Any = ""
|
||||
else:
|
||||
content = message.content
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": message.content,
|
||||
"content": content,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
elif isinstance(message, SystemMessage):
|
||||
|
@ -148,6 +148,22 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
return ChatMessage(content=_dict.get("content", ""), role=role, id=id_)
|
||||
|
||||
|
||||
def _format_message_content(content: Any) -> Any:
|
||||
"""Format message content."""
|
||||
if content and isinstance(content, list):
|
||||
# Remove unexpected block types
|
||||
formatted_content = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and "type" in block and block["type"] != "text":
|
||||
continue
|
||||
else:
|
||||
formatted_content.append(block)
|
||||
else:
|
||||
formatted_content = content
|
||||
|
||||
return formatted_content
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
"""Convert a LangChain message to a dictionary.
|
||||
|
||||
@ -158,7 +174,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
The dictionary.
|
||||
"""
|
||||
message_dict: Dict[str, Any] = {
|
||||
"content": message.content,
|
||||
"content": _format_message_content(message.content),
|
||||
}
|
||||
if (name := message.name or message.additional_kwargs.get("name")) is not None:
|
||||
message_dict["name"] = name
|
||||
|
@ -117,12 +117,13 @@ class ChatModelIntegrationTests(ABC):
|
||||
assert isinstance(result.content, str)
|
||||
assert len(result.content) > 0
|
||||
|
||||
def test_tool_message(
|
||||
def test_tool_message_histories(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
chat_model_has_tool_calling: bool,
|
||||
) -> None:
|
||||
"""Test that message histories are compatible across providers."""
|
||||
if not chat_model_has_tool_calling:
|
||||
pytest.skip("Test requires tool calling.")
|
||||
model = chat_model_class(**chat_model_params)
|
||||
@ -130,24 +131,55 @@ class ChatModelIntegrationTests(ABC):
|
||||
function_name = "my_adder_tool"
|
||||
function_args = {"a": "1", "b": "2"}
|
||||
|
||||
messages = [
|
||||
HumanMessage(content="What is 1 + 2"),
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": function_name,
|
||||
"args": function_args,
|
||||
"id": "abc123",
|
||||
},
|
||||
],
|
||||
),
|
||||
ToolMessage(
|
||||
name=function_name,
|
||||
content=json.dumps({"result": 3}),
|
||||
tool_call_id="abc123",
|
||||
),
|
||||
]
|
||||
human_message = HumanMessage(content="What is 1 + 2")
|
||||
tool_message = ToolMessage(
|
||||
name=function_name,
|
||||
content=json.dumps({"result": 3}),
|
||||
tool_call_id="abc123",
|
||||
)
|
||||
|
||||
# String content (e.g., OpenAI)
|
||||
string_content_msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": function_name,
|
||||
"args": function_args,
|
||||
"id": "abc123",
|
||||
},
|
||||
],
|
||||
)
|
||||
messages = [
|
||||
human_message,
|
||||
string_content_msg,
|
||||
tool_message,
|
||||
]
|
||||
result = model_with_tools.invoke(messages)
|
||||
assert isinstance(result, AIMessage)
|
||||
|
||||
# List content (e.g., Anthropic)
|
||||
list_content_msg = AIMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "some text"},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "abc123",
|
||||
"name": function_name,
|
||||
"input": function_args,
|
||||
},
|
||||
],
|
||||
tool_calls=[
|
||||
{
|
||||
"name": function_name,
|
||||
"args": function_args,
|
||||
"id": "abc123",
|
||||
},
|
||||
],
|
||||
)
|
||||
messages = [
|
||||
human_message,
|
||||
list_content_msg,
|
||||
tool_message,
|
||||
]
|
||||
result = model_with_tools.invoke(messages)
|
||||
assert isinstance(result, AIMessage)
|
||||
|
Loading…
Reference in New Issue
Block a user