2023-05-24 22:51:12 +00:00
|
|
|
"""Test Vertex AI API wrapper.
|
2023-11-29 01:38:04 +00:00
|
|
|
In order to run this test, you need to install VertexAI SDK:
|
|
|
|
pip install google-cloud-aiplatform>=1.36.0
|
2023-05-24 22:51:12 +00:00
|
|
|
|
|
|
|
Your end-user credentials would be used to make the calls (make sure you've run
|
|
|
|
`gcloud auth login` first).
|
|
|
|
"""
|
2023-09-01 22:58:21 +00:00
|
|
|
import os
|
2023-11-27 18:31:53 +00:00
|
|
|
from typing import Optional
|
2023-09-01 22:58:21 +00:00
|
|
|
|
2023-09-22 08:44:09 +00:00
|
|
|
import pytest
|
2023-11-21 16:35:29 +00:00
|
|
|
from langchain_core.outputs import LLMResult
|
2023-09-22 08:44:09 +00:00
|
|
|
|
2023-12-11 21:53:30 +00:00
|
|
|
from langchain_community.llms import VertexAI, VertexAIModelGarden
|
2023-05-24 22:51:12 +00:00
|
|
|
|
|
|
|
|
2023-09-23 22:51:59 +00:00
|
|
|
def test_vertex_initialization() -> None:
|
|
|
|
llm = VertexAI()
|
|
|
|
assert llm._llm_type == "vertexai"
|
|
|
|
assert llm.model_name == llm.client._model_id
|
|
|
|
|
|
|
|
|
2023-05-24 22:51:12 +00:00
|
|
|
def test_vertex_call() -> None:
|
2023-09-22 08:44:09 +00:00
|
|
|
llm = VertexAI(temperature=0)
|
2023-05-24 22:51:12 +00:00
|
|
|
output = llm("Say foo:")
|
|
|
|
assert isinstance(output, str)
|
2023-09-01 22:58:21 +00:00
|
|
|
|
|
|
|
|
2023-09-23 22:51:59 +00:00
|
|
|
@pytest.mark.scheduled
|
2023-09-22 08:44:09 +00:00
|
|
|
def test_vertex_generate() -> None:
|
2023-10-13 20:31:20 +00:00
|
|
|
llm = VertexAI(temperature=0.3, n=2, model_name="text-bison@001")
|
|
|
|
output = llm.generate(["Say foo:"])
|
2023-09-22 08:44:09 +00:00
|
|
|
assert isinstance(output, LLMResult)
|
2023-10-13 20:31:20 +00:00
|
|
|
assert len(output.generations) == 1
|
|
|
|
assert len(output.generations[0]) == 2
|
2023-09-22 08:44:09 +00:00
|
|
|
|
|
|
|
|
2023-10-24 17:20:05 +00:00
|
|
|
@pytest.mark.scheduled
|
|
|
|
def test_vertex_generate_code() -> None:
|
|
|
|
llm = VertexAI(temperature=0.3, n=2, model_name="code-bison@001")
|
|
|
|
output = llm.generate(["generate a python method that says foo:"])
|
|
|
|
assert isinstance(output, LLMResult)
|
|
|
|
assert len(output.generations) == 1
|
|
|
|
assert len(output.generations[0]) == 2
|
|
|
|
|
|
|
|
|
2023-09-23 22:51:59 +00:00
|
|
|
@pytest.mark.scheduled
|
2023-09-22 08:44:09 +00:00
|
|
|
async def test_vertex_agenerate() -> None:
|
2023-10-13 20:31:20 +00:00
|
|
|
llm = VertexAI(temperature=0)
|
2023-09-22 08:44:09 +00:00
|
|
|
output = await llm.agenerate(["Please say foo:"])
|
|
|
|
assert isinstance(output, LLMResult)
|
|
|
|
|
|
|
|
|
2023-09-23 22:51:59 +00:00
|
|
|
@pytest.mark.scheduled
|
|
|
|
def test_vertex_stream() -> None:
|
2023-10-13 20:31:20 +00:00
|
|
|
llm = VertexAI(temperature=0)
|
2023-09-22 08:44:09 +00:00
|
|
|
outputs = list(llm.stream("Please say foo:"))
|
|
|
|
assert isinstance(outputs[0], str)
|
|
|
|
|
|
|
|
|
|
|
|
async def test_vertex_consistency() -> None:
|
2023-10-13 20:31:20 +00:00
|
|
|
llm = VertexAI(temperature=0)
|
2023-09-22 08:44:09 +00:00
|
|
|
output = llm.generate(["Please say foo:"])
|
|
|
|
streaming_output = llm.generate(["Please say foo:"], stream=True)
|
|
|
|
async_output = await llm.agenerate(["Please say foo:"])
|
|
|
|
assert output.generations[0][0].text == streaming_output.generations[0][0].text
|
|
|
|
assert output.generations[0][0].text == async_output.generations[0][0].text
|
|
|
|
|
|
|
|
|
2023-11-27 18:31:53 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"endpoint_os_variable_name,result_arg",
|
|
|
|
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
|
|
|
|
)
|
|
|
|
def test_model_garden(
|
|
|
|
endpoint_os_variable_name: str, result_arg: Optional[str]
|
|
|
|
) -> None:
|
|
|
|
"""In order to run this test, you should provide endpoint names.
|
2023-09-01 22:58:21 +00:00
|
|
|
|
|
|
|
Example:
|
2023-11-27 18:31:53 +00:00
|
|
|
export FALCON_ENDPOINT_ID=...
|
|
|
|
export LLAMA_ENDPOINT_ID=...
|
2023-09-01 22:58:21 +00:00
|
|
|
export PROJECT=...
|
|
|
|
"""
|
2023-11-27 18:31:53 +00:00
|
|
|
endpoint_id = os.environ[endpoint_os_variable_name]
|
2023-09-01 22:58:21 +00:00
|
|
|
project = os.environ["PROJECT"]
|
2023-11-27 18:31:53 +00:00
|
|
|
location = "europe-west4"
|
|
|
|
llm = VertexAIModelGarden(
|
|
|
|
endpoint_id=endpoint_id,
|
|
|
|
project=project,
|
|
|
|
result_arg=result_arg,
|
|
|
|
location=location,
|
|
|
|
)
|
2023-09-01 22:58:21 +00:00
|
|
|
output = llm("What is the meaning of life?")
|
|
|
|
assert isinstance(output, str)
|
|
|
|
assert llm._llm_type == "vertexai_model_garden"
|
|
|
|
|
|
|
|
|
2023-11-27 18:31:53 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"endpoint_os_variable_name,result_arg",
|
|
|
|
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
|
|
|
|
)
|
|
|
|
def test_model_garden_generate(
|
|
|
|
endpoint_os_variable_name: str, result_arg: Optional[str]
|
|
|
|
) -> None:
|
|
|
|
"""In order to run this test, you should provide endpoint names.
|
2023-09-01 22:58:21 +00:00
|
|
|
|
|
|
|
Example:
|
2023-11-27 18:31:53 +00:00
|
|
|
export FALCON_ENDPOINT_ID=...
|
|
|
|
export LLAMA_ENDPOINT_ID=...
|
2023-09-01 22:58:21 +00:00
|
|
|
export PROJECT=...
|
|
|
|
"""
|
2023-11-27 18:31:53 +00:00
|
|
|
endpoint_id = os.environ[endpoint_os_variable_name]
|
2023-09-01 22:58:21 +00:00
|
|
|
project = os.environ["PROJECT"]
|
2023-11-27 18:31:53 +00:00
|
|
|
location = "europe-west4"
|
|
|
|
llm = VertexAIModelGarden(
|
|
|
|
endpoint_id=endpoint_id,
|
|
|
|
project=project,
|
|
|
|
result_arg=result_arg,
|
|
|
|
location=location,
|
|
|
|
)
|
2023-09-22 08:44:09 +00:00
|
|
|
output = llm.generate(["What is the meaning of life?", "How much is 2+2"])
|
|
|
|
assert isinstance(output, LLMResult)
|
|
|
|
assert len(output.generations) == 2
|
|
|
|
|
|
|
|
|
2023-11-27 18:31:53 +00:00
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"endpoint_os_variable_name,result_arg",
|
|
|
|
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
|
|
|
|
)
|
|
|
|
async def test_model_garden_agenerate(
|
|
|
|
endpoint_os_variable_name: str, result_arg: Optional[str]
|
|
|
|
) -> None:
|
|
|
|
endpoint_id = os.environ[endpoint_os_variable_name]
|
2023-09-22 08:44:09 +00:00
|
|
|
project = os.environ["PROJECT"]
|
2023-11-27 18:31:53 +00:00
|
|
|
location = "europe-west4"
|
|
|
|
llm = VertexAIModelGarden(
|
|
|
|
endpoint_id=endpoint_id,
|
|
|
|
project=project,
|
|
|
|
result_arg=result_arg,
|
|
|
|
location=location,
|
|
|
|
)
|
2023-09-22 08:44:09 +00:00
|
|
|
output = await llm.agenerate(["What is the meaning of life?", "How much is 2+2"])
|
2023-09-01 22:58:21 +00:00
|
|
|
assert isinstance(output, LLMResult)
|
|
|
|
assert len(output.generations) == 2
|
2023-10-30 22:10:05 +00:00
|
|
|
|
|
|
|
|
2023-11-29 01:38:04 +00:00
|
|
|
def test_vertex_call_count_tokens() -> None:
|
2023-10-30 22:10:05 +00:00
|
|
|
llm = VertexAI()
|
2023-11-29 01:38:04 +00:00
|
|
|
output = llm.get_num_tokens("How are you?")
|
|
|
|
assert output == 4
|