forked from Archives/langchain
921894960b
- Add langchain.llms.GooglePalm for text completion, - Add langchain.chat_models.ChatGooglePalm for chat completion, - Add langchain.embeddings.GooglePalmEmbeddings for sentence embeddings, - Add example field to HumanMessage and AIMessage so that users can feed in examples into the PaLM Chat API, - Add system and unit tests. Note async completion for the Text API is not yet supported and will be included in a future PR. Happy for feedback on any aspect of this PR, especially our choice of adding an example field to Human and AI Message objects to enable passing example messages to the API.
115 lines
3.9 KiB
Python
115 lines
3.9 KiB
Python
"""Test Google PaLM Chat API wrapper."""
|
|
|
|
import pytest
|
|
|
|
from langchain.chat_models.google_palm import (
|
|
ChatGooglePalm,
|
|
ChatGooglePalmError,
|
|
_messages_to_prompt_dict,
|
|
)
|
|
from langchain.schema import (
|
|
AIMessage,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
)
|
|
|
|
|
|
def test_messages_to_prompt_dict_with_valid_messages() -> None:
|
|
pytest.importorskip("google.generativeai")
|
|
result = _messages_to_prompt_dict(
|
|
[
|
|
SystemMessage(content="Prompt"),
|
|
HumanMessage(example=True, content="Human example #1"),
|
|
AIMessage(example=True, content="AI example #1"),
|
|
HumanMessage(example=True, content="Human example #2"),
|
|
AIMessage(example=True, content="AI example #2"),
|
|
HumanMessage(content="Real human message"),
|
|
AIMessage(content="Real AI message"),
|
|
]
|
|
)
|
|
expected = {
|
|
"context": "Prompt",
|
|
"examples": [
|
|
{"author": "human", "content": "Human example #1"},
|
|
{"author": "ai", "content": "AI example #1"},
|
|
{"author": "human", "content": "Human example #2"},
|
|
{"author": "ai", "content": "AI example #2"},
|
|
],
|
|
"messages": [
|
|
{"author": "human", "content": "Real human message"},
|
|
{"author": "ai", "content": "Real AI message"},
|
|
],
|
|
}
|
|
|
|
assert result == expected
|
|
|
|
|
|
def test_messages_to_prompt_dict_raises_with_misplaced_system_message() -> None:
|
|
pytest.importorskip("google.generativeai")
|
|
with pytest.raises(ChatGooglePalmError) as e:
|
|
_messages_to_prompt_dict(
|
|
[
|
|
HumanMessage(content="Real human message"),
|
|
SystemMessage(content="Prompt"),
|
|
]
|
|
)
|
|
assert "System message must be first" in str(e)
|
|
|
|
|
|
def test_messages_to_prompt_dict_raises_with_misordered_examples() -> None:
|
|
pytest.importorskip("google.generativeai")
|
|
with pytest.raises(ChatGooglePalmError) as e:
|
|
_messages_to_prompt_dict(
|
|
[
|
|
AIMessage(example=True, content="AI example #1"),
|
|
HumanMessage(example=True, content="Human example #1"),
|
|
]
|
|
)
|
|
assert "AI example message must be immediately preceded" in str(e)
|
|
|
|
|
|
def test_messages_to_prompt_dict_raises_with_mismatched_examples() -> None:
|
|
pytest.importorskip("google.generativeai")
|
|
with pytest.raises(ChatGooglePalmError) as e:
|
|
_messages_to_prompt_dict(
|
|
[
|
|
HumanMessage(example=True, content="Human example #1"),
|
|
AIMessage(example=False, content="AI example #1"),
|
|
]
|
|
)
|
|
assert "Human example message must be immediately followed" in str(e)
|
|
|
|
|
|
def test_messages_to_prompt_dict_raises_with_example_after_real() -> None:
|
|
pytest.importorskip("google.generativeai")
|
|
with pytest.raises(ChatGooglePalmError) as e:
|
|
_messages_to_prompt_dict(
|
|
[
|
|
HumanMessage(example=False, content="Real message"),
|
|
HumanMessage(example=True, content="Human example #1"),
|
|
AIMessage(example=True, content="AI example #1"),
|
|
]
|
|
)
|
|
assert "Message examples must come before other" in str(e)
|
|
|
|
|
|
def test_chat_google_raises_with_invalid_temperature() -> None:
|
|
pytest.importorskip("google.generativeai")
|
|
with pytest.raises(ValueError) as e:
|
|
ChatGooglePalm(google_api_key="fake", temperature=2.0)
|
|
assert "must be in the range" in str(e)
|
|
|
|
|
|
def test_chat_google_raises_with_invalid_top_p() -> None:
|
|
pytest.importorskip("google.generativeai")
|
|
with pytest.raises(ValueError) as e:
|
|
ChatGooglePalm(google_api_key="fake", top_p=2.0)
|
|
assert "must be in the range" in str(e)
|
|
|
|
|
|
def test_chat_google_raises_with_invalid_top_k() -> None:
|
|
pytest.importorskip("google.generativeai")
|
|
with pytest.raises(ValueError) as e:
|
|
ChatGooglePalm(google_api_key="fake", top_k=-5)
|
|
assert "must be positive" in str(e)
|