import json from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest from langchain_core.messages import ( AIMessage, FunctionMessage, HumanMessage, SystemMessage, ToolMessage, ) from langchain_openai.chat_models.base import ( _convert_dict_to_message, _convert_message_to_dict, ) from langchain_together import ChatTogether def test_initialization() -> None: """Test chat model initialization.""" ChatTogether() def test_together_model_param() -> None: llm = ChatTogether(model="foo") assert llm.model_name == "foo" llm = ChatTogether(model_name="foo") assert llm.model_name == "foo" ls_params = llm._get_ls_params() assert ls_params["ls_provider"] == "together" def test_function_dict_to_message_function_message() -> None: content = json.dumps({"result": "Example #1"}) name = "test_function" result = _convert_dict_to_message( { "role": "function", "name": name, "content": content, } ) assert isinstance(result, FunctionMessage) assert result.name == name assert result.content == content def test_convert_dict_to_message_human() -> None: message = {"role": "user", "content": "foo"} result = _convert_dict_to_message(message) expected_output = HumanMessage(content="foo") assert result == expected_output assert _convert_message_to_dict(expected_output) == message def test__convert_dict_to_message_human_with_name() -> None: message = {"role": "user", "content": "foo", "name": "test"} result = _convert_dict_to_message(message) expected_output = HumanMessage(content="foo", name="test") assert result == expected_output assert _convert_message_to_dict(expected_output) == message def test_convert_dict_to_message_ai() -> None: message = {"role": "assistant", "content": "foo"} result = _convert_dict_to_message(message) expected_output = AIMessage(content="foo") assert result == expected_output assert _convert_message_to_dict(expected_output) == message def test_convert_dict_to_message_ai_with_name() -> None: message = {"role": "assistant", "content": "foo", "name": "test"} result = _convert_dict_to_message(message) expected_output = AIMessage(content="foo", name="test") assert result == expected_output assert _convert_message_to_dict(expected_output) == message def test_convert_dict_to_message_system() -> None: message = {"role": "system", "content": "foo"} result = _convert_dict_to_message(message) expected_output = SystemMessage(content="foo") assert result == expected_output assert _convert_message_to_dict(expected_output) == message def test_convert_dict_to_message_system_with_name() -> None: message = {"role": "system", "content": "foo", "name": "test"} result = _convert_dict_to_message(message) expected_output = SystemMessage(content="foo", name="test") assert result == expected_output assert _convert_message_to_dict(expected_output) == message def test_convert_dict_to_message_tool() -> None: message = {"role": "tool", "content": "foo", "tool_call_id": "bar"} result = _convert_dict_to_message(message) expected_output = ToolMessage(content="foo", tool_call_id="bar") assert result == expected_output assert _convert_message_to_dict(expected_output) == message @pytest.fixture def mock_completion() -> dict: return { "id": "chatcmpl-7fcZavknQda3SQ", "object": "chat.completion", "created": 1689989000, "model": "meta-llama/Llama-3-8b-chat-hf", "choices": [ { "index": 0, "message": { "role": "assistant", "content": "Bab", "name": "KimSolar", }, "finish_reason": "stop", } ], } def test_together_invoke(mock_completion: dict) -> None: llm = ChatTogether() mock_client = MagicMock() completed = False def mock_create(*args: Any, **kwargs: Any) -> Any: nonlocal completed completed = True return mock_completion mock_client.create = mock_create with patch.object( llm, "client", mock_client, ): res = llm.invoke("bab") assert res.content == "Bab" assert completed async def test_together_ainvoke(mock_completion: dict) -> None: llm = ChatTogether() mock_client = AsyncMock() completed = False async def mock_create(*args: Any, **kwargs: Any) -> Any: nonlocal completed completed = True return mock_completion mock_client.create = mock_create with patch.object( llm, "async_client", mock_client, ): res = await llm.ainvoke("bab") assert res.content == "Bab" assert completed def test_together_invoke_name(mock_completion: dict) -> None: llm = ChatTogether() mock_client = MagicMock() mock_client.create.return_value = mock_completion with patch.object( llm, "client", mock_client, ): messages = [ HumanMessage(content="Foo", name="Zorba"), ] res = llm.invoke(messages) call_args, call_kwargs = mock_client.create.call_args assert len(call_args) == 0 # no positional args call_messages = call_kwargs["messages"] assert len(call_messages) == 1 assert call_messages[0]["role"] == "user" assert call_messages[0]["content"] == "Foo" assert call_messages[0]["name"] == "Zorba" # check return type has name assert res.content == "Bab" assert res.name == "KimSolar"