You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/community/tests/integration_tests/llms/test_google_palm.py

77 lines
2.3 KiB
Python

"""Test Google GenerativeAI API wrapper.
Note: This test must be run with the GOOGLE_API_KEY environment variable set to a
valid API key.
"""
from pathlib import Path
import pytest
from langchain_core.outputs import LLMResult
from langchain_community.llms.google_palm import GooglePalm
from langchain_community.llms.loading import load_llm
model_names = [None, "models/text-bison-001", "gemini-pro"]
@pytest.mark.parametrize(
"model_name",
model_names,
)
def test_google_generativeai_call(model_name: str) -> None:
"""Test valid call to Google GenerativeAI text API."""
if model_name:
llm = GooglePalm(max_output_tokens=10, model_name=model_name)
else:
llm = GooglePalm(max_output_tokens=10)
output = llm("Say foo:")
assert isinstance(output, str)
assert llm._llm_type == "google_palm"
if model_name and "gemini" in model_name:
assert llm.client.model_name == "models/gemini-pro"
else:
assert llm.model_name == "models/text-bison-001"
@pytest.mark.parametrize(
"model_name",
model_names,
)
def test_google_generativeai_generate(model_name: str) -> None:
n = 1 if model_name == "gemini-pro" else 2
if model_name:
llm = GooglePalm(temperature=0.3, n=n, model_name=model_name)
else:
llm = GooglePalm(temperature=0.3, n=n)
output = llm.generate(["Say foo:"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 1
assert len(output.generations[0]) == n
def test_google_generativeai_get_num_tokens() -> None:
llm = GooglePalm()
output = llm.get_num_tokens("How are you?")
assert output == 4
async def test_google_generativeai_agenerate() -> None:
llm = GooglePalm(temperature=0, model_name="gemini-pro")
output = await llm.agenerate(["Please say foo:"])
assert isinstance(output, LLMResult)
def test_generativeai_stream() -> None:
llm = GooglePalm(temperature=0, model_name="gemini-pro")
outputs = list(llm.stream("Please say foo:"))
assert isinstance(outputs[0], str)
def test_saving_loading_llm(tmp_path: Path) -> None:
"""Test saving/loading a Google PaLM LLM."""
llm = GooglePalm(max_output_tokens=10)
llm.save(file_path=tmp_path / "google_palm.yaml")
loaded_llm = load_llm(tmp_path / "google_palm.yaml")
assert loaded_llm == llm