langchain/libs/partners/mistralai/tests/unit_tests/test_chat_models.py

201 lines
5.9 KiB
Python
Raw Normal View History

"""Test MistralAI Chat API wrapper."""
import os
from typing import Any, AsyncGenerator, Dict, Generator, List, cast
from unittest.mock import patch
import pytest
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.messages import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
core[minor], ...: add tool calls message (#18947) core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor] ```python class ToolCall(TypedDict): name: str args: Dict[str, Any] id: Optional[str] class InvalidToolCall(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] error: Optional[str] class ToolCallChunk(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] index: Optional[int] class AIMessage(BaseMessage): ... tool_calls: List[ToolCall] = [] invalid_tool_calls: List[InvalidToolCall] = [] ... class AIMessageChunk(AIMessage, BaseMessageChunk): ... tool_call_chunks: Optional[List[ToolCallChunk]] = None ... ``` Important considerations: - Parsing logic occurs within different providers; - ~Changing output type is a breaking change for anyone doing explicit type checking;~ - ~Langsmith rendering will need to be updated: https://github.com/langchain-ai/langchainplus/pull/3561~ - ~Langserve will need to be updated~ - Adding chunks: - ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has non-null .tool_calls.~ - Tool call chunks are appended, merging when having equal values of `index`. - additional_kwargs accumulate the normal way. - During streaming: - ~Messages can change types (e.g., from AIMessageChunk to AIToolCallsMessageChunk)~ - Output parsers parse additional_kwargs (during .invoke they read off tool calls). Packages outside of `partners/`: - https://github.com/langchain-ai/langchain-cohere/pull/7 - https://github.com/langchain-ai/langchain-google/pull/123/files --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
InvalidToolCall,
SystemMessage,
core[minor], ...: add tool calls message (#18947) core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor] ```python class ToolCall(TypedDict): name: str args: Dict[str, Any] id: Optional[str] class InvalidToolCall(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] error: Optional[str] class ToolCallChunk(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] index: Optional[int] class AIMessage(BaseMessage): ... tool_calls: List[ToolCall] = [] invalid_tool_calls: List[InvalidToolCall] = [] ... class AIMessageChunk(AIMessage, BaseMessageChunk): ... tool_call_chunks: Optional[List[ToolCallChunk]] = None ... ``` Important considerations: - Parsing logic occurs within different providers; - ~Changing output type is a breaking change for anyone doing explicit type checking;~ - ~Langsmith rendering will need to be updated: https://github.com/langchain-ai/langchainplus/pull/3561~ - ~Langserve will need to be updated~ - Adding chunks: - ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has non-null .tool_calls.~ - Tool call chunks are appended, merging when having equal values of `index`. - additional_kwargs accumulate the normal way. - During streaming: - ~Messages can change types (e.g., from AIMessageChunk to AIToolCallsMessageChunk)~ - Output parsers parse additional_kwargs (during .invoke they read off tool calls). Packages outside of `partners/`: - https://github.com/langchain-ai/langchain-cohere/pull/7 - https://github.com/langchain-ai/langchain-google/pull/123/files --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
ToolCall,
)
from langchain_core.pydantic_v1 import SecretStr
from langchain_mistralai.chat_models import ( # type: ignore[import]
ChatMistralAI,
_convert_message_to_mistral_chat_message,
core[minor], ...: add tool calls message (#18947) core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor] ```python class ToolCall(TypedDict): name: str args: Dict[str, Any] id: Optional[str] class InvalidToolCall(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] error: Optional[str] class ToolCallChunk(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] index: Optional[int] class AIMessage(BaseMessage): ... tool_calls: List[ToolCall] = [] invalid_tool_calls: List[InvalidToolCall] = [] ... class AIMessageChunk(AIMessage, BaseMessageChunk): ... tool_call_chunks: Optional[List[ToolCallChunk]] = None ... ``` Important considerations: - Parsing logic occurs within different providers; - ~Changing output type is a breaking change for anyone doing explicit type checking;~ - ~Langsmith rendering will need to be updated: https://github.com/langchain-ai/langchainplus/pull/3561~ - ~Langserve will need to be updated~ - Adding chunks: - ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has non-null .tool_calls.~ - Tool call chunks are appended, merging when having equal values of `index`. - additional_kwargs accumulate the normal way. - During streaming: - ~Messages can change types (e.g., from AIMessageChunk to AIToolCallsMessageChunk)~ - Output parsers parse additional_kwargs (during .invoke they read off tool calls). Packages outside of `partners/`: - https://github.com/langchain-ai/langchain-cohere/pull/7 - https://github.com/langchain-ai/langchain-google/pull/123/files --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
_convert_mistral_chat_message_to_message,
)
os.environ["MISTRAL_API_KEY"] = "foo"
def test_mistralai_model_param() -> None:
llm = ChatMistralAI(model="foo")
assert llm.model == "foo"
def test_mistralai_initialization() -> None:
"""Test ChatMistralAI initialization."""
# Verify that ChatMistralAI can be initialized using a secret key provided
# as a parameter rather than an environment variable.
for model in [
ChatMistralAI(model="test", mistral_api_key="test"),
ChatMistralAI(model="test", api_key="test"),
]:
assert cast(SecretStr, model.mistral_api_key).get_secret_value() == "test"
@pytest.mark.parametrize(
("message", "expected"),
[
(
SystemMessage(content="Hello"),
dict(role="system", content="Hello"),
),
(
HumanMessage(content="Hello"),
dict(role="user", content="Hello"),
),
(
AIMessage(content="Hello"),
dict(role="assistant", content="Hello"),
),
(
ChatMessage(role="assistant", content="Hello"),
dict(role="assistant", content="Hello"),
),
],
)
def test_convert_message_to_mistral_chat_message(
message: BaseMessage, expected: Dict
) -> None:
result = _convert_message_to_mistral_chat_message(message)
assert result == expected
def _make_completion_response_from_token(token: str) -> Dict:
return dict(
id="abc123",
model="fake_model",
choices=[
dict(
index=0,
delta=dict(content=token),
finish_reason=None,
)
],
)
def mock_chat_stream(*args: Any, **kwargs: Any) -> Generator:
def it() -> Generator:
for token in ["Hello", " how", " can", " I", " help", "?"]:
yield _make_completion_response_from_token(token)
return it()
async def mock_chat_astream(*args: Any, **kwargs: Any) -> AsyncGenerator:
async def it() -> AsyncGenerator:
for token in ["Hello", " how", " can", " I", " help", "?"]:
yield _make_completion_response_from_token(token)
return it()
class MyCustomHandler(BaseCallbackHandler):
last_token: str = ""
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
self.last_token = token
@patch(
"langchain_mistralai.chat_models.ChatMistralAI.completion_with_retry",
new=mock_chat_stream,
)
def test_stream_with_callback() -> None:
callback = MyCustomHandler()
chat = ChatMistralAI(callbacks=[callback])
for token in chat.stream("Hello"):
assert callback.last_token == token.content
@patch("langchain_mistralai.chat_models.acompletion_with_retry", new=mock_chat_astream)
async def test_astream_with_callback() -> None:
callback = MyCustomHandler()
chat = ChatMistralAI(callbacks=[callback])
async for token in chat.astream("Hello"):
assert callback.last_token == token.content
core[minor], ...: add tool calls message (#18947) core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor] ```python class ToolCall(TypedDict): name: str args: Dict[str, Any] id: Optional[str] class InvalidToolCall(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] error: Optional[str] class ToolCallChunk(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] index: Optional[int] class AIMessage(BaseMessage): ... tool_calls: List[ToolCall] = [] invalid_tool_calls: List[InvalidToolCall] = [] ... class AIMessageChunk(AIMessage, BaseMessageChunk): ... tool_call_chunks: Optional[List[ToolCallChunk]] = None ... ``` Important considerations: - Parsing logic occurs within different providers; - ~Changing output type is a breaking change for anyone doing explicit type checking;~ - ~Langsmith rendering will need to be updated: https://github.com/langchain-ai/langchainplus/pull/3561~ - ~Langserve will need to be updated~ - Adding chunks: - ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has non-null .tool_calls.~ - Tool call chunks are appended, merging when having equal values of `index`. - additional_kwargs accumulate the normal way. - During streaming: - ~Messages can change types (e.g., from AIMessageChunk to AIToolCallsMessageChunk)~ - Output parsers parse additional_kwargs (during .invoke they read off tool calls). Packages outside of `partners/`: - https://github.com/langchain-ai/langchain-cohere/pull/7 - https://github.com/langchain-ai/langchain-google/pull/123/files --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
def test__convert_dict_to_message_tool_call() -> None:
raw_tool_call = {
"id": "abc123",
core[minor], ...: add tool calls message (#18947) core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor] ```python class ToolCall(TypedDict): name: str args: Dict[str, Any] id: Optional[str] class InvalidToolCall(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] error: Optional[str] class ToolCallChunk(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] index: Optional[int] class AIMessage(BaseMessage): ... tool_calls: List[ToolCall] = [] invalid_tool_calls: List[InvalidToolCall] = [] ... class AIMessageChunk(AIMessage, BaseMessageChunk): ... tool_call_chunks: Optional[List[ToolCallChunk]] = None ... ``` Important considerations: - Parsing logic occurs within different providers; - ~Changing output type is a breaking change for anyone doing explicit type checking;~ - ~Langsmith rendering will need to be updated: https://github.com/langchain-ai/langchainplus/pull/3561~ - ~Langserve will need to be updated~ - Adding chunks: - ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has non-null .tool_calls.~ - Tool call chunks are appended, merging when having equal values of `index`. - additional_kwargs accumulate the normal way. - During streaming: - ~Messages can change types (e.g., from AIMessageChunk to AIToolCallsMessageChunk)~ - Output parsers parse additional_kwargs (during .invoke they read off tool calls). Packages outside of `partners/`: - https://github.com/langchain-ai/langchain-cohere/pull/7 - https://github.com/langchain-ai/langchain-google/pull/123/files --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
"function": {
"arguments": '{"name": "Sally", "hair_color": "green"}',
core[minor], ...: add tool calls message (#18947) core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor] ```python class ToolCall(TypedDict): name: str args: Dict[str, Any] id: Optional[str] class InvalidToolCall(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] error: Optional[str] class ToolCallChunk(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] index: Optional[int] class AIMessage(BaseMessage): ... tool_calls: List[ToolCall] = [] invalid_tool_calls: List[InvalidToolCall] = [] ... class AIMessageChunk(AIMessage, BaseMessageChunk): ... tool_call_chunks: Optional[List[ToolCallChunk]] = None ... ``` Important considerations: - Parsing logic occurs within different providers; - ~Changing output type is a breaking change for anyone doing explicit type checking;~ - ~Langsmith rendering will need to be updated: https://github.com/langchain-ai/langchainplus/pull/3561~ - ~Langserve will need to be updated~ - Adding chunks: - ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has non-null .tool_calls.~ - Tool call chunks are appended, merging when having equal values of `index`. - additional_kwargs accumulate the normal way. - During streaming: - ~Messages can change types (e.g., from AIMessageChunk to AIToolCallsMessageChunk)~ - Output parsers parse additional_kwargs (during .invoke they read off tool calls). Packages outside of `partners/`: - https://github.com/langchain-ai/langchain-cohere/pull/7 - https://github.com/langchain-ai/langchain-google/pull/123/files --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
"name": "GenerateUsername",
},
}
message = {"role": "assistant", "content": "", "tool_calls": [raw_tool_call]}
result = _convert_mistral_chat_message_to_message(message)
expected_output = AIMessage(
content="",
additional_kwargs={"tool_calls": [raw_tool_call]},
tool_calls=[
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="abc123",
core[minor], ...: add tool calls message (#18947) core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor] ```python class ToolCall(TypedDict): name: str args: Dict[str, Any] id: Optional[str] class InvalidToolCall(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] error: Optional[str] class ToolCallChunk(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] index: Optional[int] class AIMessage(BaseMessage): ... tool_calls: List[ToolCall] = [] invalid_tool_calls: List[InvalidToolCall] = [] ... class AIMessageChunk(AIMessage, BaseMessageChunk): ... tool_call_chunks: Optional[List[ToolCallChunk]] = None ... ``` Important considerations: - Parsing logic occurs within different providers; - ~Changing output type is a breaking change for anyone doing explicit type checking;~ - ~Langsmith rendering will need to be updated: https://github.com/langchain-ai/langchainplus/pull/3561~ - ~Langserve will need to be updated~ - Adding chunks: - ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has non-null .tool_calls.~ - Tool call chunks are appended, merging when having equal values of `index`. - additional_kwargs accumulate the normal way. - During streaming: - ~Messages can change types (e.g., from AIMessageChunk to AIToolCallsMessageChunk)~ - Output parsers parse additional_kwargs (during .invoke they read off tool calls). Packages outside of `partners/`: - https://github.com/langchain-ai/langchain-cohere/pull/7 - https://github.com/langchain-ai/langchain-google/pull/123/files --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
)
],
)
assert result == expected_output
assert _convert_message_to_mistral_chat_message(expected_output) == message
# Test malformed tool call
raw_tool_calls = [
{
"id": "def456",
core[minor], ...: add tool calls message (#18947) core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor] ```python class ToolCall(TypedDict): name: str args: Dict[str, Any] id: Optional[str] class InvalidToolCall(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] error: Optional[str] class ToolCallChunk(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] index: Optional[int] class AIMessage(BaseMessage): ... tool_calls: List[ToolCall] = [] invalid_tool_calls: List[InvalidToolCall] = [] ... class AIMessageChunk(AIMessage, BaseMessageChunk): ... tool_call_chunks: Optional[List[ToolCallChunk]] = None ... ``` Important considerations: - Parsing logic occurs within different providers; - ~Changing output type is a breaking change for anyone doing explicit type checking;~ - ~Langsmith rendering will need to be updated: https://github.com/langchain-ai/langchainplus/pull/3561~ - ~Langserve will need to be updated~ - Adding chunks: - ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has non-null .tool_calls.~ - Tool call chunks are appended, merging when having equal values of `index`. - additional_kwargs accumulate the normal way. - During streaming: - ~Messages can change types (e.g., from AIMessageChunk to AIToolCallsMessageChunk)~ - Output parsers parse additional_kwargs (during .invoke they read off tool calls). Packages outside of `partners/`: - https://github.com/langchain-ai/langchain-cohere/pull/7 - https://github.com/langchain-ai/langchain-google/pull/123/files --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
"function": {
"arguments": '{"name": "Sally", "hair_color": "green"}',
core[minor], ...: add tool calls message (#18947) core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor] ```python class ToolCall(TypedDict): name: str args: Dict[str, Any] id: Optional[str] class InvalidToolCall(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] error: Optional[str] class ToolCallChunk(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] index: Optional[int] class AIMessage(BaseMessage): ... tool_calls: List[ToolCall] = [] invalid_tool_calls: List[InvalidToolCall] = [] ... class AIMessageChunk(AIMessage, BaseMessageChunk): ... tool_call_chunks: Optional[List[ToolCallChunk]] = None ... ``` Important considerations: - Parsing logic occurs within different providers; - ~Changing output type is a breaking change for anyone doing explicit type checking;~ - ~Langsmith rendering will need to be updated: https://github.com/langchain-ai/langchainplus/pull/3561~ - ~Langserve will need to be updated~ - Adding chunks: - ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has non-null .tool_calls.~ - Tool call chunks are appended, merging when having equal values of `index`. - additional_kwargs accumulate the normal way. - During streaming: - ~Messages can change types (e.g., from AIMessageChunk to AIToolCallsMessageChunk)~ - Output parsers parse additional_kwargs (during .invoke they read off tool calls). Packages outside of `partners/`: - https://github.com/langchain-ai/langchain-cohere/pull/7 - https://github.com/langchain-ai/langchain-google/pull/123/files --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
"name": "GenerateUsername",
},
},
{
"id": "abc123",
core[minor], ...: add tool calls message (#18947) core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor] ```python class ToolCall(TypedDict): name: str args: Dict[str, Any] id: Optional[str] class InvalidToolCall(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] error: Optional[str] class ToolCallChunk(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] index: Optional[int] class AIMessage(BaseMessage): ... tool_calls: List[ToolCall] = [] invalid_tool_calls: List[InvalidToolCall] = [] ... class AIMessageChunk(AIMessage, BaseMessageChunk): ... tool_call_chunks: Optional[List[ToolCallChunk]] = None ... ``` Important considerations: - Parsing logic occurs within different providers; - ~Changing output type is a breaking change for anyone doing explicit type checking;~ - ~Langsmith rendering will need to be updated: https://github.com/langchain-ai/langchainplus/pull/3561~ - ~Langserve will need to be updated~ - Adding chunks: - ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has non-null .tool_calls.~ - Tool call chunks are appended, merging when having equal values of `index`. - additional_kwargs accumulate the normal way. - During streaming: - ~Messages can change types (e.g., from AIMessageChunk to AIToolCallsMessageChunk)~ - Output parsers parse additional_kwargs (during .invoke they read off tool calls). Packages outside of `partners/`: - https://github.com/langchain-ai/langchain-cohere/pull/7 - https://github.com/langchain-ai/langchain-google/pull/123/files --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
"function": {
"arguments": "oops",
core[minor], ...: add tool calls message (#18947) core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor] ```python class ToolCall(TypedDict): name: str args: Dict[str, Any] id: Optional[str] class InvalidToolCall(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] error: Optional[str] class ToolCallChunk(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] index: Optional[int] class AIMessage(BaseMessage): ... tool_calls: List[ToolCall] = [] invalid_tool_calls: List[InvalidToolCall] = [] ... class AIMessageChunk(AIMessage, BaseMessageChunk): ... tool_call_chunks: Optional[List[ToolCallChunk]] = None ... ``` Important considerations: - Parsing logic occurs within different providers; - ~Changing output type is a breaking change for anyone doing explicit type checking;~ - ~Langsmith rendering will need to be updated: https://github.com/langchain-ai/langchainplus/pull/3561~ - ~Langserve will need to be updated~ - Adding chunks: - ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has non-null .tool_calls.~ - Tool call chunks are appended, merging when having equal values of `index`. - additional_kwargs accumulate the normal way. - During streaming: - ~Messages can change types (e.g., from AIMessageChunk to AIToolCallsMessageChunk)~ - Output parsers parse additional_kwargs (during .invoke they read off tool calls). Packages outside of `partners/`: - https://github.com/langchain-ai/langchain-cohere/pull/7 - https://github.com/langchain-ai/langchain-google/pull/123/files --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
"name": "GenerateUsername",
},
},
]
message = {"role": "assistant", "content": "", "tool_calls": raw_tool_calls}
result = _convert_mistral_chat_message_to_message(message)
expected_output = AIMessage(
content="",
additional_kwargs={"tool_calls": raw_tool_calls},
invalid_tool_calls=[
InvalidToolCall(
name="GenerateUsername",
args="oops",
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
id="abc123",
core[minor], ...: add tool calls message (#18947) core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor] ```python class ToolCall(TypedDict): name: str args: Dict[str, Any] id: Optional[str] class InvalidToolCall(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] error: Optional[str] class ToolCallChunk(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] index: Optional[int] class AIMessage(BaseMessage): ... tool_calls: List[ToolCall] = [] invalid_tool_calls: List[InvalidToolCall] = [] ... class AIMessageChunk(AIMessage, BaseMessageChunk): ... tool_call_chunks: Optional[List[ToolCallChunk]] = None ... ``` Important considerations: - Parsing logic occurs within different providers; - ~Changing output type is a breaking change for anyone doing explicit type checking;~ - ~Langsmith rendering will need to be updated: https://github.com/langchain-ai/langchainplus/pull/3561~ - ~Langserve will need to be updated~ - Adding chunks: - ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has non-null .tool_calls.~ - Tool call chunks are appended, merging when having equal values of `index`. - additional_kwargs accumulate the normal way. - During streaming: - ~Messages can change types (e.g., from AIMessageChunk to AIToolCallsMessageChunk)~ - Output parsers parse additional_kwargs (during .invoke they read off tool calls). Packages outside of `partners/`: - https://github.com/langchain-ai/langchain-cohere/pull/7 - https://github.com/langchain-ai/langchain-google/pull/123/files --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
),
],
tool_calls=[
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="def456",
core[minor], ...: add tool calls message (#18947) core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor] ```python class ToolCall(TypedDict): name: str args: Dict[str, Any] id: Optional[str] class InvalidToolCall(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] error: Optional[str] class ToolCallChunk(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] index: Optional[int] class AIMessage(BaseMessage): ... tool_calls: List[ToolCall] = [] invalid_tool_calls: List[InvalidToolCall] = [] ... class AIMessageChunk(AIMessage, BaseMessageChunk): ... tool_call_chunks: Optional[List[ToolCallChunk]] = None ... ``` Important considerations: - Parsing logic occurs within different providers; - ~Changing output type is a breaking change for anyone doing explicit type checking;~ - ~Langsmith rendering will need to be updated: https://github.com/langchain-ai/langchainplus/pull/3561~ - ~Langserve will need to be updated~ - Adding chunks: - ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has non-null .tool_calls.~ - Tool call chunks are appended, merging when having equal values of `index`. - additional_kwargs accumulate the normal way. - During streaming: - ~Messages can change types (e.g., from AIMessageChunk to AIToolCallsMessageChunk)~ - Output parsers parse additional_kwargs (during .invoke they read off tool calls). Packages outside of `partners/`: - https://github.com/langchain-ai/langchain-cohere/pull/7 - https://github.com/langchain-ai/langchain-google/pull/123/files --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
),
],
)
assert result == expected_output
assert _convert_message_to_mistral_chat_message(expected_output) == message
def test_custom_token_counting() -> None:
def token_encoder(text: str) -> List[int]:
return [1, 2, 3]
llm = ChatMistralAI(custom_get_token_ids=token_encoder)
assert llm.get_token_ids("foo") == [1, 2, 3]