You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/partners/google-genai/tests/integration_tests/test_chat_models.py

189 lines
8.2 KiB
Python

"""Test ChatGoogleGenerativeAI chat model."""
import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_google_genai.chat_models import (
ChatGoogleGenerativeAI,
ChatGoogleGenerativeAIError,
)
_MODEL = "gemini-pro" # TODO: Use nano when it's available.
_VISION_MODEL = "gemini-pro-vision"
_B64_string = """iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABhGlDQ1BJQ0MgUHJvZmlsZQAAeJx9kT1Iw0AcxV8/xCIVQTuIKGSoTi2IijhqFYpQIdQKrTqYXPoFTRqSFBdHwbXg4Mdi1cHFWVcHV0EQ/ABxdXFSdJES/5cUWsR4cNyPd/ced+8Af6PCVDM4DqiaZaSTCSGbWxW6XxHECPoRQ0hipj4niil4jq97+Ph6F+dZ3uf+HL1K3mSATyCeZbphEW8QT29aOud94ggrSQrxOXHMoAsSP3JddvmNc9FhP8+MGJn0PHGEWCh2sNzBrGSoxFPEUUXVKN+fdVnhvMVZrdRY6578heG8trLMdZrDSGIRSxAhQEYNZVRgIU6rRoqJNO0nPPxDjl8kl0yuMhg5FlCFCsnxg//B727NwuSEmxROAF0vtv0xCnTvAs26bX8f23bzBAg8A1da219tADOfpNfbWvQI6NsGLq7bmrwHXO4Ag0+6ZEiOFKDpLxSA9zP6phwwcAv0rLm9tfZx+gBkqKvUDXBwCIwVKXvd492hzt7+PdPq7wdzbXKn5swsVgAAA8lJREFUeJx90dtPHHUUB/Dz+81vZhb2wrDI3soUKBSRcisF21iqqCRNY01NTE0k8aHpi0k18VJfjOFvUF9M44MmGrHFQqSQiKSmFloL5c4CXW6Fhb0vO3ufvczMzweiBGI9+eW8ffI95/yQqqrwv4UxBgCfJ9w/2NfSVB+Nyn6/r+vdLo7H6FkYY6yoABR2PJujj34MSo/d/nHeVLYbydmIp/bEO0fEy/+NMcbTU4/j4Vs6Lr0ccKeYuUKWS4ABVCVHmRdszbfvTgfjR8kz5Jjs+9RREl9Zy2lbVK9wU3/kWLJLCXnqza1bfVe7b9jLbIeTMcYu13Jg/aMiPrCwVFcgtDiMhnxwJ/zXVDwSdVCVMRV7nqzl2i9e/fKrw8mqSp84e2sFj3Oj8/SrF/MaicmyYhAaXu58NPAbeAeyzY0NLecmh2+ODN3BewYBAkAY43giI3kebrnsRmvV9z2D4ciOa3EBAf31Tp9sMgdxMTFm6j74/Ogb70VCYQKAAIDCXkOAIC6pkYBWdwwnpHEdf6L9dJtJKPh95DZhzFKMEWRAGL927XpWTmMA+s8DAOBYAoR483l/iHZ/8bXoODl8b9UfyH72SXepzbyRJNvjFGHKMlhvMBze+cH9+4lEuOOlU2X1tVkFTU7Om03q080NDGXV1cflRpHwaaoiiiildB8jhDLZ7HDfz2Yidba6Vn2L4fhzFrNRKy5OZ2QOZ1U5W8VtqlVH/iUHcM933zZYWS7Wtj66zZr65bzGJQt0glHgudi9XVzEl4vKw2kUPhO020oPYI1qYc+2Xc0bRXFwTLY0VXa2VibD/lBaIXm1UChN5JSRUcQQ1Tk/47Cf3x8bY7y17Y17PVYTG1UkLPBFcqik7Zoa9JcLYoHBqHhXNgd6gS1k9EJ1TQ2l9EDy1saErmQ2kGpwGC2MLOtCM8nZEV1K0tKJtEksSm26J/rHg2zzmabKisq939nHzqUH7efzd4f/nPGW6NP8ybNFrOsWQhpoCuuhnJ4hAnPhFam01K4oQMjBg/mzBjVhuvw2O++KKT+BIVxJKzQECBDLF2qu2WTMmCovtDQ1f8iyoGkUADBCCGPsdnvTW2OtFm01VeB06msvdWlpPZU0wJRG85ns84umU3k+VyxeEcWqvYUBAGsUrbvme4be99HFeisP/pwUOIZaOqQX31ISgrKmZhLHtXNXuJq68orrr5/9mBCglCLAGGPyy81votEbcjlKLrC9E8mhH3wdHRdcyyvjidSlxjftPJpD+o25JYvRHGFoZDdks1mBQhxJu9uxvwEiXuHnHbLd1AAAAABJRU5ErkJggg==""" # noqa: E501
def test_chat_google_genai_stream() -> None:
"""Test streaming tokens from Gemini."""
llm = ChatGoogleGenerativeAI(model=_MODEL)
for token in llm.stream("This is a test. Say 'foo'"):
assert isinstance(token.content, str)
async def test_chat_google_genai_astream() -> None:
"""Test streaming tokens from Gemini."""
llm = ChatGoogleGenerativeAI(model=_MODEL)
async for token in llm.astream("This is a test. Say 'foo'"):
assert isinstance(token.content, str)
async def test_chat_google_genai_abatch() -> None:
"""Test streaming tokens from ChatGoogleGenerativeAI."""
llm = ChatGoogleGenerativeAI(model=_MODEL)
result = await llm.abatch(
["This is a test. Say 'foo'", "This is a test, say 'bar'"]
)
for token in result:
assert isinstance(token.content, str)
async def test_chat_google_genai_abatch_tags() -> None:
"""Test batch tokens from ChatGoogleGenerativeAI."""
llm = ChatGoogleGenerativeAI(model=_MODEL)
result = await llm.abatch(
["This is a test", "This is another test"], config={"tags": ["foo"]}
)
for token in result:
assert isinstance(token.content, str)
def test_chat_google_genai_batch() -> None:
"""Test batch tokens from ChatGoogleGenerativeAI."""
llm = ChatGoogleGenerativeAI(model=_MODEL)
result = llm.batch(["This is a test. Say 'foo'", "This is a test, say 'bar'"])
for token in result:
assert isinstance(token.content, str)
async def test_chat_google_genai_ainvoke() -> None:
"""Test invoke tokens from ChatGoogleGenerativeAI."""
llm = ChatGoogleGenerativeAI(model=_MODEL)
result = await llm.ainvoke("This is a test. Say 'foo'", config={"tags": ["foo"]})
assert isinstance(result.content, str)
def test_chat_google_genai_invoke() -> None:
"""Test invoke tokens from ChatGoogleGenerativeAI."""
llm = ChatGoogleGenerativeAI(model=_MODEL)
result = llm.invoke(
"This is a test. Say 'foo'",
config=dict(tags=["foo"]),
generation_config=dict(top_k=2, top_p=1, temperature=0.7),
)
assert isinstance(result.content, str)
assert not result.content.startswith(" ")
def test_chat_google_genai_invoke_multimodal() -> None:
messages: list = [
HumanMessage(
content=[
{
"type": "text",
"text": "Guess what's in this picture! You have 3 guesses.",
},
{
"type": "image_url",
"image_url": "data:image/png;base64," + _B64_string,
},
]
),
]
llm = ChatGoogleGenerativeAI(model=_VISION_MODEL)
response = llm.invoke(messages)
assert isinstance(response.content, str)
assert len(response.content.strip()) > 0
# Try streaming
for chunk in llm.stream(messages):
print(chunk)
assert isinstance(chunk.content, str)
assert len(chunk.content.strip()) > 0
def test_chat_google_genai_invoke_multimodal_too_many_messages() -> None:
# Only supports 1 turn...
messages: list = [
HumanMessage(content="Hi there"),
AIMessage(content="Hi, how are you?"),
HumanMessage(
content=[
{
"type": "text",
"text": "I'm doing great! Guess what's in this picture!",
},
{
"type": "image_url",
"image_url": "data:image/png;base64," + _B64_string,
},
]
),
]
llm = ChatGoogleGenerativeAI(model=_VISION_MODEL)
with pytest.raises(ChatGoogleGenerativeAIError):
llm.invoke(messages)
def test_chat_google_genai_invoke_multimodal_invalid_model() -> None:
# need the vision model to support this.
messages: list = [
HumanMessage(
content=[
{
"type": "text",
"text": "I'm doing great! Guess what's in this picture!",
},
{
"type": "image_url",
"image_url": "data:image/png;base64," + _B64_string,
},
]
),
]
llm = ChatGoogleGenerativeAI(model=_MODEL)
with pytest.raises(ChatGoogleGenerativeAIError):
llm.invoke(messages)
def test_chat_google_genai_single_call_with_history() -> None:
model = ChatGoogleGenerativeAI(model=_MODEL)
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
response = model([message1, message2, message3])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_chat_google_genai_system_message_error() -> None:
model = ChatGoogleGenerativeAI(model=_MODEL)
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content="You're supposed to answer math questions.")
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
with pytest.raises(ValueError):
model([system_message, message1, message2, message3])
def test_chat_google_genai_system_message() -> None:
model = ChatGoogleGenerativeAI(model=_MODEL, convert_system_message_to_human=True)
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content="You're supposed to answer math questions.")
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
response = model([system_message, message1, message2, message3])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)