google-vertexai: added langchain_google_vertexai package (#15218)

added langchain_google_vertexai package

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
pull/5066/head
Leonid Kuligin 9 months ago committed by GitHub
parent e1fc4d5b95
commit f73bf4ee54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 LangChain, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

@ -0,0 +1,59 @@
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
# Default target executed when no arguments are given to make.
all: help
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/
test:
poetry run pytest $(TEST_FILE)
tests:
poetry run pytest $(TEST_FILE)
######################
# LINTING AND FORMATTING
######################
# Define a variable for Python and notebook files.
PYTHON_FILES=.
MYPY_CACHE=.mypy_cache
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/google-vertexai --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=langchain_google_vertexai
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
poetry run ruff .
poetry run ruff format $(PYTHON_FILES) --diff
poetry run ruff --select I $(PYTHON_FILES)
mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
poetry run ruff format $(PYTHON_FILES)
poetry run ruff --select I --fix $(PYTHON_FILES)
spell_check:
poetry run codespell --toml pyproject.toml
spell_fix:
poetry run codespell --toml pyproject.toml -w
check_imports: $(shell find langchain_google_vertexai -name '*.py')
poetry run python ./scripts/check_imports.py $^
######################
# HELP
######################
help:
@echo '----'
@echo 'check_imports - check imports'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'test - run unit tests'
@echo 'tests - run unit tests'
@echo 'test TEST_FILE=<test_file> - run all tests in file'

@ -0,0 +1,100 @@
# langchain-google-vertexai
This package contains the LangChain integrations for Google Cloud generative models.
## Installation
```bash
pip install -U langchain-google-vertexai
```
## Chat Models
`ChatVertexAI` class exposes models .
To use, you should have Google Cloud project with APIs enabled, and configured credentials. Initialize the model as:
```python
from langchain_google_vertexai import ChatVertexAI
llm = ChatVertexAI(model_name="gemini-pro")
llm.invoke("Sing a ballad of LangChain.")
```
You can use other models, e.g. `chat-bison`:
```python
from langchain_google_vertexai import ChatVertexAI
llm = ChatVertexAI(model_name="chat-bison", temperature=0.3)
llm.invoke("Sing a ballad of LangChain.")
```
#### Multimodal inputs
Gemini vision model supports image inputs when providing a single chat message. Example:
```python
from langchain_core.messages import HumanMessage
from langchain_google_vertexai import ChatVertexAI
llm = ChatVertexAI(model_name="gemini-pro-vision")
# example
message = HumanMessage(
content=[
{
"type": "text",
"text": "What's in this image?",
}, # You can optionally provide text parts
{"type": "image_url", "image_url": {"url": "https://picsum.photos/seed/picsum/200/300"}},
]
)
llm.invoke([message])
```
The value of `image_url` can be any of the following:
- A public image URL
- An accessible gcs file (e.g., "gcs://path/to/file.png")
- A local file path
- A base64 encoded image (e.g., ``)
## Embeddings
You can use Google Cloud's embeddings models as:
```
from langchain_google_vertexai import VertexAIEmbeddings
embeddings = VertexAIEmbeddings()
embeddings.embed_query("hello, world!")
```
## LLMs
You can use Google Cloud's generative AI models as Langchain LLMs:
```python
from langchain.prompts import PromptTemplate
from langchain_google_vertexai import VertexAI
template = """Question: {question}
Answer: Let's think step by step."""
prompt = PromptTemplate.from_template(template)
chain = prompt | llm
question = "Who was the president in the year Justin Beiber was born?"
print(chain.invoke({"question": question}))
```
You can use Gemini and Palm models, including code-generations ones:
```python
from langchain_google_vertexai import VertexAI
llm = VertexAI(model_name="code-bison", max_output_tokens=1000, temperature=0.3)
question = "Write a python function that checks if a string is a valid email address"
output = llm(question)
```

@ -0,0 +1,5 @@
from langchain_google_vertexai.chat_models import ChatVertexAI
from langchain_google_vertexai.embeddings import VertexAIEmbeddings
from langchain_google_vertexai.llms import VertexAI, VertexAIModelGarden
__all__ = ["ChatVertexAI", "VertexAIEmbeddings", "VertexAI", "VertexAIModelGarden"]

@ -0,0 +1,88 @@
"""Utilities to init Vertex AI."""
from importlib import metadata
from typing import Any, Callable, Optional, Union
import google.api_core
from google.api_core.gapic_v1.client_info import ClientInfo
from google.cloud import storage # type: ignore
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import create_base_retry_decorator
from vertexai.preview.generative_models import Image # type: ignore
def create_retry_decorator(
*,
max_retries: int = 1,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Creates a retry decorator for Vertex / Palm LLMs."""
errors = [
google.api_core.exceptions.ResourceExhausted,
google.api_core.exceptions.ServiceUnavailable,
google.api_core.exceptions.Aborted,
google.api_core.exceptions.DeadlineExceeded,
google.api_core.exceptions.GoogleAPIError,
]
decorator = create_base_retry_decorator(
error_types=errors, max_retries=max_retries, run_manager=run_manager
)
return decorator
def raise_vertex_import_error(minimum_expected_version: str = "1.38.0") -> None:
"""Raise ImportError related to Vertex SDK being not available.
Args:
minimum_expected_version: The lowest expected version of the SDK.
Raises:
ImportError: an ImportError that mentions a required version of the SDK.
"""
raise ImportError(
"Please, install or upgrade the google-cloud-aiplatform library: "
f"pip install google-cloud-aiplatform>={minimum_expected_version}"
)
def get_client_info(module: Optional[str] = None) -> "ClientInfo":
r"""Returns a custom user agent header.
Args:
module (Optional[str]):
Optional. The module for a custom user agent header.
Returns:
google.api_core.gapic_v1.client_info.ClientInfo
"""
langchain_version = metadata.version("langchain")
client_library_version = (
f"{langchain_version}-{module}" if module else langchain_version
)
return ClientInfo(
client_library_version=client_library_version,
user_agent=f"langchain/{client_library_version}",
)
def load_image_from_gcs(path: str, project: Optional[str] = None) -> Image:
"""Loads im Image from GCS."""
gcs_client = storage.Client(project=project)
pieces = path.split("/")
blobs = list(gcs_client.list_blobs(pieces[2], prefix="/".join(pieces[3:])))
if len(blobs) > 1:
raise ValueError(f"Found more than one candidate for {path}!")
return Image.from_bytes(blobs[0].download_as_bytes())
def is_codey_model(model_name: str) -> bool:
"""Returns True if the model name is a Codey model."""
return "code" in model_name
def is_gemini_model(model_name: str) -> bool:
"""Returns True if the model name is a Gemini model."""
return model_name is not None and "gemini" in model_name

@ -0,0 +1,366 @@
"""Wrapper around Google VertexAI chat-based models."""
from __future__ import annotations
import base64
import logging
import re
from dataclasses import dataclass, field
from typing import Any, Dict, Iterator, List, Optional, Union, cast
from urllib.parse import urlparse
import requests
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
BaseChatModel,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import root_validator
from vertexai.language_models import ( # type: ignore
ChatMessage,
ChatModel,
ChatSession,
CodeChatModel,
CodeChatSession,
InputOutputTextPair,
)
from vertexai.preview.generative_models import ( # type: ignore
Content,
GenerativeModel,
Image,
Part,
)
from langchain_google_vertexai._utils import (
is_codey_model,
is_gemini_model,
load_image_from_gcs,
)
from langchain_google_vertexai.llms import (
_VertexAICommon,
)
logger = logging.getLogger(__name__)
@dataclass
class _ChatHistory:
"""Represents a context and a history of messages."""
history: List[ChatMessage] = field(default_factory=list)
context: Optional[str] = None
def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory:
"""Parse a sequence of messages into history.
Args:
history: The list of messages to re-create the history of the chat.
Returns:
A parsed chat history.
Raises:
ValueError: If a sequence of message has a SystemMessage not at the
first place.
"""
vertex_messages, context = [], None
for i, message in enumerate(history):
content = cast(str, message.content)
if i == 0 and isinstance(message, SystemMessage):
context = content
elif isinstance(message, AIMessage):
vertex_message = ChatMessage(content=message.content, author="bot")
vertex_messages.append(vertex_message)
elif isinstance(message, HumanMessage):
vertex_message = ChatMessage(content=message.content, author="user")
vertex_messages.append(vertex_message)
else:
raise ValueError(
f"Unexpected message with type {type(message)} at the position {i}."
)
chat_history = _ChatHistory(context=context, history=vertex_messages)
return chat_history
def _is_url(s: str) -> bool:
try:
result = urlparse(s)
return all([result.scheme, result.netloc])
except Exception as e:
logger.debug(f"Unable to parse URL: {e}")
return False
def _parse_chat_history_gemini(
history: List[BaseMessage], project: Optional[str]
) -> List[Content]:
def _convert_to_prompt(part: Union[str, Dict]) -> Part:
if isinstance(part, str):
return Part.from_text(part)
if not isinstance(part, Dict):
raise ValueError(
f"Message's content is expected to be a dict, got {type(part)}!"
)
if part["type"] == "text":
return Part.from_text(part["text"])
elif part["type"] == "image_url":
path = part["image_url"]["url"]
if path.startswith("gs://"):
image = load_image_from_gcs(path=path, project=project)
elif path.startswith("data:image/"):
# extract base64 component from image uri
try:
regexp = r"data:image/\w{2,4};base64,(.*)"
encoded = re.search(regexp, path).group(1) # type: ignore
except AttributeError:
raise ValueError(
"Invalid image uri. It should be in the format "
"data:image/<image_type>;base64,<base64_encoded_image>."
)
image = Image.from_bytes(base64.b64decode(encoded))
elif _is_url(path):
response = requests.get(path)
response.raise_for_status()
image = Image.from_bytes(response.content)
else:
image = Image.load_from_file(path)
else:
raise ValueError("Only text and image_url types are supported!")
return Part.from_image(image)
vertex_messages = []
for i, message in enumerate(history):
if i == 0 and isinstance(message, SystemMessage):
raise ValueError("SystemMessages are not yet supported!")
elif isinstance(message, AIMessage):
role = "model"
elif isinstance(message, HumanMessage):
role = "user"
else:
raise ValueError(
f"Unexpected message with type {type(message)} at the position {i}."
)
raw_content = message.content
if isinstance(raw_content, str):
raw_content = [raw_content]
parts = [_convert_to_prompt(part) for part in raw_content]
vertex_message = Content(role=role, parts=parts)
vertex_messages.append(vertex_message)
return vertex_messages
def _parse_examples(examples: List[BaseMessage]) -> List[InputOutputTextPair]:
if len(examples) % 2 != 0:
raise ValueError(
f"Expect examples to have an even amount of messages, got {len(examples)}."
)
example_pairs = []
input_text = None
for i, example in enumerate(examples):
if i % 2 == 0:
if not isinstance(example, HumanMessage):
raise ValueError(
f"Expected the first message in a part to be from human, got "
f"{type(example)} for the {i}th message."
)
input_text = example.content
if i % 2 == 1:
if not isinstance(example, AIMessage):
raise ValueError(
f"Expected the second message in a part to be from AI, got "
f"{type(example)} for the {i}th message."
)
pair = InputOutputTextPair(
input_text=input_text, output_text=example.content
)
example_pairs.append(pair)
return example_pairs
def _get_question(messages: List[BaseMessage]) -> HumanMessage:
"""Get the human message at the end of a list of input messages to a chat model."""
if not messages:
raise ValueError("You should provide at least one message to start the chat!")
question = messages[-1]
if not isinstance(question, HumanMessage):
raise ValueError(
f"Last message in the list should be from human, got {question.type}."
)
return question
class ChatVertexAI(_VertexAICommon, BaseChatModel):
"""`Vertex AI` Chat large language models API."""
model_name: str = "chat-bison"
"Underlying model name."
examples: Optional[List[BaseMessage]] = None
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""
is_gemini = is_gemini_model(values["model_name"])
cls._init_vertexai(values)
if is_gemini:
values["client"] = GenerativeModel(model_name=values["model_name"])
else:
if is_codey_model(values["model_name"]):
model_cls = CodeChatModel
else:
model_cls = ChatModel
values["client"] = model_cls.from_pretrained(values["model_name"])
return values
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
"""Generate next turn in the conversation.
Args:
messages: The history of the conversation as a list of messages. Code chat
does not support context.
stop: The list of stop words (optional).
run_manager: The CallbackManager for LLM run, it's not used at the moment.
stream: Whether to use the streaming endpoint.
Returns:
The ChatResult that contains outputs generated by the model.
Raises:
ValueError: if the last message in the list is not from human.
"""
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
question = _get_question(messages)
params = self._prepare_params(stop=stop, stream=False, **kwargs)
msg_params = {}
if "candidate_count" in params:
msg_params["candidate_count"] = params.pop("candidate_count")
if self._is_gemini_model:
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
message = history_gemini.pop()
chat = self.client.start_chat(history=history_gemini)
response = chat.send_message(message, generation_config=params)
else:
history = _parse_chat_history(messages[:-1])
examples = kwargs.get("examples") or self.examples
if examples:
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, **params)
response = chat.send_message(question.content, **msg_params)
generations = [
ChatGeneration(message=AIMessage(content=r.text))
for r in response.candidates
]
return ChatResult(generations=generations)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Asynchronously generate next turn in the conversation.
Args:
messages: The history of the conversation as a list of messages. Code chat
does not support context.
stop: The list of stop words (optional).
run_manager: The CallbackManager for LLM run, it's not used at the moment.
Returns:
The ChatResult that contains outputs generated by the model.
Raises:
ValueError: if the last message in the list is not from human.
"""
if "stream" in kwargs:
kwargs.pop("stream")
logger.warning("ChatVertexAI does not currently support async streaming.")
params = self._prepare_params(stop=stop, **kwargs)
msg_params = {}
if "candidate_count" in params:
msg_params["candidate_count"] = params.pop("candidate_count")
if self._is_gemini_model:
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
message = history_gemini.pop()
chat = self.client.start_chat(history=history_gemini)
response = await chat.send_message_async(message, generation_config=params)
else:
question = _get_question(messages)
history = _parse_chat_history(messages[:-1])
examples = kwargs.get("examples", None)
if examples:
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, **params)
response = await chat.send_message_async(question.content, **msg_params)
generations = [
ChatGeneration(message=AIMessage(content=r.text))
for r in response.candidates
]
return ChatResult(generations=generations)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params = self._prepare_params(stop=stop, stream=True, **kwargs)
if self._is_gemini_model:
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
message = history_gemini.pop()
chat = self.client.start_chat(history=history_gemini)
responses = chat.send_message(
message, stream=True, generation_config=params
)
else:
question = _get_question(messages)
history = _parse_chat_history(messages[:-1])
examples = kwargs.get("examples", None)
if examples:
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, **params)
responses = chat.send_message_streaming(question.content, **params)
for response in responses:
if run_manager:
run_manager.on_llm_new_token(response.text)
yield ChatGenerationChunk(message=AIMessageChunk(content=response.text))
def _start_chat(
self, history: _ChatHistory, **kwargs: Any
) -> Union[ChatSession, CodeChatSession]:
if not self.is_codey_model:
return self.client.start_chat(
context=history.context, message_history=history.history, **kwargs
)
else:
return self.client.start_chat(message_history=history.history, **kwargs)

