From 4a1795190004bbb8badfab6ce1f2a8416c3c2558 Mon Sep 17 00:00:00 2001 From: ccurme Date: Wed, 17 Apr 2024 13:38:24 -0400 Subject: [PATCH] mistral: read tool calls from AIMessage (#20554) Co-authored-by: Eugene Yurtsev --- .../langchain_mistralai/chat_models.py | 44 +++++++++++++++++-- .../integration_tests/test_chat_models.py | 6 +-- .../tests/integration_tests/test_standard.py | 7 +++ .../tests/unit_tests/test_chat_models.py | 10 ++--- 4 files changed, 56 insertions(+), 11 deletions(-) diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index 3b41e6cf3e..159fd45f76 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import logging import uuid from operator import itemgetter @@ -42,8 +43,10 @@ from langchain_core.messages import ( ChatMessageChunk, HumanMessage, HumanMessageChunk, + InvalidToolCall, SystemMessage, SystemMessageChunk, + ToolCall, ToolMessage, ) from langchain_core.output_parsers.base import OutputParserLike @@ -223,6 +226,34 @@ def _convert_delta_to_message_chunk( return default_class(content=content) +def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict: + """Format Langchain ToolCall to dict expected by Mistral.""" + result: Dict[str, Any] = { + "function": { + "name": tool_call["name"], + "arguments": json.dumps(tool_call["args"]), + } + } + if _id := tool_call.get("id"): + result["id"] = _id + + return result + + +def _format_invalid_tool_call_for_mistral(invalid_tool_call: InvalidToolCall) -> dict: + """Format Langchain InvalidToolCall to dict expected by Mistral.""" + result: Dict[str, Any] = { + "function": { + "name": invalid_tool_call["name"], + "arguments": invalid_tool_call["args"], + } + } + if _id := invalid_tool_call.get("id"): + result["id"] = _id + + return result + + def _convert_message_to_mistral_chat_message( message: BaseMessage, ) -> Dict: @@ -231,8 +262,15 @@ def _convert_message_to_mistral_chat_message( elif isinstance(message, HumanMessage): return dict(role="user", content=message.content) elif isinstance(message, AIMessage): - if "tool_calls" in message.additional_kwargs: - tool_calls = [] + tool_calls = [] + if message.tool_calls or message.invalid_tool_calls: + for tool_call in message.tool_calls: + tool_calls.append(_format_tool_call_for_mistral(tool_call)) + for invalid_tool_call in message.invalid_tool_calls: + tool_calls.append( + _format_invalid_tool_call_for_mistral(invalid_tool_call) + ) + elif "tool_calls" in message.additional_kwargs: for tc in message.additional_kwargs["tool_calls"]: chunk = { "function": { @@ -244,7 +282,7 @@ def _convert_message_to_mistral_chat_message( chunk["id"] = _id tool_calls.append(chunk) else: - tool_calls = [] + pass return { "role": "assistant", "content": message.content, diff --git a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py index 7dd19a4b9c..4bf576ac53 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py @@ -138,7 +138,7 @@ def test_structured_output() -> None: def test_streaming_structured_output() -> None: - llm = ChatMistralAI(model="mistral-large", temperature=0) + llm = ChatMistralAI(model="mistral-large-latest", temperature=0) class Person(BaseModel): name: str @@ -156,7 +156,7 @@ def test_streaming_structured_output() -> None: def test_tool_call() -> None: - llm = ChatMistralAI(model="mistral-large", temperature=0) + llm = ChatMistralAI(model="mistral-large-latest", temperature=0) class Person(BaseModel): name: str @@ -173,7 +173,7 @@ def test_tool_call() -> None: def test_streaming_tool_call() -> None: - llm = ChatMistralAI(model="mistral-large", temperature=0) + llm = ChatMistralAI(model="mistral-large-latest", temperature=0) class Person(BaseModel): name: str diff --git a/libs/partners/mistralai/tests/integration_tests/test_standard.py b/libs/partners/mistralai/tests/integration_tests/test_standard.py index 5e589955f2..d9b8ff1969 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_standard.py +++ b/libs/partners/mistralai/tests/integration_tests/test_standard.py @@ -13,3 +13,10 @@ class TestMistralStandard(ChatModelIntegrationTests): @pytest.fixture def chat_model_class(self) -> Type[BaseChatModel]: return ChatMistralAI + + @pytest.fixture + def chat_model_params(self) -> dict: + return { + "model": "mistral-large-latest", + "temperature": 0, + } diff --git a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py index 96c637b5a2..ab70d02d45 100644 --- a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py @@ -130,7 +130,7 @@ def test__convert_dict_to_message_tool_call() -> None: raw_tool_call = { "id": "abc123", "function": { - "arguments": '{"name":"Sally","hair_color":"green"}', + "arguments": '{"name": "Sally", "hair_color": "green"}', "name": "GenerateUsername", }, } @@ -153,16 +153,16 @@ def test__convert_dict_to_message_tool_call() -> None: # Test malformed tool call raw_tool_calls = [ { - "id": "abc123", + "id": "def456", "function": { - "arguments": "oops", + "arguments": '{"name": "Sally", "hair_color": "green"}', "name": "GenerateUsername", }, }, { - "id": "def456", + "id": "abc123", "function": { - "arguments": '{"name":"Sally","hair_color":"green"}', + "arguments": "oops", "name": "GenerateUsername", }, },