"""Test Google PaLM Chat API wrapper.""" import pytest from langchain.chat_models.google_palm import ( ChatGooglePalm, ChatGooglePalmError, _messages_to_prompt_dict, ) from langchain.schema import ( AIMessage, HumanMessage, SystemMessage, ) def test_messages_to_prompt_dict_with_valid_messages() -> None: pytest.importorskip("google.generativeai") result = _messages_to_prompt_dict( [ SystemMessage(content="Prompt"), HumanMessage(example=True, content="Human example #1"), AIMessage(example=True, content="AI example #1"), HumanMessage(example=True, content="Human example #2"), AIMessage(example=True, content="AI example #2"), HumanMessage(content="Real human message"), AIMessage(content="Real AI message"), ] ) expected = { "context": "Prompt", "examples": [ {"author": "human", "content": "Human example #1"}, {"author": "ai", "content": "AI example #1"}, {"author": "human", "content": "Human example #2"}, {"author": "ai", "content": "AI example #2"}, ], "messages": [ {"author": "human", "content": "Real human message"}, {"author": "ai", "content": "Real AI message"}, ], } assert result == expected def test_messages_to_prompt_dict_raises_with_misplaced_system_message() -> None: pytest.importorskip("google.generativeai") with pytest.raises(ChatGooglePalmError) as e: _messages_to_prompt_dict( [ HumanMessage(content="Real human message"), SystemMessage(content="Prompt"), ] ) assert "System message must be first" in str(e) def test_messages_to_prompt_dict_raises_with_misordered_examples() -> None: pytest.importorskip("google.generativeai") with pytest.raises(ChatGooglePalmError) as e: _messages_to_prompt_dict( [ AIMessage(example=True, content="AI example #1"), HumanMessage(example=True, content="Human example #1"), ] ) assert "AI example message must be immediately preceded" in str(e) def test_messages_to_prompt_dict_raises_with_mismatched_examples() -> None: pytest.importorskip("google.generativeai") with pytest.raises(ChatGooglePalmError) as e: _messages_to_prompt_dict( [ HumanMessage(example=True, content="Human example #1"), AIMessage(example=False, content="AI example #1"), ] ) assert "Human example message must be immediately followed" in str(e) def test_messages_to_prompt_dict_raises_with_example_after_real() -> None: pytest.importorskip("google.generativeai") with pytest.raises(ChatGooglePalmError) as e: _messages_to_prompt_dict( [ HumanMessage(example=False, content="Real message"), HumanMessage(example=True, content="Human example #1"), AIMessage(example=True, content="AI example #1"), ] ) assert "Message examples must come before other" in str(e) def test_chat_google_raises_with_invalid_temperature() -> None: pytest.importorskip("google.generativeai") with pytest.raises(ValueError) as e: ChatGooglePalm(google_api_key="fake", temperature=2.0) assert "must be in the range" in str(e) def test_chat_google_raises_with_invalid_top_p() -> None: pytest.importorskip("google.generativeai") with pytest.raises(ValueError) as e: ChatGooglePalm(google_api_key="fake", top_p=2.0) assert "must be in the range" in str(e) def test_chat_google_raises_with_invalid_top_k() -> None: pytest.importorskip("google.generativeai") with pytest.raises(ValueError) as e: ChatGooglePalm(google_api_key="fake", top_k=-5) assert "must be positive" in str(e)