@ -0,0 +1,336 @@
import logging
import re
import string
import threading
from concurrent.futures import ThreadPoolExecutor, wait
from typing import Any, Dict, List, Literal, Optional, Tuple, Type
from google.api_core.exceptions import (
Aborted,
DeadlineExceeded,
InvalidArgument,
ResourceExhausted,
ServiceUnavailable,
)
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.pydantic_v1 import root_validator
from vertexai.language_models import ( # type: ignore
TextEmbeddingInput,
TextEmbeddingModel,
)
from langchain_google_vertexai.llms import _VertexAICommon
logger = logging.getLogger(__name__)
_MAX_TOKENS_PER_BATCH = 20000
_MAX_BATCH_SIZE = 250
_MIN_BATCH_SIZE = 5
class VertexAIEmbeddings(_VertexAICommon, Embeddings):
"""Google Cloud VertexAI embedding models."""
# Instance context
instance: Dict[str, Any] = {} #: :meta private:
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validates that the python package exists in environment."""
cls._init_vertexai(values)
if values["model_name"] == "textembedding-gecko-default":
logger.warning(
"Model_name will become a required arg for VertexAIEmbeddings "
"starting from Feb-01-2024. Currently the default is set to "
"textembedding-gecko@001"
)
values["model_name"] = "textembedding-gecko@001"
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
return values
def __init__(
self,
# the default value would be removed after Feb-01-2024
model_name: str = "textembedding-gecko-default",
project: Optional[str] = None,
location: str = "us-central1",
request_parallelism: int = 5,
max_retries: int = 6,
credentials: Optional[Any] = None,
**kwargs: Any,
):
"""Initialize the sentence_transformer."""
super().__init__(
project=project,
location=location,
credentials=credentials,
request_parallelism=request_parallelism,
max_retries=max_retries,
model_name=model_name,
**kwargs,
)
self.instance["max_batch_size"] = kwargs.get("max_batch_size", _MAX_BATCH_SIZE)
self.instance["batch_size"] = self.instance["max_batch_size"]
self.instance["min_batch_size"] = kwargs.get("min_batch_size", _MIN_BATCH_SIZE)
self.instance["min_good_batch_size"] = self.instance["min_batch_size"]
self.instance["lock"] = threading.Lock()
self.instance["batch_size_validated"] = False
self.instance["task_executor"] = ThreadPoolExecutor(
max_workers=request_parallelism
)
self.instance[
"embeddings_task_type_supported"
] = not self.client._endpoint_name.endswith("/textembedding-gecko@001")
@staticmethod
def _split_by_punctuation(text: str) -> List[str]:
"""Splits a string by punctuation and whitespace characters."""
split_by = string.punctuation + "\t\n "
pattern = f"([{split_by}])"
# Using re.split to split the text based on the pattern
return [segment for segment in re.split(pattern, text) if segment]
@staticmethod
def _prepare_batches(texts: List[str], batch_size: int) -> List[List[str]]:
"""Splits texts in batches based on current maximum batch size
and maximum tokens per request.
"""
text_index = 0
texts_len = len(texts)
batch_token_len = 0
batches: List[List[str]] = []
current_batch: List[str] = []
if texts_len == 0:
return []
while text_index < texts_len:
current_text = texts[text_index]
# Number of tokens per a text is conservatively estimated
# as 2 times number of words, punctuation and whitespace characters.
# Using `count_tokens` API will make batching too expensive.
# Utilizing a tokenizer, would add a dependency that would not
# necessarily be reused by the application using this class.
current_text_token_cnt = (
len(VertexAIEmbeddings._split_by_punctuation(current_text)) * 2
)
end_of_batch = False
if current_text_token_cnt > _MAX_TOKENS_PER_BATCH:
# Current text is too big even for a single batch.
# Such request will fail, but we still make a batch
# so that the app can get the error from the API.
if len(current_batch) > 0:
# Adding current batch if not empty.
batches.append(current_batch)
current_batch = [current_text]
text_index += 1
end_of_batch = True
elif (
batch_token_len + current_text_token_cnt > _MAX_TOKENS_PER_BATCH
or len(current_batch) == batch_size
):
end_of_batch = True
else:
if text_index == texts_len - 1:
# Last element - even though the batch may be not big,
# we still need to make it.
end_of_batch = True
batch_token_len += current_text_token_cnt
current_batch.append(current_text)
text_index += 1
if end_of_batch:
batches.append(current_batch)
current_batch = []
batch_token_len = 0
return batches
def _get_embeddings_with_retry(
self, texts: List[str], embeddings_type: Optional[str] = None
) -> List[List[float]]:
"""Makes a Vertex AI model request with retry logic."""
errors: List[Type[BaseException]] = [
ResourceExhausted,
ServiceUnavailable,
Aborted,
DeadlineExceeded,
]
retry_decorator = create_base_retry_decorator(
error_types=errors, max_retries=self.max_retries
)
@retry_decorator
def _completion_with_retry(texts_to_process: List[str]) -> Any:
if embeddings_type and self.instance["embeddings_task_type_supported"]:
requests = [
TextEmbeddingInput(text=t, task_type=embeddings_type)
for t in texts_to_process
]
else:
requests = texts_to_process
embeddings = self.client.get_embeddings(requests)
return [embs.values for embs in embeddings]
return _completion_with_retry(texts)
def _prepare_and_validate_batches(
self, texts: List[str], embeddings_type: Optional[str] = None
) -> Tuple[List[List[float]], List[List[str]]]:
"""Prepares text batches with one-time validation of batch size.
Batch size varies between GCP regions and individual project quotas.
# Returns embeddings of the first text batch that went through,
# and text batches for the rest of the texts.
"""
batches = VertexAIEmbeddings._prepare_batches(
texts, self.instance["batch_size"]
)
# If batch size if less or equal to one that went through before,
# then keep batches as they are.
if len(batches[0]) <= self.instance["min_good_batch_size"]:
return [], batches
with self.instance["lock"]:
# If largest possible batch size was validated
# while waiting for the lock, then check for rebuilding
# our batches, and return.
if self.instance["batch_size_validated"]:
if len(batches[0]) <= self.instance["batch_size"]:
return [], batches
else:
return [], VertexAIEmbeddings._prepare_batches(
texts, self.instance["batch_size"]
)
# Figure out largest possible batch size by trying to push
# batches and lowering their size in half after every failure.
first_batch = batches[0]
first_result = []
had_failure = False
while True:
try:
first_result = self._get_embeddings_with_retry(
first_batch, embeddings_type
)
break
except InvalidArgument:
had_failure = True
first_batch_len = len(first_batch)
if first_batch_len == self.instance["min_batch_size"]:
raise
first_batch_len = max(
self.instance["min_batch_size"], int(first_batch_len / 2)
)
first_batch = first_batch[:first_batch_len]
first_batch_len = len(first_batch)
self.instance["min_good_batch_size"] = max(
self.instance["min_good_batch_size"], first_batch_len
)
# If had a failure and recovered
# or went through with the max size, then it's a legit batch size.
if had_failure or first_batch_len == self.instance["max_batch_size"]:
self.instance["batch_size"] = first_batch_len
self.instance["batch_size_validated"] = True
# If batch size was updated,
# rebuild batches with the new batch size
# (texts that went through are excluded here).
if first_batch_len != self.instance["max_batch_size"]:
batches = VertexAIEmbeddings._prepare_batches(
texts[first_batch_len:], self.instance["batch_size"]
)
else:
# Still figuring out max batch size.
batches = batches[1:]
# Returning embeddings of the first text batch that went through,
# and text batches for the rest of texts.
return first_result, batches
def embed(
self,
texts: List[str],
batch_size: int = 0,
embeddings_task_type: Optional[
Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
]
] = None,
) -> List[List[float]]:
"""Embed a list of strings.
Args:
texts: List[str] The list of strings to embed.
batch_size: [int] The batch size of embeddings to send to the model.
If zero, then the largest batch size will be detected dynamically
at the first request, starting from 250, down to 5.
embeddings_task_type: [str] optional embeddings task type,
one of the following
RETRIEVAL_QUERY - Text is a query
in a search/retrieval setting.
RETRIEVAL_DOCUMENT - Text is a document
in a search/retrieval setting.
SEMANTIC_SIMILARITY - Embeddings will be used
for Semantic Textual Similarity (STS).
CLASSIFICATION - Embeddings will be used for classification.
CLUSTERING - Embeddings will be used for clustering.
Returns:
List of embeddings, one for each text.
"""
if len(texts) == 0:
return []
embeddings: List[List[float]] = []
first_batch_result: List[List[float]] = []
if batch_size > 0:
# Fixed batch size.
batches = VertexAIEmbeddings._prepare_batches(texts, batch_size)
else:
# Dynamic batch size, starting from 250 at the first call.
first_batch_result, batches = self._prepare_and_validate_batches(
texts, embeddings_task_type
)
# First batch result may have some embeddings already.
# In such case, batches have texts that were not processed yet.
embeddings.extend(first_batch_result)
tasks = []
for batch in batches:
tasks.append(
self.instance["task_executor"].submit(
self._get_embeddings_with_retry,
texts=batch,
embeddings_type=embeddings_task_type,
)
)
if len(tasks) > 0:
wait(tasks)
for t in tasks:
embeddings.extend(t.result())
return embeddings
def embed_documents(
self, texts: List[str], batch_size: int = 0
) -> List[List[float]]:
"""Embed a list of documents.
Args:
texts: List[str] The list of texts to embed.
batch_size: [int] The batch size of embeddings to send to the model.
If zero, then the largest batch size will be detected dynamically
at the first request, starting from 250, down to 5.
Returns:
List of embeddings, one for each text.
"""
return self.embed(texts, batch_size, "RETRIEVAL_DOCUMENT")
def embed_query(self, text: str) -> List[float]:
"""Embed a text.
Args:
text: The text to embed.
Returns:
Embedding for the text.
"""
embeddings = self.embed([text], 1, "RETRIEVAL_QUERY")
return embeddings[0]

@ -0,0 +1,469 @@
from __future__ import annotations
from concurrent.futures import Executor
from typing import Any, ClassVar, Dict, Iterator, List, Optional, Union
import vertexai # type: ignore
from google.api_core.client_options import ClientOptions
from google.cloud.aiplatform.gapic import (
PredictionServiceAsyncClient,
PredictionServiceClient,
)
from google.cloud.aiplatform.models import Prediction
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from vertexai.language_models import ( # type: ignore
CodeGenerationModel,
TextGenerationModel,
)
from vertexai.language_models._language_models import ( # type: ignore
TextGenerationResponse,
)
from vertexai.preview.generative_models import GenerativeModel, Image # type: ignore
from vertexai.preview.language_models import ( # type: ignore
CodeGenerationModel as PreviewCodeGenerationModel,
)
from vertexai.preview.language_models import (
TextGenerationModel as PreviewTextGenerationModel,
)
from langchain_google_vertexai._utils import (
create_retry_decorator,
get_client_info,
is_codey_model,
is_gemini_model,
)
def _completion_with_retry(
llm: VertexAI,
prompt: List[Union[str, Image]],
stream: bool = False,
is_gemini: bool = False,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = create_retry_decorator(
max_retries=llm.max_retries, run_manager=run_manager
)
@retry_decorator
def _completion_with_retry_inner(
prompt: List[Union[str, Image]], is_gemini: bool = False, **kwargs: Any
) -> Any:
if is_gemini:
return llm.client.generate_content(
prompt, stream=stream, generation_config=kwargs
)
else:
if stream:
return llm.client.predict_streaming(prompt[0], **kwargs)
return llm.client.predict(prompt[0], **kwargs)
return _completion_with_retry_inner(prompt, is_gemini, **kwargs)
async def _acompletion_with_retry(
llm: VertexAI,
prompt: str,
is_gemini: bool = False,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = create_retry_decorator(
max_retries=llm.max_retries, run_manager=run_manager
)
@retry_decorator
async def _acompletion_with_retry_inner(
prompt: str, is_gemini: bool = False, **kwargs: Any
) -> Any:
if is_gemini:
return await llm.client.generate_content_async(
prompt, generation_config=kwargs
)
return await llm.client.predict_async(prompt, **kwargs)
return await _acompletion_with_retry_inner(prompt, is_gemini, **kwargs)
class _VertexAIBase(BaseModel):
project: Optional[str] = None
"The default GCP project to use when making Vertex API calls."
location: str = "us-central1"
"The default location to use when making API calls."
request_parallelism: int = 5
"The amount of parallelism allowed for requests issued to VertexAI models. "
"Default is 5."
max_retries: int = 6
"""The maximum number of retries to make when generating."""
task_executor: ClassVar[Optional[Executor]] = Field(default=None, exclude=True)
stop: Optional[List[str]] = None
"Optional list of stop words to use when generating."
model_name: Optional[str] = None
"Underlying model name."
class _VertexAICommon(_VertexAIBase):
client: Any = None #: :meta private:
client_preview: Any = None #: :meta private:
model_name: str
"Underlying model name."
temperature: float = 0.0
"Sampling temperature, it controls the degree of randomness in token selection."
max_output_tokens: int = 128
"Token limit determines the maximum amount of text output from one prompt."
top_p: float = 0.95
"Tokens are selected from most probable to least until the sum of their "
"probabilities equals the top-p value. Top-p is ignored for Codey models."
top_k: int = 40
"How the model selects tokens for output, the next token is selected from "
"among the top-k most probable tokens. Top-k is ignored for Codey models."
credentials: Any = Field(default=None, exclude=True)
"The default custom credentials (google.auth.credentials.Credentials) to use "
"when making API calls. If not provided, credentials will be ascertained from "
"the environment."
n: int = 1
"""How many completions to generate for each prompt."""
streaming: bool = False
"""Whether to stream the results or not."""
@property
def _llm_type(self) -> str:
return "vertexai"
@property
def is_codey_model(self) -> bool:
return is_codey_model(self.model_name)
@property
def _is_gemini_model(self) -> bool:
return is_gemini_model(self.model_name)
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Gets the identifying parameters."""
return {**{"model_name": self.model_name}, **self._default_params}
@property
def _default_params(self) -> Dict[str, Any]:
params = {
"temperature": self.temperature,
"max_output_tokens": self.max_output_tokens,
"candidate_count": self.n,
}
if not self.is_codey_model:
params.update(
{
"top_k": self.top_k,
"top_p": self.top_p,
}
)
return params
@classmethod
def _init_vertexai(cls, values: Dict) -> None:
vertexai.init(
project=values.get("project"),
location=values.get("location"),
credentials=values.get("credentials"),
)
return None
def _prepare_params(
self,
stop: Optional[List[str]] = None,
stream: bool = False,
**kwargs: Any,
) -> dict:
stop_sequences = stop or self.stop
params_mapping = {"n": "candidate_count"}
params = {params_mapping.get(k, k): v for k, v in kwargs.items()}
params = {**self._default_params, "stop_sequences": stop_sequences, **params}
if stream or self.streaming:
params.pop("candidate_count")
return params
class VertexAI(_VertexAICommon, BaseLLM):
"""Google Vertex AI large language models."""
model_name: str = "text-bison"
"The name of the Vertex AI large language model."
tuned_model_name: Optional[str] = None
"The name of a tuned model. If provided, model_name is ignored."
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""
tuned_model_name = values.get("tuned_model_name")
model_name = values["model_name"]
is_gemini = is_gemini_model(values["model_name"])
cls._init_vertexai(values)
if is_codey_model(model_name):
model_cls = CodeGenerationModel
preview_model_cls = PreviewCodeGenerationModel
elif is_gemini:
model_cls = GenerativeModel
preview_model_cls = GenerativeModel
else:
model_cls = TextGenerationModel
preview_model_cls = PreviewTextGenerationModel
if tuned_model_name:
values["client"] = model_cls.get_tuned_model(tuned_model_name)
values["client_preview"] = preview_model_cls.get_tuned_model(
tuned_model_name
)
else:
if is_gemini:
values["client"] = model_cls(model_name=model_name)
values["client_preview"] = preview_model_cls(model_name=model_name)
else:
values["client"] = model_cls.from_pretrained(model_name)
values["client_preview"] = preview_model_cls.from_pretrained(model_name)
if values["streaming"] and values["n"] > 1:
raise ValueError("Only one candidate can be generated with streaming!")
return values
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.
Useful for checking if an input will fit in a model's context window.
Args:
text: The string input to tokenize.
Returns:
The integer number of tokens in the text.
"""
result = self.client_preview.count_tokens([text])
return result.total_tokens
def _response_to_generation(
self, response: TextGenerationResponse
) -> GenerationChunk:
"""Converts a stream response to a generation chunk."""
try:
generation_info = {
"is_blocked": response.is_blocked,
"safety_attributes": response.safety_attributes,
}
except Exception:
generation_info = None
return GenerationChunk(text=response.text, generation_info=generation_info)
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> LLMResult:
should_stream = stream if stream is not None else self.streaming
params = self._prepare_params(stop=stop, stream=should_stream, **kwargs)
generations: List[List[Generation]] = []
for prompt in prompts:
if should_stream:
generation = GenerationChunk(text="")
for chunk in self._stream(
prompt, stop=stop, run_manager=run_manager, **kwargs
):
generation += chunk
generations.append([generation])
else:
res = _completion_with_retry(
self,
[prompt],
stream=should_stream,
is_gemini=self._is_gemini_model,
run_manager=run_manager,
**params,
)
generations.append(
[self._response_to_generation(r) for r in res.candidates]
)
return LLMResult(generations=generations)
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
params = self._prepare_params(stop=stop, **kwargs)
generations = []
for prompt in prompts:
res = await _acompletion_with_retry(
self,
prompt,
is_gemini=self._is_gemini_model,
run_manager=run_manager,
**params,
)
generations.append(
[self._response_to_generation(r) for r in res.candidates]
)
return LLMResult(generations=generations)
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = self._prepare_params(stop=stop, stream=True, **kwargs)
for stream_resp in _completion_with_retry(
self,
[prompt],
stream=True,
is_gemini=self._is_gemini_model,
run_manager=run_manager,
**params,
):
chunk = self._response_to_generation(stream_resp)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
chunk=chunk,
verbose=self.verbose,
)
class VertexAIModelGarden(_VertexAIBase, BaseLLM):
"""Large language models served from Vertex AI Model Garden."""
client: Any = None #: :meta private:
async_client: Any = None #: :meta private:
endpoint_id: str
"A name of an endpoint where the model has been deployed."
allowed_model_args: Optional[List[str]] = None
"Allowed optional args to be passed to the model."
prompt_arg: str = "prompt"
result_arg: Optional[str] = "generated_text"
"Set result_arg to None if output of the model is expected to be a string."
"Otherwise, if it's a dict, provided an argument that contains the result."
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""
if not values["project"]:
raise ValueError(
"A GCP project should be provided to run inference on Model Garden!"
)
client_options = ClientOptions(
api_endpoint=f"{values['location']}-aiplatform.googleapis.com"
)
client_info = get_client_info(module="vertex-ai-model-garden")
values["client"] = PredictionServiceClient(
client_options=client_options, client_info=client_info
)
values["async_client"] = PredictionServiceAsyncClient(
client_options=client_options, client_info=client_info
)
return values
@property
def endpoint_path(self) -> str:
return self.client.endpoint_path(
project=self.project, # type: ignore
location=self.location,
endpoint=self.endpoint_id,
)
@property
def _llm_type(self) -> str:
return "vertexai_model_garden"
def _prepare_request(self, prompts: List[str], **kwargs: Any) -> List["Value"]:
instances = []
for prompt in prompts:
if self.allowed_model_args:
instance = {
k: v for k, v in kwargs.items() if k in self.allowed_model_args
}
else:
instance = {}
instance[self.prompt_arg] = prompt
instances.append(instance)
predict_instances = [
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
]
return predict_instances
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
instances = self._prepare_request(prompts, **kwargs)
response = self.client.predict(endpoint=self.endpoint_path, instances=instances)
return self._parse_response(response)
def _parse_response(self, predictions: "Prediction") -> LLMResult:
generations: List[List[Generation]] = []
for result in predictions.predictions:
generations.append(
[
Generation(text=self._parse_prediction(prediction))
for prediction in result
]
)
return LLMResult(generations=generations)
def _parse_prediction(self, prediction: Any) -> str:
if isinstance(prediction, str):
return prediction
if self.result_arg:
try:
return prediction[self.result_arg]
except KeyError:
if isinstance(prediction, str):
error_desc = (
"Provided non-None `result_arg` (result_arg="
f"{self.result_arg}). But got prediction of type "
f"{type(prediction)} instead of dict. Most probably, you"
"need to set `result_arg=None` during VertexAIModelGarden "
"initialization."
)
raise ValueError(error_desc)
else:
raise ValueError(f"{self.result_arg} key not found in prediction!")
return prediction
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
instances = self._prepare_request(prompts, **kwargs)
response = await self.async_client.predict(
endpoint=self.endpoint_path, instances=instances
)
return self._parse_response(response)

File diff suppressed because it is too large Load Diff

@ -0,0 +1,94 @@
[tool.poetry]
name = "langchain-google-vertexai"
version = "0.0.1"
description = "An integration package connecting GoogleVertexAI and LangChain"
authors = []
readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = ">=0.1,<0.2"
google-cloud-aiplatform = "1.38.1"
google-cloud-storage = "^2.14.0"
types-requests = "^2.31.0.20231231"
types-protobuf = "^4.24.0.4"
[tool.poetry.group.test]
optional = true
[tool.poetry.group.test.dependencies]
pytest = "^7.3.0"
freezegun = "^1.2.2"
pytest-mock = "^3.10.0"
syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
langchain-core = {path = "../../core", develop = true}
types-requests = "^2.31.0.20231231"
types-protobuf = "^4.24.0.4"
[tool.poetry.group.codespell]
optional = true
[tool.poetry.group.codespell.dependencies]
codespell = "^2.2.0"
[tool.poetry.group.test_integration]
optional = true
[tool.poetry.group.test_integration.dependencies]
[tool.poetry.group.lint]
optional = true
[tool.poetry.group.lint.dependencies]
ruff = "^0.1.5"
[tool.poetry.group.typing.dependencies]
mypy = "^0.991"
langchain-core = {path = "../../core", develop = true}
[tool.poetry.group.dev]
optional = true
[tool.poetry.group.dev.dependencies]
langchain-core = {path = "../../core", develop = true}
[tool.ruff]
select = [
"E", # pycodestyle
"F", # pyflakes
"I", # isort
]
[tool.mypy]
disallow_untyped_defs = "True"
[tool.coverage.run]
omit = [
"tests/*",
]
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
# --strict-markers will raise errors on unknown marks.
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
#
# https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors.
#
# https://github.com/tophat/syrupy
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
# Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [
"requires: mark tests as requiring a specific library",
"asyncio: mark tests as requiring asyncio",
"compile: mark placeholder test used to compile integration tests without running them",
]
asyncio_mode = "auto"

@ -0,0 +1,17 @@
import sys
import traceback
from importlib.machinery import SourceFileLoader
if __name__ == "__main__":
files = sys.argv[1:]
has_failure = False
for file in files:
try:
SourceFileLoader("x", file).load_module()
except Exception:
has_faillure = True
print(file)
traceback.print_exc()
print()
sys.exit(1 if has_failure else 0)

@ -0,0 +1,27 @@
#!/bin/bash
#
# This script searches for lines starting with "import pydantic" or "from pydantic"
# in tracked files within a Git repository.
#
# Usage: ./scripts/check_pydantic.sh /path/to/repository
# Check if a path argument is provided
if [ $# -ne 1 ]; then
echo "Usage: $0 /path/to/repository"
exit 1
fi
repository_path="$1"
# Search for lines matching the pattern within the specified repository
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
# Check if any matching lines were found
if [ -n "$result" ]; then
echo "ERROR: The following lines need to be updated:"
echo "$result"
echo "Please replace the code with an import from langchain_core.pydantic_v1."
echo "For example, replace 'from pydantic import BaseModel'"
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
exit 1
fi

@ -0,0 +1,17 @@
#!/bin/bash
set -eu
# Initialize a variable to keep track of errors
errors=0
# make sure not importing from langchain or langchain_experimental
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
# Decide on an exit status based on the errors
if [ "$errors" -gt 0 ]; then
exit 1
else
exit 0
fi

@ -0,0 +1,176 @@
"""Test ChatGoogleVertexAI chat model."""
import pytest
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import LLMResult
from langchain_google_vertexai.chat_models import ChatVertexAI
model_names_to_test = [None, "codechat-bison", "chat-bison", "gemini-pro"]
@pytest.mark.parametrize("model_name", model_names_to_test)
def test_initialization(model_name: str) -> None:
"""Test chat model initialization."""
if model_name:
model = ChatVertexAI(model_name=model_name)
else:
model = ChatVertexAI()
assert model._llm_type == "vertexai"
try:
assert model.model_name == model.client._model_id
except AttributeError:
assert model.model_name == model.client._model_name.split("/")[-1]
@pytest.mark.parametrize("model_name", model_names_to_test)
def test_vertexai_single_call(model_name: str) -> None:
if model_name:
model = ChatVertexAI(model_name=model_name)
else:
model = ChatVertexAI()
message = HumanMessage(content="Hello")
response = model([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
# mark xfail because Vertex API randomly doesn't respect
# the n/candidate_count parameter
@pytest.mark.xfail
def test_candidates() -> None:
model = ChatVertexAI(model_name="chat-bison@001", temperature=0.3, n=2)
message = HumanMessage(content="Hello")
response = model.generate(messages=[[message]])
assert isinstance(response, LLMResult)
assert len(response.generations) == 1
assert len(response.generations[0]) == 2
@pytest.mark.parametrize("model_name", ["chat-bison@001", "gemini-pro"])
async def test_vertexai_agenerate(model_name: str) -> None:
model = ChatVertexAI(temperature=0, model_name=model_name)
message = HumanMessage(content="Hello")
response = await model.agenerate([[message]])
assert isinstance(response, LLMResult)
assert isinstance(response.generations[0][0].message, AIMessage) # type: ignore
sync_response = model.generate([[message]])
assert response.generations[0][0] == sync_response.generations[0][0]
@pytest.mark.parametrize("model_name", ["chat-bison@001", "gemini-pro"])
def test_vertexai_stream(model_name: str) -> None:
model = ChatVertexAI(temperature=0, model_name=model_name)
message = HumanMessage(content="Hello")
sync_response = model.stream([message])
for chunk in sync_response:
assert isinstance(chunk, AIMessageChunk)
def test_vertexai_single_call_with_context() -> None:
model = ChatVertexAI()
raw_context = (
"My name is Ned. You are my personal assistant. My favorite movies "
"are Lord of the Rings and Hobbit."
)
question = (
"Hello, could you recommend a good movie for me to watch this evening, please?"
)
context = SystemMessage(content=raw_context)
message = HumanMessage(content=question)
response = model([context, message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_multimodal() -> None:
llm = ChatVertexAI(model_name="gemini-pro-vision")
gcs_url = (
"gs://cloud-samples-data/generative-ai/image/"
"320px-Felis_catus-cat_on_snow.jpg"
)
image_message = {
"type": "image_url",
"image_url": {"url": gcs_url},
}
text_message = {
"type": "text",
"text": "What is shown in this image?",
}
message = HumanMessage(content=[text_message, image_message])
output = llm([message])
assert isinstance(output.content, str)
def test_multimodal_history() -> None:
llm = ChatVertexAI(model_name="gemini-pro-vision")
gcs_url = (
"gs://cloud-samples-data/generative-ai/image/"
"320px-Felis_catus-cat_on_snow.jpg"
)
image_message = {
"type": "image_url",
"image_url": {"url": gcs_url},
}
text_message = {
"type": "text",
"text": "What is shown in this image?",
}
message1 = HumanMessage(content=[text_message, image_message])
message2 = AIMessage(
content=(
"This is a picture of a cat in the snow. The cat is a tabby cat, which is "
"a type of cat with a striped coat. The cat is standing in the snow, and "
"its fur is covered in snow."
)
)
message3 = HumanMessage(content="What time of day is it?")
response = llm([message1, message2, message3])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_vertexai_single_call_with_examples() -> None:
model = ChatVertexAI()
raw_context = "My name is Ned. You are my personal assistant."
question = "2+2"
text_question, text_answer = "4+4", "8"
inp = HumanMessage(content=text_question)
output = AIMessage(content=text_answer)
context = SystemMessage(content=raw_context)
message = HumanMessage(content=question)
response = model([context, message], examples=[inp, output])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
@pytest.mark.parametrize("model_name", model_names_to_test)
def test_vertexai_single_call_with_history(model_name: str) -> None:
if model_name:
model = ChatVertexAI(model_name=model_name)
else:
model = ChatVertexAI()
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
response = model([message1, message2, message3])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_vertexai_single_call_fails_no_message() -> None:
chat = ChatVertexAI()
with pytest.raises(ValueError) as exc_info:
_ = chat([])
assert (
str(exc_info.value)
== "You should provide at least one message to start the chat!"
)

@ -0,0 +1,7 @@
import pytest
@pytest.mark.compile
def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests."""
pass

