2023-12-13 19:57:59 +00:00
|
|
|
"""Test chat model integration."""
|
2023-12-19 02:23:14 +00:00
|
|
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
2023-12-14 01:05:31 +00:00
|
|
|
from langchain_core.pydantic_v1 import SecretStr
|
|
|
|
from pytest import CaptureFixture
|
2023-12-13 19:57:59 +00:00
|
|
|
|
2023-12-19 02:23:14 +00:00
|
|
|
from langchain_google_genai.chat_models import (
|
|
|
|
ChatGoogleGenerativeAI,
|
|
|
|
_parse_chat_history,
|
|
|
|
)
|
2023-12-13 19:57:59 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_integration_initialization() -> None:
|
|
|
|
"""Test chat model initialization."""
|
|
|
|
ChatGoogleGenerativeAI(
|
|
|
|
model="gemini-nano",
|
|
|
|
google_api_key="...",
|
|
|
|
top_k=2,
|
|
|
|
top_p=1,
|
|
|
|
temperature=0.7,
|
|
|
|
n=2,
|
|
|
|
)
|
|
|
|
ChatGoogleGenerativeAI(
|
|
|
|
model="gemini-nano",
|
|
|
|
google_api_key="...",
|
|
|
|
top_k=2,
|
|
|
|
top_p=1,
|
|
|
|
temperature=0.7,
|
|
|
|
candidate_count=2,
|
|
|
|
)
|
2023-12-14 01:05:31 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_api_key_is_string() -> None:
|
|
|
|
chat = ChatGoogleGenerativeAI(model="gemini-nano", google_api_key="secret-api-key")
|
|
|
|
assert isinstance(chat.google_api_key, SecretStr)
|
|
|
|
|
|
|
|
|
|
|
|
def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> None:
|
|
|
|
chat = ChatGoogleGenerativeAI(model="gemini-nano", google_api_key="secret-api-key")
|
|
|
|
print(chat.google_api_key, end="")
|
|
|
|
captured = capsys.readouterr()
|
|
|
|
|
|
|
|
assert captured.out == "**********"
|
2023-12-19 02:23:14 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_parse_history() -> None:
|
|
|
|
system_input = "You're supposed to answer math questions."
|
|
|
|
text_question1, text_answer1 = "How much is 2+2?", "4"
|
|
|
|
text_question2 = "How much is 3+3?"
|
|
|
|
system_message = SystemMessage(content=system_input)
|
|
|
|
message1 = HumanMessage(content=text_question1)
|
|
|
|
message2 = AIMessage(content=text_answer1)
|
|
|
|
message3 = HumanMessage(content=text_question2)
|
|
|
|
messages = [system_message, message1, message2, message3]
|
|
|
|
history = _parse_chat_history(messages, convert_system_message_to_human=True)
|
|
|
|
assert len(history) == 3
|
|
|
|
assert history[0] == {
|
|
|
|
"role": "user",
|
|
|
|
"parts": [{"text": system_input}, {"text": text_question1}],
|
|
|
|
}
|
|
|
|
assert history[1] == {"role": "model", "parts": [{"text": text_answer1}]}
|