mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
mistral[patch]: translate tool call IDs to mistral compatible format (#24668)
Mistral appears to have added validation for the format of its tool call IDs: `{"object":"error","message":"Tool call id was abc123 but must be a-z, A-Z, 0-9, with a length of 9.","type":"invalid_request_error","param":null,"code":null}` This breaks compatibility of messages from other providers. Here we add a function that converts any string to a Mistral-valid tool call ID, and apply it to incoming messages.
This commit is contained in:
parent
38d30e285a
commit
dfbd12b384
@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
@ -77,6 +79,9 @@ from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Mistral enforces a specific pattern for tool call IDs
|
||||
TOOL_CALL_ID_PATTERN = re.compile(r"^[a-zA-Z0-9]{9}$")
|
||||
|
||||
|
||||
def _create_retry_decorator(
|
||||
llm: ChatMistralAI,
|
||||
@ -92,6 +97,39 @@ def _create_retry_decorator(
|
||||
)
|
||||
|
||||
|
||||
def _is_valid_mistral_tool_call_id(tool_call_id: str) -> bool:
|
||||
"""Check if tool call ID is nine character string consisting of a-z, A-Z, 0-9"""
|
||||
return bool(TOOL_CALL_ID_PATTERN.match(tool_call_id))
|
||||
|
||||
|
||||
def _base62_encode(num: int) -> str:
|
||||
"""Encodes a number in base62 and ensures result is of a specified length."""
|
||||
base62 = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
if num == 0:
|
||||
return base62[0]
|
||||
arr = []
|
||||
base = len(base62)
|
||||
while num:
|
||||
num, rem = divmod(num, base)
|
||||
arr.append(base62[rem])
|
||||
arr.reverse()
|
||||
return "".join(arr)
|
||||
|
||||
|
||||
def _convert_tool_call_id_to_mistral_compatible(tool_call_id: str) -> str:
|
||||
"""Convert a tool call ID to a Mistral-compatible format"""
|
||||
if _is_valid_mistral_tool_call_id(tool_call_id):
|
||||
return tool_call_id
|
||||
else:
|
||||
hash_bytes = hashlib.sha256(tool_call_id.encode()).digest()
|
||||
hash_int = int.from_bytes(hash_bytes, byteorder="big")
|
||||
base62_str = _base62_encode(hash_int)
|
||||
if len(base62_str) >= 9:
|
||||
return base62_str[:9]
|
||||
else:
|
||||
return base62_str.rjust(9, "0")
|
||||
|
||||
|
||||
def _convert_mistral_chat_message_to_message(
|
||||
_message: Dict,
|
||||
) -> BaseMessage:
|
||||
@ -246,7 +284,7 @@ def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
|
||||
}
|
||||
}
|
||||
if _id := tool_call.get("id"):
|
||||
result["id"] = _id
|
||||
result["id"] = _convert_tool_call_id_to_mistral_compatible(_id)
|
||||
|
||||
return result
|
||||
|
||||
@ -260,7 +298,7 @@ def _format_invalid_tool_call_for_mistral(invalid_tool_call: InvalidToolCall) ->
|
||||
}
|
||||
}
|
||||
if _id := invalid_tool_call.get("id"):
|
||||
result["id"] = _id
|
||||
result["id"] = _convert_tool_call_id_to_mistral_compatible(_id)
|
||||
|
||||
return result
|
||||
|
||||
|
@ -21,6 +21,8 @@ from langchain_mistralai.chat_models import ( # type: ignore[import]
|
||||
ChatMistralAI,
|
||||
_convert_message_to_mistral_chat_message,
|
||||
_convert_mistral_chat_message_to_message,
|
||||
_convert_tool_call_id_to_mistral_compatible,
|
||||
_is_valid_mistral_tool_call_id,
|
||||
)
|
||||
|
||||
os.environ["MISTRAL_API_KEY"] = "foo"
|
||||
@ -128,7 +130,7 @@ async def test_astream_with_callback() -> None:
|
||||
|
||||
def test__convert_dict_to_message_tool_call() -> None:
|
||||
raw_tool_call = {
|
||||
"id": "abc123",
|
||||
"id": "ssAbar4Dr",
|
||||
"function": {
|
||||
"arguments": '{"name": "Sally", "hair_color": "green"}',
|
||||
"name": "GenerateUsername",
|
||||
@ -143,7 +145,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
||||
ToolCall(
|
||||
name="GenerateUsername",
|
||||
args={"name": "Sally", "hair_color": "green"},
|
||||
id="abc123",
|
||||
id="ssAbar4Dr",
|
||||
type="tool_call",
|
||||
)
|
||||
],
|
||||
@ -154,14 +156,14 @@ def test__convert_dict_to_message_tool_call() -> None:
|
||||
# Test malformed tool call
|
||||
raw_tool_calls = [
|
||||
{
|
||||
"id": "def456",
|
||||
"id": "pL5rEGzxe",
|
||||
"function": {
|
||||
"arguments": '{"name": "Sally", "hair_color": "green"}',
|
||||
"name": "GenerateUsername",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "abc123",
|
||||
"id": "ssAbar4Dr",
|
||||
"function": {
|
||||
"arguments": "oops",
|
||||
"name": "GenerateUsername",
|
||||
@ -178,7 +180,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
||||
name="GenerateUsername",
|
||||
args="oops",
|
||||
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
|
||||
id="abc123",
|
||||
id="ssAbar4Dr",
|
||||
type="invalid_tool_call",
|
||||
),
|
||||
],
|
||||
@ -186,7 +188,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
||||
ToolCall(
|
||||
name="GenerateUsername",
|
||||
args={"name": "Sally", "hair_color": "green"},
|
||||
id="def456",
|
||||
id="pL5rEGzxe",
|
||||
type="tool_call",
|
||||
),
|
||||
],
|
||||
@ -201,3 +203,18 @@ def test_custom_token_counting() -> None:
|
||||
|
||||
llm = ChatMistralAI(custom_get_token_ids=token_encoder)
|
||||
assert llm.get_token_ids("foo") == [1, 2, 3]
|
||||
|
||||
|
||||
def test_tool_id_conversion() -> None:
|
||||
assert _is_valid_mistral_tool_call_id("ssAbar4Dr")
|
||||
assert not _is_valid_mistral_tool_call_id("abc123")
|
||||
assert not _is_valid_mistral_tool_call_id("call_JIIjI55tTipFFzpcP8re3BpM")
|
||||
|
||||
result_map = {
|
||||
"ssAbar4Dr": "ssAbar4Dr",
|
||||
"abc123": "pL5rEGzxe",
|
||||
"call_JIIjI55tTipFFzpcP8re3BpM": "8kxAQvoED",
|
||||
}
|
||||
for input_id, expected_output in result_map.items():
|
||||
assert _convert_tool_call_id_to_mistral_compatible(input_id) == expected_output
|
||||
assert _is_valid_mistral_tool_call_id(expected_output)
|
||||
|
Loading…
Reference in New Issue
Block a user