@ -0,0 +1,70 @@
"""Test Vertex AI API wrapper.
Your end-user credentials would be used to make the calls (make sure you've run
`gcloud auth login` first).
"""
import pytest
from langchain_google_vertexai.embeddings import VertexAIEmbeddings
def test_initialization() -> None:
"""Test embedding model initialization."""
VertexAIEmbeddings()
def test_langchain_google_vertexai_embedding_documents() -> None:
documents = ["foo bar"]
model = VertexAIEmbeddings()
output = model.embed_documents(documents)
assert len(output) == 1
assert len(output[0]) == 768
assert model.model_name == model.client._model_id
assert model.model_name == "textembedding-gecko@001"
def test_langchain_google_vertexai_embedding_query() -> None:
document = "foo bar"
model = VertexAIEmbeddings()
output = model.embed_query(document)
assert len(output) == 768
def test_langchain_google_vertexai_large_batches() -> None:
documents = ["foo bar" for _ in range(0, 251)]
model_uscentral1 = VertexAIEmbeddings(location="us-central1")
model_asianortheast1 = VertexAIEmbeddings(location="asia-northeast1")
model_uscentral1.embed_documents(documents)
model_asianortheast1.embed_documents(documents)
assert model_uscentral1.instance["batch_size"] >= 250
assert model_asianortheast1.instance["batch_size"] < 50
def test_langchain_google_vertexai_paginated_texts() -> None:
documents = [
"foo bar",
"foo baz",
"bar foo",
"baz foo",
"bar bar",
"foo foo",
"baz baz",
"baz bar",
]
model = VertexAIEmbeddings()
output = model.embed_documents(documents)
assert len(output) == 8
assert len(output[0]) == 768
assert model.model_name == model.client._model_id
def test_warning(caplog: pytest.LogCaptureFixture) -> None:
_ = VertexAIEmbeddings()
assert len(caplog.records) == 1
record = caplog.records[0]
assert record.levelname == "WARNING"
expected_message = (
"Model_name will become a required arg for VertexAIEmbeddings starting from "
"Feb-01-2024. Currently the default is set to textembedding-gecko@001"
)
assert record.message == expected_message

