You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/community/tests/unit_tests/chat_models/test_edenai.py

71 lines
2.1 KiB
Python

"""Test EdenAI Chat API wrapper."""
from typing import List
import pytest
from langchain_core.messages import (
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_community.chat_models.edenai import (
_extract_edenai_tool_results_from_messages,
_format_edenai_messages,
_message_role,
)
@pytest.mark.parametrize(
("messages", "expected"),
[
(
[
SystemMessage(content="Translate the text from English to French"),
HumanMessage(content="Hello how are you today?"),
],
{
"text": "Hello how are you today?",
"previous_history": [],
"chatbot_global_action": "Translate the text from English to French",
"tool_results": [],
},
)
],
)
def test_edenai_messages_formatting(messages: List[BaseMessage], expected: str) -> None:
result = _format_edenai_messages(messages)
assert result == expected
@pytest.mark.parametrize(
("role", "role_response"),
[("ai", "assistant"), ("human", "user"), ("chat", "user")],
)
def test_edenai_message_role(role: str, role_response) -> None: # type: ignore[no-untyped-def]
role = _message_role(role)
assert role == role_response
def test_extract_edenai_tool_results_mixed_messages() -> None:
fake_other_msg = BaseMessage(content="content", type="other message")
messages = [
fake_other_msg,
ToolMessage(tool_call_id="id1", content="result1"),
fake_other_msg,
ToolMessage(tool_call_id="id2", content="result2"),
ToolMessage(tool_call_id="id3", content="result3"),
]
expected_tool_results = [
{"id": "id2", "result": "result2"},
{"id": "id3", "result": "result3"},
]
expected_other_messages = [
fake_other_msg,
ToolMessage(tool_call_id="id1", content="result1"),
fake_other_msg,
]
tool_results, other_messages = _extract_edenai_tool_results_from_messages(messages)
assert tool_results == expected_tool_results
assert other_messages == expected_other_messages