langchain/libs/partners/groq/tests/unit_tests/test_chat_models.py

278 lines
8.2 KiB
Python
Raw Normal View History

"""Test Groq Chat API wrapper."""
import json
import os
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import langchain_core.load as lc_load
import pytest
from langchain_core.messages import (
AIMessage,
FunctionMessage,
HumanMessage,
core[minor], ...: add tool calls message (#18947) core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor] ```python class ToolCall(TypedDict): name: str args: Dict[str, Any] id: Optional[str] class InvalidToolCall(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] error: Optional[str] class ToolCallChunk(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] index: Optional[int] class AIMessage(BaseMessage): ... tool_calls: List[ToolCall] = [] invalid_tool_calls: List[InvalidToolCall] = [] ... class AIMessageChunk(AIMessage, BaseMessageChunk): ... tool_call_chunks: Optional[List[ToolCallChunk]] = None ... ``` Important considerations: - Parsing logic occurs within different providers; - ~Changing output type is a breaking change for anyone doing explicit type checking;~ - ~Langsmith rendering will need to be updated: https://github.com/langchain-ai/langchainplus/pull/3561~ - ~Langserve will need to be updated~ - Adding chunks: - ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has non-null .tool_calls.~ - Tool call chunks are appended, merging when having equal values of `index`. - additional_kwargs accumulate the normal way. - During streaming: - ~Messages can change types (e.g., from AIMessageChunk to AIToolCallsMessageChunk)~ - Output parsers parse additional_kwargs (during .invoke they read off tool calls). Packages outside of `partners/`: - https://github.com/langchain-ai/langchain-cohere/pull/7 - https://github.com/langchain-ai/langchain-google/pull/123/files --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
InvalidToolCall,
SystemMessage,
core[minor], ...: add tool calls message (#18947) core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor] ```python class ToolCall(TypedDict): name: str args: Dict[str, Any] id: Optional[str] class InvalidToolCall(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] error: Optional[str] class ToolCallChunk(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] index: Optional[int] class AIMessage(BaseMessage): ... tool_calls: List[ToolCall] = [] invalid_tool_calls: List[InvalidToolCall] = [] ... class AIMessageChunk(AIMessage, BaseMessageChunk): ... tool_call_chunks: Optional[List[ToolCallChunk]] = None ... ``` Important considerations: - Parsing logic occurs within different providers; - ~Changing output type is a breaking change for anyone doing explicit type checking;~ - ~Langsmith rendering will need to be updated: https://github.com/langchain-ai/langchainplus/pull/3561~ - ~Langserve will need to be updated~ - Adding chunks: - ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has non-null .tool_calls.~ - Tool call chunks are appended, merging when having equal values of `index`. - additional_kwargs accumulate the normal way. - During streaming: - ~Messages can change types (e.g., from AIMessageChunk to AIToolCallsMessageChunk)~ - Output parsers parse additional_kwargs (during .invoke they read off tool calls). Packages outside of `partners/`: - https://github.com/langchain-ai/langchain-cohere/pull/7 - https://github.com/langchain-ai/langchain-google/pull/123/files --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
ToolCall,
)
from langchain_groq.chat_models import ChatGroq, _convert_dict_to_message
if "GROQ_API_KEY" not in os.environ:
os.environ["GROQ_API_KEY"] = "fake-key"
def test_groq_model_param() -> None:
llm = ChatGroq(model="foo")
assert llm.model_name == "foo"
llm = ChatGroq(model_name="foo")
assert llm.model_name == "foo"
def test_function_message_dict_to_function_message() -> None:
content = json.dumps({"result": "Example #1"})
name = "test_function"
result = _convert_dict_to_message(
{
"role": "function",
"name": name,
"content": content,
}
)
assert isinstance(result, FunctionMessage)
assert result.name == name
assert result.content == content
def test__convert_dict_to_message_human() -> None:
message = {"role": "user", "content": "foo"}
result = _convert_dict_to_message(message)
expected_output = HumanMessage(content="foo")
assert result == expected_output
def test__convert_dict_to_message_ai() -> None:
message = {"role": "assistant", "content": "foo"}
result = _convert_dict_to_message(message)
expected_output = AIMessage(content="foo")
assert result == expected_output
core[minor], ...: add tool calls message (#18947) core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor] ```python class ToolCall(TypedDict): name: str args: Dict[str, Any] id: Optional[str] class InvalidToolCall(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] error: Optional[str] class ToolCallChunk(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] index: Optional[int] class AIMessage(BaseMessage): ... tool_calls: List[ToolCall] = [] invalid_tool_calls: List[InvalidToolCall] = [] ... class AIMessageChunk(AIMessage, BaseMessageChunk): ... tool_call_chunks: Optional[List[ToolCallChunk]] = None ... ``` Important considerations: - Parsing logic occurs within different providers; - ~Changing output type is a breaking change for anyone doing explicit type checking;~ - ~Langsmith rendering will need to be updated: https://github.com/langchain-ai/langchainplus/pull/3561~ - ~Langserve will need to be updated~ - Adding chunks: - ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has non-null .tool_calls.~ - Tool call chunks are appended, merging when having equal values of `index`. - additional_kwargs accumulate the normal way. - During streaming: - ~Messages can change types (e.g., from AIMessageChunk to AIToolCallsMessageChunk)~ - Output parsers parse additional_kwargs (during .invoke they read off tool calls). Packages outside of `partners/`: - https://github.com/langchain-ai/langchain-cohere/pull/7 - https://github.com/langchain-ai/langchain-google/pull/123/files --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
def test__convert_dict_to_message_tool_call() -> None:
raw_tool_call = {
"id": "call_wm0JY6CdwOMZ4eTxHWUThDNz",
"function": {
"arguments": '{"name":"Sally","hair_color":"green"}',
"name": "GenerateUsername",
},
"type": "function",
}
message = {"role": "assistant", "content": None, "tool_calls": [raw_tool_call]}
result = _convert_dict_to_message(message)
expected_output = AIMessage(
content="",
additional_kwargs={"tool_calls": [raw_tool_call]},
tool_calls=[
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
)
],
)
assert result == expected_output
# Test malformed tool call
raw_tool_calls = [
{
"id": "call_wm0JY6CdwOMZ4eTxHWUThDNz",
"function": {
"arguments": "oops",
"name": "GenerateUsername",
},
"type": "function",
},
{
"id": "call_abc123",
"function": {
"arguments": '{"name":"Sally","hair_color":"green"}',
"name": "GenerateUsername",
},
"type": "function",
},
]
message = {"role": "assistant", "content": None, "tool_calls": raw_tool_calls}
result = _convert_dict_to_message(message)
expected_output = AIMessage(
content="",
additional_kwargs={"tool_calls": raw_tool_calls},
invalid_tool_calls=[
InvalidToolCall(
name="GenerateUsername",
args="oops",
id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
),
],
tool_calls=[
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="call_abc123",
),
],
)
assert result == expected_output
def test__convert_dict_to_message_system() -> None:
message = {"role": "system", "content": "foo"}
result = _convert_dict_to_message(message)
expected_output = SystemMessage(content="foo")
assert result == expected_output
@pytest.fixture
def mock_completion() -> dict:
return {
"id": "chatcmpl-7fcZavknQda3SQ",
"object": "chat.completion",
"created": 1689989000,
"model": "test-model",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Bar Baz",
},
"finish_reason": "stop",
}
],
}
def test_groq_invoke(mock_completion: dict) -> None:
llm = ChatGroq()
mock_client = MagicMock()
completed = False
def mock_create(*args: Any, **kwargs: Any) -> Any:
nonlocal completed
completed = True
return mock_completion
mock_client.create = mock_create
with patch.object(
llm,
"client",
mock_client,
):
res = llm.invoke("bar")
assert res.content == "Bar Baz"
assert type(res) == AIMessage
assert completed
async def test_groq_ainvoke(mock_completion: dict) -> None:
llm = ChatGroq()
mock_client = AsyncMock()
completed = False
async def mock_create(*args: Any, **kwargs: Any) -> Any:
nonlocal completed
completed = True
return mock_completion
mock_client.create = mock_create
with patch.object(
llm,
"async_client",
mock_client,
):
res = await llm.ainvoke("bar")
assert res.content == "Bar Baz"
assert type(res) == AIMessage
assert completed
def test_chat_groq_extra_kwargs() -> None:
"""Test extra kwargs to chat groq."""
# Check that foo is saved in extra_kwargs.
with pytest.warns(UserWarning) as record:
llm = ChatGroq(foo=3, max_tokens=10)
assert llm.max_tokens == 10
assert llm.model_kwargs == {"foo": 3}
assert len(record) == 1
assert type(record[0].message) is UserWarning
assert "foo is not default parameter" in record[0].message.args[0]
# Test that if extra_kwargs are provided, they are added to it.
with pytest.warns(UserWarning) as record:
llm = ChatGroq(foo=3, model_kwargs={"bar": 2})
assert llm.model_kwargs == {"foo": 3, "bar": 2}
assert len(record) == 1
assert type(record[0].message) is UserWarning
assert "foo is not default parameter" in record[0].message.args[0]
# Test that if provided twice it errors
with pytest.raises(ValueError):
ChatGroq(foo=3, model_kwargs={"foo": 2})
# Test that if explicit param is specified in kwargs it errors
with pytest.raises(ValueError):
ChatGroq(model_kwargs={"temperature": 0.2})
# Test that "model" cannot be specified in kwargs
with pytest.raises(ValueError):
ChatGroq(model_kwargs={"model": "test-model"})
def test_chat_groq_invalid_streaming_params() -> None:
"""Test that an error is raised if streaming is invoked with n>1."""
with pytest.raises(ValueError):
ChatGroq(
max_tokens=10,
streaming=True,
temperature=0,
n=5,
)
def test_chat_groq_secret() -> None:
"""Test that secret is not printed"""
secret = "secretKey"
not_secret = "safe"
llm = ChatGroq(api_key=secret, model_kwargs={"not_secret": not_secret})
stringified = str(llm)
assert not_secret in stringified
assert secret not in stringified
@pytest.mark.filterwarnings("ignore:The function `loads` is in beta")
def test_groq_serialization() -> None:
"""Test that ChatGroq can be successfully serialized and deserialized"""
api_key1 = "top secret"
api_key2 = "topest secret"
llm = ChatGroq(api_key=api_key1, temperature=0.5)
dump = lc_load.dumps(llm)
llm2 = lc_load.loads(
dump,
valid_namespaces=["langchain_groq"],
secrets_map={"GROQ_API_KEY": api_key2},
)
assert type(llm2) is ChatGroq
# Ensure api key wasn't dumped and instead was read from secret map.
assert llm.groq_api_key is not None
assert llm.groq_api_key.get_secret_value() not in dump
assert llm2.groq_api_key is not None
assert llm2.groq_api_key.get_secret_value() == api_key2
# Ensure a non-secret field was preserved
assert llm.temperature == llm2.temperature
# Ensure a None was preserved
assert llm.groq_api_base == llm2.groq_api_base