@ -0,0 +1,175 @@
"""Test Vertex AI API wrapper.
Your end-user credentials would be used to make the calls (make sure you've run
`gcloud auth login` first).
"""
import os
from typing import Optional
import pytest
from langchain_core.outputs import LLMResult
from langchain_google_vertexai.llms import VertexAI, VertexAIModelGarden
model_names_to_test = ["text-bison@001", "gemini-pro"]
model_names_to_test_with_default = [None] + model_names_to_test
@pytest.mark.parametrize(
"model_name",
model_names_to_test_with_default,
)
def test_vertex_initialization(model_name: str) -> None:
llm = VertexAI(model_name=model_name) if model_name else VertexAI()
assert llm._llm_type == "vertexai"
try:
assert llm.model_name == llm.client._model_id
except AttributeError:
assert llm.model_name == llm.client._model_name.split("/")[-1]
@pytest.mark.parametrize(
"model_name",
model_names_to_test_with_default,
)
def test_vertex_call(model_name: str) -> None:
llm = (
VertexAI(model_name=model_name, temperature=0)
if model_name
else VertexAI(temperature=0.0)
)
output = llm("Say foo:")
assert isinstance(output, str)
def test_vertex_generate() -> None:
llm = VertexAI(temperature=0.3, n=2, model_name="text-bison@001")
output = llm.generate(["Say foo:"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 1
assert len(output.generations[0]) == 2
def test_vertex_generate_code() -> None:
llm = VertexAI(temperature=0.3, n=2, model_name="code-bison@001")
output = llm.generate(["generate a python method that says foo:"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 1
assert len(output.generations[0]) == 2
async def test_vertex_agenerate() -> None:
llm = VertexAI(temperature=0)
output = await llm.agenerate(["Please say foo:"])
assert isinstance(output, LLMResult)
@pytest.mark.parametrize(
"model_name",
model_names_to_test_with_default,
)
def test_stream(model_name: str) -> None:
llm = (
VertexAI(temperature=0, model_name=model_name)
if model_name
else VertexAI(temperature=0)
)
for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token, str)
async def test_vertex_consistency() -> None:
llm = VertexAI(temperature=0)
output = llm.generate(["Please say foo:"])
streaming_output = llm.generate(["Please say foo:"], stream=True)
async_output = await llm.agenerate(["Please say foo:"])
assert output.generations[0][0].text == streaming_output.generations[0][0].text
assert output.generations[0][0].text == async_output.generations[0][0].text
@pytest.mark.parametrize(
"endpoint_os_variable_name,result_arg",
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
)
def test_model_garden(
endpoint_os_variable_name: str, result_arg: Optional[str]
) -> None:
"""In order to run this test, you should provide endpoint names.
Example:
export FALCON_ENDPOINT_ID=...
export LLAMA_ENDPOINT_ID=...
export PROJECT=...
"""
endpoint_id = os.environ[endpoint_os_variable_name]
project = os.environ["PROJECT"]
location = "europe-west4"
llm = VertexAIModelGarden(
endpoint_id=endpoint_id,
project=project,
result_arg=result_arg,
location=location,
)
output = llm("What is the meaning of life?")
assert isinstance(output, str)
assert llm._llm_type == "vertexai_model_garden"
@pytest.mark.parametrize(
"endpoint_os_variable_name,result_arg",
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
)
def test_model_garden_generate(
endpoint_os_variable_name: str, result_arg: Optional[str]
) -> None:
"""In order to run this test, you should provide endpoint names.
Example:
export FALCON_ENDPOINT_ID=...
export LLAMA_ENDPOINT_ID=...
export PROJECT=...
"""
endpoint_id = os.environ[endpoint_os_variable_name]
project = os.environ["PROJECT"]
location = "europe-west4"
llm = VertexAIModelGarden(
endpoint_id=endpoint_id,
project=project,
result_arg=result_arg,
location=location,
)
output = llm.generate(["What is the meaning of life?", "How much is 2+2"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 2
@pytest.mark.asyncio
@pytest.mark.parametrize(
"endpoint_os_variable_name,result_arg",
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
)
async def test_model_garden_agenerate(
endpoint_os_variable_name: str, result_arg: Optional[str]
) -> None:
endpoint_id = os.environ[endpoint_os_variable_name]
project = os.environ["PROJECT"]
location = "europe-west4"
llm = VertexAIModelGarden(
endpoint_id=endpoint_id,
project=project,
result_arg=result_arg,
location=location,
)
output = await llm.agenerate(["What is the meaning of life?", "How much is 2+2"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 2
@pytest.mark.parametrize(
"model_name",
model_names_to_test,
)
def test_vertex_call_count_tokens(model_name: str) -> None:
llm = VertexAI(model_name=model_name)
output = llm.get_num_tokens("How are you?")
assert output == 4

@ -0,0 +1,112 @@
"""Test chat model integration."""
from typing import Optional
from unittest.mock import MagicMock, Mock, patch
import pytest
from langchain_core.messages import (
AIMessage,
HumanMessage,
SystemMessage,
)
from vertexai.language_models import ChatMessage, InputOutputTextPair # type: ignore
from langchain_google_vertexai.chat_models import (
ChatVertexAI,
_parse_chat_history,
_parse_examples,
)
def test_parse_examples_correct() -> None:
text_question = (
"Hello, could you recommend a good movie for me to watch this evening, please?"
)
question = HumanMessage(content=text_question)
text_answer = (
"Sure, You might enjoy The Lord of the Rings: The Fellowship of the Ring "
"(2001): This is the first movie in the Lord of the Rings trilogy."
)
answer = AIMessage(content=text_answer)
examples = _parse_examples([question, answer, question, answer])
assert len(examples) == 2
assert examples == [
InputOutputTextPair(input_text=text_question, output_text=text_answer),
InputOutputTextPair(input_text=text_question, output_text=text_answer),
]
def test_parse_examples_failes_wrong_sequence() -> None:
with pytest.raises(ValueError) as exc_info:
_ = _parse_examples([AIMessage(content="a")])
assert (
str(exc_info.value)
== "Expect examples to have an even amount of messages, got 1."
)
@pytest.mark.parametrize("stop", [None, "stop1"])
def test_vertexai_args_passed(stop: Optional[str]) -> None:
response_text = "Goodbye"
user_prompt = "Hello"
prompt_params = {
"max_output_tokens": 1,
"temperature": 10000.0,
"top_k": 10,
"top_p": 0.5,
}
# Mock the library to ensure the args are passed correctly
with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg:
mock_response = MagicMock()
mock_response.candidates = [Mock(text=response_text)]
mock_chat = MagicMock()
mock_send_message = MagicMock(return_value=mock_response)
mock_chat.send_message = mock_send_message
mock_model = MagicMock()
mock_start_chat = MagicMock(return_value=mock_chat)
mock_model.start_chat = mock_start_chat
mg.return_value = mock_model
model = ChatVertexAI(**prompt_params)
message = HumanMessage(content=user_prompt)
if stop:
response = model([message], stop=[stop])
else:
response = model([message])
assert response.content == response_text
mock_send_message.assert_called_once_with(user_prompt, candidate_count=1)
expected_stop_sequence = [stop] if stop else None
mock_start_chat.assert_called_once_with(
context=None,
message_history=[],
**prompt_params,
stop_sequences=expected_stop_sequence,
)
def test_parse_chat_history_correct() -> None:
text_context = (
"My name is Ned. You are my personal assistant. My "
"favorite movies are Lord of the Rings and Hobbit."
)
context = SystemMessage(content=text_context)
text_question = (
"Hello, could you recommend a good movie for me to watch this evening, please?"
)
question = HumanMessage(content=text_question)
text_answer = (
"Sure, You might enjoy The Lord of the Rings: The Fellowship of the Ring "
"(2001): This is the first movie in the Lord of the Rings trilogy."
)
answer = AIMessage(content=text_answer)
history = _parse_chat_history([context, question, answer, question, answer])
assert history.context == context.content
assert len(history.history) == 4
assert history.history == [
ChatMessage(content=text_question, author="user"),
ChatMessage(content=text_answer, author="bot"),
ChatMessage(content=text_question, author="user"),
ChatMessage(content=text_answer, author="bot"),
]

@ -0,0 +1,7 @@
from langchain_google_vertexai import __all__
EXPECTED_ALL = ["ChatVertexAI", "VertexAIEmbeddings", "VertexAI", "VertexAIModelGarden"]
def test_all_imports() -> None:
assert sorted(EXPECTED_ALL) == sorted(__all__)
Loading…
Cancel
Save