|
|
|
@ -9,7 +9,7 @@ from langchain_core.outputs import LLMResult
|
|
|
|
|
|
|
|
|
|
from langchain_google_genai.llms import GoogleGenerativeAI
|
|
|
|
|
|
|
|
|
|
model_names = [None, "models/text-bison-001", "gemini-pro"]
|
|
|
|
|
model_names = ["models/text-bison-001", "gemini-pro"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
@ -37,10 +37,7 @@ def test_google_generativeai_call(model_name: str) -> None:
|
|
|
|
|
)
|
|
|
|
|
def test_google_generativeai_generate(model_name: str) -> None:
|
|
|
|
|
n = 1 if model_name == "gemini-pro" else 2
|
|
|
|
|
if model_name:
|
|
|
|
|
llm = GoogleGenerativeAI(temperature=0.3, n=n, model=model_name)
|
|
|
|
|
else:
|
|
|
|
|
llm = GoogleGenerativeAI(temperature=0.3, n=n)
|
|
|
|
|
llm = GoogleGenerativeAI(temperature=0.3, n=n, model=model_name)
|
|
|
|
|
output = llm.generate(["Say foo:"])
|
|
|
|
|
assert isinstance(output, LLMResult)
|
|
|
|
|
assert len(output.generations) == 1
|
|
|
|
@ -48,7 +45,7 @@ def test_google_generativeai_generate(model_name: str) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_google_generativeai_get_num_tokens() -> None:
|
|
|
|
|
llm = GoogleGenerativeAI()
|
|
|
|
|
llm = GoogleGenerativeAI(model="models/text-bison-001")
|
|
|
|
|
output = llm.get_num_tokens("How are you?")
|
|
|
|
|
assert output == 4
|
|
|
|
|
|
|
|
|
|