diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index 159fd45f76..c3bf669e86 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -283,9 +283,16 @@ def _convert_message_to_mistral_chat_message( tool_calls.append(chunk) else: pass + if tool_calls and message.content: + # Assistant message must have either content or tool_calls, but not both. + # Some providers may not support tool_calls in the same message as content. + # This is done to ensure compatibility with messages from other providers. + content: Any = "" + else: + content = message.content return { "role": "assistant", - "content": message.content, + "content": content, "tool_calls": tool_calls, } elif isinstance(message, SystemMessage): diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 50308b6327..7c132a3cb7 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -148,6 +148,22 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: return ChatMessage(content=_dict.get("content", ""), role=role, id=id_) +def _format_message_content(content: Any) -> Any: + """Format message content.""" + if content and isinstance(content, list): + # Remove unexpected block types + formatted_content = [] + for block in content: + if isinstance(block, dict) and "type" in block and block["type"] != "text": + continue + else: + formatted_content.append(block) + else: + formatted_content = content + + return formatted_content + + def _convert_message_to_dict(message: BaseMessage) -> dict: """Convert a LangChain message to a dictionary. @@ -158,7 +174,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: The dictionary. """ message_dict: Dict[str, Any] = { - "content": message.content, + "content": _format_message_content(message.content), } if (name := message.name or message.additional_kwargs.get("name")) is not None: message_dict["name"] = name diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py index e1a99772d6..734283f729 100644 --- a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py @@ -117,12 +117,13 @@ class ChatModelIntegrationTests(ABC): assert isinstance(result.content, str) assert len(result.content) > 0 - def test_tool_message( + def test_tool_message_histories( self, chat_model_class: Type[BaseChatModel], chat_model_params: dict, chat_model_has_tool_calling: bool, ) -> None: + """Test that message histories are compatible across providers.""" if not chat_model_has_tool_calling: pytest.skip("Test requires tool calling.") model = chat_model_class(**chat_model_params) @@ -130,24 +131,55 @@ class ChatModelIntegrationTests(ABC): function_name = "my_adder_tool" function_args = {"a": "1", "b": "2"} - messages = [ - HumanMessage(content="What is 1 + 2"), - AIMessage( - content="", - tool_calls=[ - { - "name": function_name, - "args": function_args, - "id": "abc123", - }, - ], - ), - ToolMessage( - name=function_name, - content=json.dumps({"result": 3}), - tool_call_id="abc123", - ), - ] + human_message = HumanMessage(content="What is 1 + 2") + tool_message = ToolMessage( + name=function_name, + content=json.dumps({"result": 3}), + tool_call_id="abc123", + ) + # String content (e.g., OpenAI) + string_content_msg = AIMessage( + content="", + tool_calls=[ + { + "name": function_name, + "args": function_args, + "id": "abc123", + }, + ], + ) + messages = [ + human_message, + string_content_msg, + tool_message, + ] + result = model_with_tools.invoke(messages) + assert isinstance(result, AIMessage) + + # List content (e.g., Anthropic) + list_content_msg = AIMessage( + content=[ + {"type": "text", "text": "some text"}, + { + "type": "tool_use", + "id": "abc123", + "name": function_name, + "input": function_args, + }, + ], + tool_calls=[ + { + "name": function_name, + "args": function_args, + "id": "abc123", + }, + ], + ) + messages = [ + human_message, + list_content_msg, + tool_message, + ] result = model_with_tools.invoke(messages) assert isinstance(result, AIMessage)