mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
115 lines
3.9 KiB
Python
115 lines
3.9 KiB
Python
|
"""Test Google PaLM Chat API wrapper."""
|
||
|
|
||
|
import pytest
|
||
|
|
||
|
from langchain.chat_models.google_palm import (
|
||
|
ChatGooglePalm,
|
||
|
ChatGooglePalmError,
|
||
|
_messages_to_prompt_dict,
|
||
|
)
|
||
|
from langchain.schema import (
|
||
|
AIMessage,
|
||
|
HumanMessage,
|
||
|
SystemMessage,
|
||
|
)
|
||
|
|
||
|
|
||
|
def test_messages_to_prompt_dict_with_valid_messages() -> None:
|
||
|
pytest.importorskip("google.generativeai")
|
||
|
result = _messages_to_prompt_dict(
|
||
|
[
|
||
|
SystemMessage(content="Prompt"),
|
||
|
HumanMessage(example=True, content="Human example #1"),
|
||
|
AIMessage(example=True, content="AI example #1"),
|
||
|
HumanMessage(example=True, content="Human example #2"),
|
||
|
AIMessage(example=True, content="AI example #2"),
|
||
|
HumanMessage(content="Real human message"),
|
||
|
AIMessage(content="Real AI message"),
|
||
|
]
|
||
|
)
|
||
|
expected = {
|
||
|
"context": "Prompt",
|
||
|
"examples": [
|
||
|
{"author": "human", "content": "Human example #1"},
|
||
|
{"author": "ai", "content": "AI example #1"},
|
||
|
{"author": "human", "content": "Human example #2"},
|
||
|
{"author": "ai", "content": "AI example #2"},
|
||
|
],
|
||
|
"messages": [
|
||
|
{"author": "human", "content": "Real human message"},
|
||
|
{"author": "ai", "content": "Real AI message"},
|
||
|
],
|
||
|
}
|
||
|
|
||
|
assert result == expected
|
||
|
|
||
|
|
||
|
def test_messages_to_prompt_dict_raises_with_misplaced_system_message() -> None:
|
||
|
pytest.importorskip("google.generativeai")
|
||
|
with pytest.raises(ChatGooglePalmError) as e:
|
||
|
_messages_to_prompt_dict(
|
||
|
[
|
||
|
HumanMessage(content="Real human message"),
|
||
|
SystemMessage(content="Prompt"),
|
||
|
]
|
||
|
)
|
||
|
assert "System message must be first" in str(e)
|
||
|
|
||
|
|
||
|
def test_messages_to_prompt_dict_raises_with_misordered_examples() -> None:
|
||
|
pytest.importorskip("google.generativeai")
|
||
|
with pytest.raises(ChatGooglePalmError) as e:
|
||
|
_messages_to_prompt_dict(
|
||
|
[
|
||
|
AIMessage(example=True, content="AI example #1"),
|
||
|
HumanMessage(example=True, content="Human example #1"),
|
||
|
]
|
||
|
)
|
||
|
assert "AI example message must be immediately preceded" in str(e)
|
||
|
|
||
|
|
||
|
def test_messages_to_prompt_dict_raises_with_mismatched_examples() -> None:
|
||
|
pytest.importorskip("google.generativeai")
|
||
|
with pytest.raises(ChatGooglePalmError) as e:
|
||
|
_messages_to_prompt_dict(
|
||
|
[
|
||
|
HumanMessage(example=True, content="Human example #1"),
|
||
|
AIMessage(example=False, content="AI example #1"),
|
||
|
]
|
||
|
)
|
||
|
assert "Human example message must be immediately followed" in str(e)
|
||
|
|
||
|
|
||
|
def test_messages_to_prompt_dict_raises_with_example_after_real() -> None:
|
||
|
pytest.importorskip("google.generativeai")
|
||
|
with pytest.raises(ChatGooglePalmError) as e:
|
||
|
_messages_to_prompt_dict(
|
||
|
[
|
||
|
HumanMessage(example=False, content="Real message"),
|
||
|
HumanMessage(example=True, content="Human example #1"),
|
||
|
AIMessage(example=True, content="AI example #1"),
|
||
|
]
|
||
|
)
|
||
|
assert "Message examples must come before other" in str(e)
|
||
|
|
||
|
|
||
|
def test_chat_google_raises_with_invalid_temperature() -> None:
|
||
|
pytest.importorskip("google.generativeai")
|
||
|
with pytest.raises(ValueError) as e:
|
||
|
ChatGooglePalm(google_api_key="fake", temperature=2.0)
|
||
|
assert "must be in the range" in str(e)
|
||
|
|
||
|
|
||
|
def test_chat_google_raises_with_invalid_top_p() -> None:
|
||
|
pytest.importorskip("google.generativeai")
|
||
|
with pytest.raises(ValueError) as e:
|
||
|
ChatGooglePalm(google_api_key="fake", top_p=2.0)
|
||
|
assert "must be in the range" in str(e)
|
||
|
|
||
|
|
||
|
def test_chat_google_raises_with_invalid_top_k() -> None:
|
||
|
pytest.importorskip("google.generativeai")
|
||
|
with pytest.raises(ValueError) as e:
|
||
|
ChatGooglePalm(google_api_key="fake", top_k=-5)
|
||
|
assert "must be positive" in str(e)
|