2024-02-22 15:36:16 +00:00
|
|
|
"""Test Groq Chat API wrapper."""
|
|
|
|
|
|
|
|
import json
|
|
|
|
import os
|
|
|
|
from typing import Any
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
|
|
|
|
import langchain_core.load as lc_load
|
|
|
|
import pytest
|
|
|
|
from langchain_core.messages import (
|
|
|
|
AIMessage,
|
|
|
|
FunctionMessage,
|
|
|
|
HumanMessage,
|
core[minor], ...: add tool calls message (#18947)
core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor]
```python
class ToolCall(TypedDict):
name: str
args: Dict[str, Any]
id: Optional[str]
class InvalidToolCall(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
error: Optional[str]
class ToolCallChunk(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
index: Optional[int]
class AIMessage(BaseMessage):
...
tool_calls: List[ToolCall] = []
invalid_tool_calls: List[InvalidToolCall] = []
...
class AIMessageChunk(AIMessage, BaseMessageChunk):
...
tool_call_chunks: Optional[List[ToolCallChunk]] = None
...
```
Important considerations:
- Parsing logic occurs within different providers;
- ~Changing output type is a breaking change for anyone doing explicit
type checking;~
- ~Langsmith rendering will need to be updated:
https://github.com/langchain-ai/langchainplus/pull/3561~
- ~Langserve will need to be updated~
- Adding chunks:
- ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has
non-null .tool_calls.~
- Tool call chunks are appended, merging when having equal values of
`index`.
- additional_kwargs accumulate the normal way.
- During streaming:
- ~Messages can change types (e.g., from AIMessageChunk to
AIToolCallsMessageChunk)~
- Output parsers parse additional_kwargs (during .invoke they read off
tool calls).
Packages outside of `partners/`:
- https://github.com/langchain-ai/langchain-cohere/pull/7
- https://github.com/langchain-ai/langchain-google/pull/123/files
---------
Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
|
|
|
InvalidToolCall,
|
2024-02-22 15:36:16 +00:00
|
|
|
SystemMessage,
|
core[minor], ...: add tool calls message (#18947)
core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor]
```python
class ToolCall(TypedDict):
name: str
args: Dict[str, Any]
id: Optional[str]
class InvalidToolCall(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
error: Optional[str]
class ToolCallChunk(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
index: Optional[int]
class AIMessage(BaseMessage):
...
tool_calls: List[ToolCall] = []
invalid_tool_calls: List[InvalidToolCall] = []
...
class AIMessageChunk(AIMessage, BaseMessageChunk):
...
tool_call_chunks: Optional[List[ToolCallChunk]] = None
...
```
Important considerations:
- Parsing logic occurs within different providers;
- ~Changing output type is a breaking change for anyone doing explicit
type checking;~
- ~Langsmith rendering will need to be updated:
https://github.com/langchain-ai/langchainplus/pull/3561~
- ~Langserve will need to be updated~
- Adding chunks:
- ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has
non-null .tool_calls.~
- Tool call chunks are appended, merging when having equal values of
`index`.
- additional_kwargs accumulate the normal way.
- During streaming:
- ~Messages can change types (e.g., from AIMessageChunk to
AIToolCallsMessageChunk)~
- Output parsers parse additional_kwargs (during .invoke they read off
tool calls).
Packages outside of `partners/`:
- https://github.com/langchain-ai/langchain-cohere/pull/7
- https://github.com/langchain-ai/langchain-google/pull/123/files
---------
Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
|
|
|
ToolCall,
|
2024-02-22 15:36:16 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
from langchain_groq.chat_models import ChatGroq, _convert_dict_to_message
|
|
|
|
|
2024-04-03 21:40:20 +00:00
|
|
|
if "GROQ_API_KEY" not in os.environ:
|
|
|
|
os.environ["GROQ_API_KEY"] = "fake-key"
|
2024-02-22 15:36:16 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_groq_model_param() -> None:
|
|
|
|
llm = ChatGroq(model="foo")
|
|
|
|
assert llm.model_name == "foo"
|
|
|
|
llm = ChatGroq(model_name="foo")
|
|
|
|
assert llm.model_name == "foo"
|
|
|
|
|
|
|
|
|
|
|
|
def test_function_message_dict_to_function_message() -> None:
|
|
|
|
content = json.dumps({"result": "Example #1"})
|
|
|
|
name = "test_function"
|
|
|
|
result = _convert_dict_to_message(
|
|
|
|
{
|
|
|
|
"role": "function",
|
|
|
|
"name": name,
|
|
|
|
"content": content,
|
|
|
|
}
|
|
|
|
)
|
|
|
|
assert isinstance(result, FunctionMessage)
|
|
|
|
assert result.name == name
|
|
|
|
assert result.content == content
|
|
|
|
|
|
|
|
|
|
|
|
def test__convert_dict_to_message_human() -> None:
|
|
|
|
message = {"role": "user", "content": "foo"}
|
|
|
|
result = _convert_dict_to_message(message)
|
|
|
|
expected_output = HumanMessage(content="foo")
|
|
|
|
assert result == expected_output
|
|
|
|
|
|
|
|
|
|
|
|
def test__convert_dict_to_message_ai() -> None:
|
|
|
|
message = {"role": "assistant", "content": "foo"}
|
|
|
|
result = _convert_dict_to_message(message)
|
|
|
|
expected_output = AIMessage(content="foo")
|
|
|
|
assert result == expected_output
|
|
|
|
|
|
|
|
|
core[minor], ...: add tool calls message (#18947)
core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor]
```python
class ToolCall(TypedDict):
name: str
args: Dict[str, Any]
id: Optional[str]
class InvalidToolCall(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
error: Optional[str]
class ToolCallChunk(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
index: Optional[int]
class AIMessage(BaseMessage):
...
tool_calls: List[ToolCall] = []
invalid_tool_calls: List[InvalidToolCall] = []
...
class AIMessageChunk(AIMessage, BaseMessageChunk):
...
tool_call_chunks: Optional[List[ToolCallChunk]] = None
...
```
Important considerations:
- Parsing logic occurs within different providers;
- ~Changing output type is a breaking change for anyone doing explicit
type checking;~
- ~Langsmith rendering will need to be updated:
https://github.com/langchain-ai/langchainplus/pull/3561~
- ~Langserve will need to be updated~
- Adding chunks:
- ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has
non-null .tool_calls.~
- Tool call chunks are appended, merging when having equal values of
`index`.
- additional_kwargs accumulate the normal way.
- During streaming:
- ~Messages can change types (e.g., from AIMessageChunk to
AIToolCallsMessageChunk)~
- Output parsers parse additional_kwargs (during .invoke they read off
tool calls).
Packages outside of `partners/`:
- https://github.com/langchain-ai/langchain-cohere/pull/7
- https://github.com/langchain-ai/langchain-google/pull/123/files
---------
Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
|
|
|
def test__convert_dict_to_message_tool_call() -> None:
|
|
|
|
raw_tool_call = {
|
|
|
|
"id": "call_wm0JY6CdwOMZ4eTxHWUThDNz",
|
|
|
|
"function": {
|
|
|
|
"arguments": '{"name":"Sally","hair_color":"green"}',
|
|
|
|
"name": "GenerateUsername",
|
|
|
|
},
|
|
|
|
"type": "function",
|
|
|
|
}
|
|
|
|
message = {"role": "assistant", "content": None, "tool_calls": [raw_tool_call]}
|
|
|
|
result = _convert_dict_to_message(message)
|
|
|
|
expected_output = AIMessage(
|
|
|
|
content="",
|
|
|
|
additional_kwargs={"tool_calls": [raw_tool_call]},
|
|
|
|
tool_calls=[
|
|
|
|
ToolCall(
|
|
|
|
name="GenerateUsername",
|
|
|
|
args={"name": "Sally", "hair_color": "green"},
|
|
|
|
id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
|
|
|
|
)
|
|
|
|
],
|
|
|
|
)
|
|
|
|
assert result == expected_output
|
|
|
|
|
|
|
|
# Test malformed tool call
|
|
|
|
raw_tool_calls = [
|
|
|
|
{
|
|
|
|
"id": "call_wm0JY6CdwOMZ4eTxHWUThDNz",
|
|
|
|
"function": {
|
|
|
|
"arguments": "oops",
|
|
|
|
"name": "GenerateUsername",
|
|
|
|
},
|
|
|
|
"type": "function",
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"id": "call_abc123",
|
|
|
|
"function": {
|
|
|
|
"arguments": '{"name":"Sally","hair_color":"green"}',
|
|
|
|
"name": "GenerateUsername",
|
|
|
|
},
|
|
|
|
"type": "function",
|
|
|
|
},
|
|
|
|
]
|
|
|
|
message = {"role": "assistant", "content": None, "tool_calls": raw_tool_calls}
|
|
|
|
result = _convert_dict_to_message(message)
|
|
|
|
expected_output = AIMessage(
|
|
|
|
content="",
|
|
|
|
additional_kwargs={"tool_calls": raw_tool_calls},
|
|
|
|
invalid_tool_calls=[
|
|
|
|
InvalidToolCall(
|
|
|
|
name="GenerateUsername",
|
|
|
|
args="oops",
|
|
|
|
id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
|
|
|
|
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
|
|
|
|
),
|
|
|
|
],
|
|
|
|
tool_calls=[
|
|
|
|
ToolCall(
|
|
|
|
name="GenerateUsername",
|
|
|
|
args={"name": "Sally", "hair_color": "green"},
|
|
|
|
id="call_abc123",
|
|
|
|
),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
assert result == expected_output
|
|
|
|
|
|
|
|
|
2024-02-22 15:36:16 +00:00
|
|
|
def test__convert_dict_to_message_system() -> None:
|
|
|
|
message = {"role": "system", "content": "foo"}
|
|
|
|
result = _convert_dict_to_message(message)
|
|
|
|
expected_output = SystemMessage(content="foo")
|
|
|
|
assert result == expected_output
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def mock_completion() -> dict:
|
|
|
|
return {
|
|
|
|
"id": "chatcmpl-7fcZavknQda3SQ",
|
|
|
|
"object": "chat.completion",
|
|
|
|
"created": 1689989000,
|
|
|
|
"model": "test-model",
|
|
|
|
"choices": [
|
|
|
|
{
|
|
|
|
"index": 0,
|
|
|
|
"message": {
|
|
|
|
"role": "assistant",
|
|
|
|
"content": "Bar Baz",
|
|
|
|
},
|
|
|
|
"finish_reason": "stop",
|
|
|
|
}
|
|
|
|
],
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def test_groq_invoke(mock_completion: dict) -> None:
|
|
|
|
llm = ChatGroq()
|
|
|
|
mock_client = MagicMock()
|
|
|
|
completed = False
|
|
|
|
|
|
|
|
def mock_create(*args: Any, **kwargs: Any) -> Any:
|
|
|
|
nonlocal completed
|
|
|
|
completed = True
|
|
|
|
return mock_completion
|
|
|
|
|
|
|
|
mock_client.create = mock_create
|
|
|
|
with patch.object(
|
|
|
|
llm,
|
|
|
|
"client",
|
|
|
|
mock_client,
|
|
|
|
):
|
|
|
|
res = llm.invoke("bar")
|
|
|
|
assert res.content == "Bar Baz"
|
|
|
|
assert type(res) == AIMessage
|
|
|
|
assert completed
|
|
|
|
|
|
|
|
|
|
|
|
async def test_groq_ainvoke(mock_completion: dict) -> None:
|
|
|
|
llm = ChatGroq()
|
|
|
|
mock_client = AsyncMock()
|
|
|
|
completed = False
|
|
|
|
|
|
|
|
async def mock_create(*args: Any, **kwargs: Any) -> Any:
|
|
|
|
nonlocal completed
|
|
|
|
completed = True
|
|
|
|
return mock_completion
|
|
|
|
|
|
|
|
mock_client.create = mock_create
|
|
|
|
with patch.object(
|
|
|
|
llm,
|
|
|
|
"async_client",
|
|
|
|
mock_client,
|
|
|
|
):
|
|
|
|
res = await llm.ainvoke("bar")
|
|
|
|
assert res.content == "Bar Baz"
|
|
|
|
assert type(res) == AIMessage
|
|
|
|
assert completed
|
|
|
|
|
|
|
|
|
|
|
|
def test_chat_groq_extra_kwargs() -> None:
|
|
|
|
"""Test extra kwargs to chat groq."""
|
|
|
|
# Check that foo is saved in extra_kwargs.
|
|
|
|
with pytest.warns(UserWarning) as record:
|
|
|
|
llm = ChatGroq(foo=3, max_tokens=10)
|
|
|
|
assert llm.max_tokens == 10
|
|
|
|
assert llm.model_kwargs == {"foo": 3}
|
|
|
|
assert len(record) == 1
|
|
|
|
assert type(record[0].message) is UserWarning
|
|
|
|
assert "foo is not default parameter" in record[0].message.args[0]
|
|
|
|
|
|
|
|
# Test that if extra_kwargs are provided, they are added to it.
|
|
|
|
with pytest.warns(UserWarning) as record:
|
|
|
|
llm = ChatGroq(foo=3, model_kwargs={"bar": 2})
|
|
|
|
assert llm.model_kwargs == {"foo": 3, "bar": 2}
|
|
|
|
assert len(record) == 1
|
|
|
|
assert type(record[0].message) is UserWarning
|
|
|
|
assert "foo is not default parameter" in record[0].message.args[0]
|
|
|
|
|
|
|
|
# Test that if provided twice it errors
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
ChatGroq(foo=3, model_kwargs={"foo": 2})
|
|
|
|
|
|
|
|
# Test that if explicit param is specified in kwargs it errors
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
ChatGroq(model_kwargs={"temperature": 0.2})
|
|
|
|
|
|
|
|
# Test that "model" cannot be specified in kwargs
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
ChatGroq(model_kwargs={"model": "test-model"})
|
|
|
|
|
|
|
|
|
|
|
|
def test_chat_groq_invalid_streaming_params() -> None:
|
|
|
|
"""Test that an error is raised if streaming is invoked with n>1."""
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
ChatGroq(
|
|
|
|
max_tokens=10,
|
|
|
|
streaming=True,
|
|
|
|
temperature=0,
|
|
|
|
n=5,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def test_chat_groq_secret() -> None:
|
|
|
|
"""Test that secret is not printed"""
|
|
|
|
secret = "secretKey"
|
|
|
|
not_secret = "safe"
|
|
|
|
llm = ChatGroq(api_key=secret, model_kwargs={"not_secret": not_secret})
|
|
|
|
stringified = str(llm)
|
|
|
|
assert not_secret in stringified
|
|
|
|
assert secret not in stringified
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.filterwarnings("ignore:The function `loads` is in beta")
|
|
|
|
def test_groq_serialization() -> None:
|
|
|
|
"""Test that ChatGroq can be successfully serialized and deserialized"""
|
|
|
|
api_key1 = "top secret"
|
|
|
|
api_key2 = "topest secret"
|
|
|
|
llm = ChatGroq(api_key=api_key1, temperature=0.5)
|
|
|
|
dump = lc_load.dumps(llm)
|
|
|
|
llm2 = lc_load.loads(
|
|
|
|
dump,
|
|
|
|
valid_namespaces=["langchain_groq"],
|
|
|
|
secrets_map={"GROQ_API_KEY": api_key2},
|
|
|
|
)
|
|
|
|
|
|
|
|
assert type(llm2) is ChatGroq
|
|
|
|
|
|
|
|
# Ensure api key wasn't dumped and instead was read from secret map.
|
|
|
|
assert llm.groq_api_key is not None
|
|
|
|
assert llm.groq_api_key.get_secret_value() not in dump
|
|
|
|
assert llm2.groq_api_key is not None
|
|
|
|
assert llm2.groq_api_key.get_secret_value() == api_key2
|
|
|
|
|
|
|
|
# Ensure a non-secret field was preserved
|
|
|
|
assert llm.temperature == llm2.temperature
|
|
|
|
|
|
|
|
# Ensure a None was preserved
|
|
|
|
assert llm.groq_api_base == llm2.groq_api_base
|