langchain/tests/unit_tests/chat_models/test_google_palm.py

115 lines
3.9 KiB
Python
Raw Normal View History

"""Test Google PaLM Chat API wrapper."""
import pytest
from langchain.chat_models.google_palm import (
ChatGooglePalm,
ChatGooglePalmError,
_messages_to_prompt_dict,
)
from langchain.schema import (
AIMessage,
HumanMessage,
SystemMessage,
)
def test_messages_to_prompt_dict_with_valid_messages() -> None:
pytest.importorskip("google.generativeai")
result = _messages_to_prompt_dict(
[
SystemMessage(content="Prompt"),
HumanMessage(example=True, content="Human example #1"),
AIMessage(example=True, content="AI example #1"),
HumanMessage(example=True, content="Human example #2"),
AIMessage(example=True, content="AI example #2"),
HumanMessage(content="Real human message"),
AIMessage(content="Real AI message"),
]
)
expected = {
"context": "Prompt",
"examples": [
{"author": "human", "content": "Human example #1"},
{"author": "ai", "content": "AI example #1"},
{"author": "human", "content": "Human example #2"},
{"author": "ai", "content": "AI example #2"},
],
"messages": [
{"author": "human", "content": "Real human message"},
{"author": "ai", "content": "Real AI message"},
],
}
assert result == expected
def test_messages_to_prompt_dict_raises_with_misplaced_system_message() -> None:
pytest.importorskip("google.generativeai")
with pytest.raises(ChatGooglePalmError) as e:
_messages_to_prompt_dict(
[
HumanMessage(content="Real human message"),
SystemMessage(content="Prompt"),
]
)
assert "System message must be first" in str(e)
def test_messages_to_prompt_dict_raises_with_misordered_examples() -> None:
pytest.importorskip("google.generativeai")
with pytest.raises(ChatGooglePalmError) as e:
_messages_to_prompt_dict(
[
AIMessage(example=True, content="AI example #1"),
HumanMessage(example=True, content="Human example #1"),
]
)
assert "AI example message must be immediately preceded" in str(e)
def test_messages_to_prompt_dict_raises_with_mismatched_examples() -> None:
pytest.importorskip("google.generativeai")
with pytest.raises(ChatGooglePalmError) as e:
_messages_to_prompt_dict(
[
HumanMessage(example=True, content="Human example #1"),
AIMessage(example=False, content="AI example #1"),
]
)
assert "Human example message must be immediately followed" in str(e)
def test_messages_to_prompt_dict_raises_with_example_after_real() -> None:
pytest.importorskip("google.generativeai")
with pytest.raises(ChatGooglePalmError) as e:
_messages_to_prompt_dict(
[
HumanMessage(example=False, content="Real message"),
HumanMessage(example=True, content="Human example #1"),
AIMessage(example=True, content="AI example #1"),
]
)
assert "Message examples must come before other" in str(e)
def test_chat_google_raises_with_invalid_temperature() -> None:
pytest.importorskip("google.generativeai")
with pytest.raises(ValueError) as e:
ChatGooglePalm(google_api_key="fake", temperature=2.0)
assert "must be in the range" in str(e)
def test_chat_google_raises_with_invalid_top_p() -> None:
pytest.importorskip("google.generativeai")
with pytest.raises(ValueError) as e:
ChatGooglePalm(google_api_key="fake", top_p=2.0)
assert "must be in the range" in str(e)
def test_chat_google_raises_with_invalid_top_k() -> None:
pytest.importorskip("google.generativeai")
with pytest.raises(ValueError) as e:
ChatGooglePalm(google_api_key="fake", top_k=-5)
assert "must be positive" in str(e)