google-genai, google-vertexai: move to langchain-google (#17899)

These packages have moved to
https://github.com/langchain-ai/langchain-google

Left tombstone readmes incase anyone ends up at the "Source Code" link
from old pypi releases. Can keep these around for a few months.
pull/17904/head
Erick Friis 7 months ago committed by GitHub
parent 3b5bdbfee8
commit 248c5b84ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2023 LangChain, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

@ -1,61 +0,0 @@
.PHONY: all format lint test tests integration_tests help
# Default target executed when no arguments are given to make.
all: help
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/
test:
poetry run pytest $(TEST_FILE)
tests:
poetry run pytest $(TEST_FILE)
check_imports: $(shell find langchain_google_genai -name '*.py')
poetry run python ./scripts/check_imports.py $^
integration_tests:
poetry run pytest tests/integration_tests
######################
# LINTING AND FORMATTING
######################
# Define a variable for Python and notebook files.
PYTHON_FILES=.
MYPY_CACHE=.mypy_cache
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=langchain_google_genai
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
./scripts/check_pydantic.sh .
./scripts/lint_imports.sh
poetry run ruff .
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES)
format format_diff:
poetry run ruff format $(PYTHON_FILES)
poetry run ruff --select I --fix $(PYTHON_FILES)
spell_check:
poetry run codespell --toml pyproject.toml
spell_fix:
poetry run codespell --toml pyproject.toml -w
######################
# HELP
######################
help:
@echo '----'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'test - run unit tests'
@echo 'tests - run unit tests'
@echo 'test TEST_FILE=<test_file> - run all tests in file'

@ -1,78 +1,3 @@
# langchain-google-genai
This package has moved!
This package contains the LangChain integrations for Gemini through their generative-ai SDK.
## Installation
```bash
pip install -U langchain-google-genai
```
### Image utilities
To use image utility methods, like loading images from GCS urls, install with extras group 'images':
```bash
pip install -e "langchain-google-genai[images]"
```
## Chat Models
This package contains the `ChatGoogleGenerativeAI` class, which is the recommended way to interface with the Google Gemini series of models.
To use, install the requirements, and configure your environment.
```bash
export GOOGLE_API_KEY=your-api-key
```
Then initialize
```python
from langchain_google_genai import ChatGoogleGenerativeAI
llm = ChatGoogleGenerativeAI(model="gemini-pro")
llm.invoke("Sing a ballad of LangChain.")
```
#### Multimodal inputs
Gemini vision model supports image inputs when providing a single chat message. Example:
```
from langchain_core.messages import HumanMessage
from langchain_google_genai import ChatGoogleGenerativeAI
llm = ChatGoogleGenerativeAI(model="gemini-pro-vision")
# example
message = HumanMessage(
content=[
{
"type": "text",
"text": "What's in this image?",
}, # You can optionally provide text parts
{"type": "image_url", "image_url": "https://picsum.photos/seed/picsum/200/300"},
]
)
llm.invoke([message])
```
The value of `image_url` can be any of the following:
- A public image URL
- An accessible gcs file (e.g., "gcs://path/to/file.png")
- A local file path
- A base64 encoded image (e.g., ``)
- A PIL image
## Embeddings
This package also adds support for google's embeddings models.
```
from langchain_google_genai import GoogleGenerativeAIEmbeddings
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
embeddings.embed_query("hello, world!")
```
https://github.com/langchain-ai/langchain-google/tree/main/libs/genai

@ -1,69 +0,0 @@
"""**LangChain Google Generative AI Integration**
This module integrates Google's Generative AI models, specifically the Gemini series, with the LangChain framework. It provides classes for interacting with chat models and generating embeddings, leveraging Google's advanced AI capabilities.
**Chat Models**
The `ChatGoogleGenerativeAI` class is the primary interface for interacting with Google's Gemini chat models. It allows users to send and receive messages using a specified Gemini model, suitable for various conversational AI applications.
**LLMs**
The `GoogleGenerativeAI` class is the primary interface for interacting with Google's Gemini LLMs. It allows users to generate text using a specified Gemini model.
**Embeddings**
The `GoogleGenerativeAIEmbeddings` class provides functionalities to generate embeddings using Google's models.
These embeddings can be used for a range of NLP tasks, including semantic analysis, similarity comparisons, and more.
**Installation**
To install the package, use pip:
```python
pip install -U langchain-google-genai
```
## Using Chat Models
After setting up your environment with the required API key, you can interact with the Google Gemini models.
```python
from langchain_google_genai import ChatGoogleGenerativeAI
llm = ChatGoogleGenerativeAI(model="gemini-pro")
llm.invoke("Sing a ballad of LangChain.")
```
## Using LLMs
The package also supports generating text with Google's models.
```python
from langchain_google_genai import GoogleGenerativeAI
llm = GoogleGenerativeAI(model="gemini-pro")
llm.invoke("Once upon a time, a library called LangChain")
```
## Embedding Generation
The package also supports creating embeddings with Google's models, useful for textual similarity and other NLP applications.
```python
from langchain_google_genai import GoogleGenerativeAIEmbeddings
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
embeddings.embed_query("hello, world!")
```
""" # noqa: E501
from langchain_google_genai._enums import HarmBlockThreshold, HarmCategory
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
from langchain_google_genai.llms import GoogleGenerativeAI
__all__ = [
"ChatGoogleGenerativeAI",
"GoogleGenerativeAIEmbeddings",
"GoogleGenerativeAI",
"HarmBlockThreshold",
"HarmCategory",
]

@ -1,4 +0,0 @@
class GoogleGenerativeAIError(Exception):
"""
Custom exception class for errors associated with the `Google GenAI` API.
"""

@ -1,6 +0,0 @@
from google.generativeai.types.safety_types import ( # type: ignore
HarmBlockThreshold,
HarmCategory,
)
__all__ = ["HarmBlockThreshold", "HarmCategory"]

@ -1,116 +0,0 @@
from __future__ import annotations
from typing import (
Dict,
List,
Type,
Union,
)
import google.ai.generativelanguage as glm
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import BaseTool
from langchain_core.utils.json_schema import dereference_refs
FunctionCallType = Union[BaseTool, Type[BaseModel], Dict]
TYPE_ENUM = {
"string": glm.Type.STRING,
"number": glm.Type.NUMBER,
"integer": glm.Type.INTEGER,
"boolean": glm.Type.BOOLEAN,
"array": glm.Type.ARRAY,
"object": glm.Type.OBJECT,
}
def convert_to_genai_function_declarations(
function_calls: List[FunctionCallType],
) -> List[glm.Tool]:
return [
glm.Tool(
function_declarations=[_convert_to_genai_function(fc)],
)
for fc in function_calls
]
def _convert_to_genai_function(fc: FunctionCallType) -> glm.FunctionDeclaration:
if isinstance(fc, BaseTool):
return _convert_tool_to_genai_function(fc)
elif isinstance(fc, type) and issubclass(fc, BaseModel):
return _convert_pydantic_to_genai_function(fc)
elif isinstance(fc, dict):
return glm.FunctionDeclaration(
name=fc["name"],
description=fc.get("description"),
parameters={
"properties": {
k: {
"type_": TYPE_ENUM[v["type"]],
"description": v.get("description"),
}
for k, v in fc["parameters"]["properties"].items()
},
"required": fc["parameters"].get("required", []),
"type_": TYPE_ENUM[fc["parameters"]["type"]],
},
)
else:
raise ValueError(f"Unsupported function call type {fc}")
def _convert_tool_to_genai_function(tool: BaseTool) -> glm.FunctionDeclaration:
if tool.args_schema:
schema = dereference_refs(tool.args_schema.schema())
schema.pop("definitions", None)
return glm.FunctionDeclaration(
name=tool.name or schema["title"],
description=tool.description or schema["description"],
parameters={
"properties": {
k: {
"type_": TYPE_ENUM[v["type"]],
"description": v.get("description"),
}
for k, v in schema["properties"].items()
},
"required": schema["required"],
"type_": TYPE_ENUM[schema["type"]],
},
)
else:
return glm.FunctionDeclaration(
name=tool.name,
description=tool.description,
parameters={
"properties": {
"__arg1": {"type_": TYPE_ENUM["string"]},
},
"required": ["__arg1"],
"type_": TYPE_ENUM["object"],
},
)
def _convert_pydantic_to_genai_function(
pydantic_model: Type[BaseModel],
) -> glm.FunctionDeclaration:
schema = dereference_refs(pydantic_model.schema())
schema.pop("definitions", None)
return glm.FunctionDeclaration(
name=schema["title"],
description=schema.get("description", ""),
parameters={
"properties": {
k: {
"type_": TYPE_ENUM[v["type"]],
"description": v.get("description"),
}
for k, v in schema["properties"].items()
},
"required": schema["required"],
"type_": TYPE_ENUM[schema["type"]],
},
)

@ -1,676 +0,0 @@
from __future__ import annotations
import base64
import json
import logging
import os
from io import BytesIO
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
cast,
)
from urllib.parse import urlparse
import google.ai.generativelanguage as glm
import google.api_core
# TODO: remove ignore once the google package is published with types
import google.generativeai as genai # type: ignore[import]
import proto # type: ignore[import]
import requests
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import SecretStr, root_validator
from langchain_core.utils import get_from_dict_or_env
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from langchain_google_genai._common import GoogleGenerativeAIError
from langchain_google_genai._function_utils import (
convert_to_genai_function_declarations,
)
from langchain_google_genai.llms import GoogleModelFamily, _BaseGoogleGenerativeAI
IMAGE_TYPES: Tuple = ()
try:
import PIL
from PIL.Image import Image
IMAGE_TYPES = IMAGE_TYPES + (Image,)
except ImportError:
PIL = None # type: ignore
Image = None # type: ignore
logger = logging.getLogger(__name__)
class ChatGoogleGenerativeAIError(GoogleGenerativeAIError):
"""
Custom exception class for errors associated with the `Google GenAI` API.
This exception is raised when there are specific issues related to the
Google genai API usage in the ChatGoogleGenerativeAI class, such as unsupported
message types or roles.
"""
def _create_retry_decorator() -> Callable[[Any], Any]:
"""
Creates and returns a preconfigured tenacity retry decorator.
The retry decorator is configured to handle specific Google API exceptions
such as ResourceExhausted and ServiceUnavailable. It uses an exponential
backoff strategy for retries.
Returns:
Callable[[Any], Any]: A retry decorator configured for handling specific
Google API exceptions.
"""
multiplier = 2
min_seconds = 1
max_seconds = 60
max_retries = 10
return retry(
reraise=True,
stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def _chat_with_retry(generation_method: Callable, **kwargs: Any) -> Any:
"""
Executes a chat generation method with retry logic using tenacity.
This function is a wrapper that applies a retry mechanism to a provided
chat generation function. It is useful for handling intermittent issues
like network errors or temporary service unavailability.
Args:
generation_method (Callable): The chat generation method to be executed.
**kwargs (Any): Additional keyword arguments to pass to the generation method.
Returns:
Any: The result from the chat generation method.
"""
retry_decorator = _create_retry_decorator()
@retry_decorator
def _chat_with_retry(**kwargs: Any) -> Any:
try:
return generation_method(**kwargs)
# Do not retry for these errors.
except google.api_core.exceptions.FailedPrecondition as exc:
if "location is not supported" in exc.message:
error_msg = (
"Your location is not supported by google-generativeai "
"at the moment. Try to use ChatVertexAI LLM from "
"langchain_google_vertexai."
)
raise ValueError(error_msg)
except google.api_core.exceptions.InvalidArgument as e:
raise ChatGoogleGenerativeAIError(
f"Invalid argument provided to Gemini: {e}"
) from e
except Exception as e:
raise e
return _chat_with_retry(**kwargs)
async def _achat_with_retry(generation_method: Callable, **kwargs: Any) -> Any:
"""
Executes a chat generation method with retry logic using tenacity.
This function is a wrapper that applies a retry mechanism to a provided
chat generation function. It is useful for handling intermittent issues
like network errors or temporary service unavailability.
Args:
generation_method (Callable): The chat generation method to be executed.
**kwargs (Any): Additional keyword arguments to pass to the generation method.
Returns:
Any: The result from the chat generation method.
"""
retry_decorator = _create_retry_decorator()
from google.api_core.exceptions import InvalidArgument # type: ignore
@retry_decorator
async def _achat_with_retry(**kwargs: Any) -> Any:
try:
return await generation_method(**kwargs)
except InvalidArgument as e:
# Do not retry for these errors.
raise ChatGoogleGenerativeAIError(
f"Invalid argument provided to Gemini: {e}"
) from e
except Exception as e:
raise e
return await _achat_with_retry(**kwargs)
def _is_openai_parts_format(part: dict) -> bool:
return "type" in part
def _is_vision_model(model: str) -> bool:
return "vision" in model
def _is_url(s: str) -> bool:
try:
result = urlparse(s)
return all([result.scheme, result.netloc])
except Exception as e:
logger.debug(f"Unable to parse URL: {e}")
return False
def _is_b64(s: str) -> bool:
return s.startswith("data:image")
def _load_image_from_gcs(path: str, project: Optional[str] = None) -> Image:
try:
from google.cloud import storage # type: ignore[attr-defined]
except ImportError:
raise ImportError(
"google-cloud-storage is required to load images from GCS."
" Install it with `pip install google-cloud-storage`"
)
if PIL is None:
raise ImportError(
"PIL is required to load images. Please install it "
"with `pip install pillow`"
)
gcs_client = storage.Client(project=project)
pieces = path.split("/")
blobs = list(gcs_client.list_blobs(pieces[2], prefix="/".join(pieces[3:])))
if len(blobs) > 1:
raise ValueError(f"Found more than one candidate for {path}!")
img_bytes = blobs[0].download_as_bytes()
return PIL.Image.open(BytesIO(img_bytes))
def _url_to_pil(image_source: str) -> Image:
if PIL is None:
raise ImportError(
"PIL is required to load images. Please install it "
"with `pip install pillow`"
)
try:
if isinstance(image_source, IMAGE_TYPES):
return image_source # type: ignore[return-value]
elif _is_url(image_source):
if image_source.startswith("gs://"):
return _load_image_from_gcs(image_source)
response = requests.get(image_source)
response.raise_for_status()
return PIL.Image.open(BytesIO(response.content))
elif _is_b64(image_source):
_, encoded = image_source.split(",", 1)
data = base64.b64decode(encoded)
return PIL.Image.open(BytesIO(data))
elif os.path.exists(image_source):
return PIL.Image.open(image_source)
else:
raise ValueError(
"The provided string is not a valid URL, base64, or file path."
)
except Exception as e:
raise ValueError(f"Unable to process the provided image source: {e}")
def _convert_to_parts(
raw_content: Union[str, Sequence[Union[str, dict]]],
) -> List[genai.types.PartType]:
"""Converts a list of LangChain messages into a google parts."""
parts = []
content = [raw_content] if isinstance(raw_content, str) else raw_content
for part in content:
if isinstance(part, str):
parts.append(genai.types.PartDict(text=part))
elif isinstance(part, Mapping):
# OpenAI Format
if _is_openai_parts_format(part):
if part["type"] == "text":
parts.append({"text": part["text"]})
elif part["type"] == "image_url":
img_url = part["image_url"]
if isinstance(img_url, dict):
if "url" not in img_url:
raise ValueError(
f"Unrecognized message image format: {img_url}"
)
img_url = img_url["url"]
parts.append({"inline_data": _url_to_pil(img_url)})
else:
raise ValueError(f"Unrecognized message part type: {part['type']}")
else:
# Yolo
logger.warning(
"Unrecognized message part format. Assuming it's a text part."
)
parts.append(part)
else:
# TODO: Maybe some of Google's native stuff
# would hit this branch.
raise ChatGoogleGenerativeAIError(
"Gemini only supports text and inline_data parts."
)
return parts
def _parse_chat_history(
input_messages: Sequence[BaseMessage], convert_system_message_to_human: bool = False
) -> List[genai.types.ContentDict]:
messages: List[genai.types.MessageDict] = []
raw_system_message: Optional[SystemMessage] = None
for i, message in enumerate(input_messages):
if (
i == 0
and isinstance(message, SystemMessage)
and not convert_system_message_to_human
):
raise ValueError(
"""SystemMessages are not yet supported!
To automatically convert the leading SystemMessage to a HumanMessage,
set `convert_system_message_to_human` to True. Example:
llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human=True)
"""
)
elif i == 0 and isinstance(message, SystemMessage):
raw_system_message = message
continue
elif isinstance(message, AIMessage):
role = "model"
raw_function_call = message.additional_kwargs.get("function_call")
if raw_function_call:
function_call = glm.FunctionCall(
{
"name": raw_function_call["name"],
"args": json.loads(raw_function_call["arguments"]),
}
)
parts = [glm.Part(function_call=function_call)]
else:
parts = _convert_to_parts(message.content)
elif isinstance(message, HumanMessage):
role = "user"
parts = _convert_to_parts(message.content)
elif isinstance(message, FunctionMessage):
role = "user"
response: Any
if not isinstance(message.content, str):
response = message.content
else:
try:
response = json.loads(message.content)
except json.JSONDecodeError:
response = message.content # leave as str representation
parts = [
glm.Part(
function_response=glm.FunctionResponse(
name=message.name,
response=(
{"output": response}
if not isinstance(response, dict)
else response
),
)
)
]
else:
raise ValueError(
f"Unexpected message with type {type(message)} at the position {i}."
)
if raw_system_message:
if role == "model":
raise ValueError(
"SystemMessage should be followed by a HumanMessage and "
"not by AIMessage."
)
parts = _convert_to_parts(raw_system_message.content) + parts
raw_system_message = None
messages.append({"role": role, "parts": parts})
return messages
def _parse_response_candidate(
response_candidate: glm.Candidate, stream: bool
) -> AIMessage:
first_part = response_candidate.content.parts[0]
if first_part.function_call:
function_call = proto.Message.to_dict(first_part.function_call)
function_call["arguments"] = json.dumps(function_call.pop("args", {}))
return (AIMessageChunk if stream else AIMessage)(
content="", additional_kwargs={"function_call": function_call}
)
else:
parts = response_candidate.content.parts
if len(parts) == 1 and parts[0].text:
content: Union[str, List[Union[str, Dict]]] = parts[0].text
else:
content = [proto.Message.to_dict(part) for part in parts]
return (AIMessageChunk if stream else AIMessage)(
content=content, additional_kwargs={}
)
def _response_to_result(
response: glm.GenerateContentResponse,
stream: bool = False,
) -> ChatResult:
"""Converts a PaLM API response into a LangChain ChatResult."""
llm_output = {"prompt_feedback": proto.Message.to_dict(response.prompt_feedback)}
generations: List[ChatGeneration] = []
for candidate in response.candidates:
generation_info = {}
if candidate.finish_reason:
generation_info["finish_reason"] = candidate.finish_reason.name
generation_info["safety_ratings"] = [
proto.Message.to_dict(safety_rating, use_integers_for_enums=False)
for safety_rating in candidate.safety_ratings
]
generations.append(
(ChatGenerationChunk if stream else ChatGeneration)(
message=_parse_response_candidate(candidate, stream=stream),
generation_info=generation_info,
)
)
if not response.candidates:
# Likely a "prompt feedback" violation (e.g., toxic input)
# Raising an error would be different than how OpenAI handles it,
# so we'll just log a warning and continue with an empty message.
logger.warning(
"Gemini produced an empty response. Continuing with empty message\n"
f"Feedback: {response.prompt_feedback}"
)
generations = [
(ChatGenerationChunk if stream else ChatGeneration)(
message=(AIMessageChunk if stream else AIMessage)(content=""),
generation_info={},
)
]
return ChatResult(generations=generations, llm_output=llm_output)
class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
"""`Google Generative AI` Chat models API.
To use, you must have either:
1. The ``GOOGLE_API_KEY``` environment variable set with your API key, or
2. Pass your API key using the google_api_key kwarg to the ChatGoogle
constructor.
Example:
.. code-block:: python
from langchain_google_genai import ChatGoogleGenerativeAI
chat = ChatGoogleGenerativeAI(model="gemini-pro")
chat.invoke("Write me a ballad about LangChain")
"""
client: Any #: :meta private:
convert_system_message_to_human: bool = False
"""Whether to merge any leading SystemMessage into the following HumanMessage.
Gemini does not support system messages; any unsupported messages will
raise an error."""
class Config:
allow_population_by_field_name = True
@property
def lc_secrets(self) -> Dict[str, str]:
return {"google_api_key": "GOOGLE_API_KEY"}
@property
def _llm_type(self) -> str:
return "chat-google-generative-ai"
@classmethod
def is_lc_serializable(self) -> bool:
return True
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validates params and passes them to google-generativeai package."""
google_api_key = get_from_dict_or_env(
values, "google_api_key", "GOOGLE_API_KEY"
)
if isinstance(google_api_key, SecretStr):
google_api_key = google_api_key.get_secret_value()
genai.configure(
api_key=google_api_key,
transport=values.get("transport"),
client_options=values.get("client_options"),
)
if (
values.get("temperature") is not None
and not 0 <= values["temperature"] <= 1
):
raise ValueError("temperature must be in the range [0.0, 1.0]")
if values.get("top_p") is not None and not 0 <= values["top_p"] <= 1:
raise ValueError("top_p must be in the range [0.0, 1.0]")
if values.get("top_k") is not None and values["top_k"] <= 0:
raise ValueError("top_k must be positive")
model = values["model"]
values["client"] = genai.GenerativeModel(model_name=model)
return values
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"model": self.model,
"temperature": self.temperature,
"top_k": self.top_k,
"n": self.n,
"safety_settings": self.safety_settings,
}
def _prepare_params(
self, stop: Optional[List[str]], **kwargs: Any
) -> Dict[str, Any]:
gen_config = {
k: v
for k, v in {
"candidate_count": self.n,
"temperature": self.temperature,
"stop_sequences": stop,
"max_output_tokens": self.max_output_tokens,
"top_k": self.top_k,
"top_p": self.top_p,
}.items()
if v is not None
}
if "generation_config" in kwargs:
gen_config = {**gen_config, **kwargs.pop("generation_config")}
params = {"generation_config": gen_config, **kwargs}
return params
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params, chat, message = self._prepare_chat(
messages,
stop=stop,
**kwargs,
)
response: genai.types.GenerateContentResponse = _chat_with_retry(
content=message,
**params,
generation_method=chat.send_message,
)
return _response_to_result(response)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params, chat, message = self._prepare_chat(
messages,
stop=stop,
**kwargs,
)
response: genai.types.GenerateContentResponse = await _achat_with_retry(
content=message,
**params,
generation_method=chat.send_message_async,
)
return _response_to_result(response)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params, chat, message = self._prepare_chat(
messages,
stop=stop,
**kwargs,
)
response: genai.types.GenerateContentResponse = _chat_with_retry(
content=message,
**params,
generation_method=chat.send_message,
stream=True,
)
for chunk in response:
_chat_result = _response_to_result(chunk, stream=True)
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
if run_manager:
run_manager.on_llm_new_token(gen.text)
yield gen
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
params, chat, message = self._prepare_chat(
messages,
stop=stop,
**kwargs,
)
async for chunk in await _achat_with_retry(
content=message,
**params,
generation_method=chat.send_message_async,
stream=True,
):
_chat_result = _response_to_result(chunk, stream=True)
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
if run_manager:
await run_manager.on_llm_new_token(gen.text)
yield gen
def _prepare_chat(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]:
client = self.client
functions = kwargs.pop("functions", None)
safety_settings = kwargs.pop("safety_settings", self.safety_settings)
if functions or safety_settings:
tools = (
convert_to_genai_function_declarations(functions) if functions else None
)
client = genai.GenerativeModel(
model_name=self.model, tools=tools, safety_settings=safety_settings
)
params = self._prepare_params(stop, **kwargs)
history = _parse_chat_history(
messages,
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history.pop()
chat = client.start_chat(history=history)
return params, chat, message
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.
Useful for checking if an input will fit in a model's context window.
Args:
text: The string input to tokenize.
Returns:
The integer number of tokens in the text.
"""
if self._model_family == GoogleModelFamily.GEMINI:
result = self.client.count_tokens(text)
token_count = result.total_tokens
else:
result = self.client.count_text_tokens(model=self.model, prompt=text)
token_count = result["token_count"]
return token_count

@ -1,115 +0,0 @@
from typing import Dict, List, Optional
# TODO: remove ignore once the google package is published with types
import google.generativeai as genai # type: ignore[import]
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_google_genai._common import GoogleGenerativeAIError
class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
"""`Google Generative AI Embeddings`.
To use, you must have either:
1. The ``GOOGLE_API_KEY``` environment variable set with your API key, or
2. Pass your API key using the google_api_key kwarg to the ChatGoogle
constructor.
Example:
.. code-block:: python
from langchain_google_genai import GoogleGenerativeAIEmbeddings
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
embeddings.embed_query("What's our Q1 revenue?")
"""
model: str = Field(
...,
description="The name of the embedding model to use. "
"Example: models/embedding-001",
)
task_type: Optional[str] = Field(
None,
description="The task type. Valid options include: "
"task_type_unspecified, retrieval_query, retrieval_document, "
"semantic_similarity, classification, and clustering",
)
google_api_key: Optional[SecretStr] = Field(
None,
description="The Google API key to use. If not provided, "
"the GOOGLE_API_KEY environment variable will be used.",
)
client_options: Optional[Dict] = Field(
None,
description=(
"A dictionary of client options to pass to the Google API client, "
"such as `api_endpoint`."
),
)
transport: Optional[str] = Field(
None,
description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
)
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validates params and passes them to google-generativeai package."""
google_api_key = get_from_dict_or_env(
values, "google_api_key", "GOOGLE_API_KEY"
)
if isinstance(google_api_key, SecretStr):
google_api_key = google_api_key.get_secret_value()
genai.configure(
api_key=google_api_key,
transport=values.get("transport"),
client_options=values.get("client_options"),
)
return values
def _embed(
self, texts: List[str], task_type: str, title: Optional[str] = None
) -> List[List[float]]:
task_type = self.task_type or "retrieval_document"
try:
result = genai.embed_content(
model=self.model,
content=texts,
task_type=task_type,
title=title,
)
except Exception as e:
raise GoogleGenerativeAIError(f"Error embedding content: {e}") from e
return result["embedding"]
def embed_documents(
self, texts: List[str], batch_size: int = 5
) -> List[List[float]]:
"""Embed a list of strings. Vertex AI currently
sets a max batch size of 5 strings.
Args:
texts: List[str] The list of strings to embed.
batch_size: [int] The batch size of embeddings to send to the model
Returns:
List of embeddings, one for each text.
"""
task_type = self.task_type or "retrieval_document"
return self._embed(texts, task_type=task_type)
def embed_query(self, text: str) -> List[float]:
"""Embed a text.
Args:
text: The text to embed.
Returns:
Embedding for the text.
"""
task_type = self.task_type or "retrieval_query"
return self._embed([text], task_type=task_type)[0]

@ -1,350 +0,0 @@
from __future__ import annotations
from enum import Enum, auto
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
import google.api_core
import google.generativeai as genai # type: ignore[import]
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_google_genai._enums import (
HarmBlockThreshold,
HarmCategory,
)
class GoogleModelFamily(str, Enum):
GEMINI = auto()
PALM = auto()
@classmethod
def _missing_(cls, value: Any) -> Optional["GoogleModelFamily"]:
if "gemini" in value.lower():
return GoogleModelFamily.GEMINI
elif "text-bison" in value.lower():
return GoogleModelFamily.PALM
return None
def _create_retry_decorator(
llm: BaseLLM,
*,
max_retries: int = 1,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Creates a retry decorator for Vertex / Palm LLMs."""
errors = [
google.api_core.exceptions.ResourceExhausted,
google.api_core.exceptions.ServiceUnavailable,
google.api_core.exceptions.Aborted,
google.api_core.exceptions.DeadlineExceeded,
google.api_core.exceptions.GoogleAPIError,
]
decorator = create_base_retry_decorator(
error_types=errors, max_retries=max_retries, run_manager=run_manager
)
return decorator
def _completion_with_retry(
llm: GoogleGenerativeAI,
prompt: LanguageModelInput,
is_gemini: bool = False,
stream: bool = False,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(
llm, max_retries=llm.max_retries, run_manager=run_manager
)
@retry_decorator
def _completion_with_retry(
prompt: LanguageModelInput, is_gemini: bool, stream: bool, **kwargs: Any
) -> Any:
generation_config = kwargs.get("generation_config", {})
error_msg = (
"Your location is not supported by google-generativeai at the moment. "
"Try to use VertexAI LLM from langchain_google_vertexai"
)
try:
if is_gemini:
return llm.client.generate_content(
contents=prompt,
stream=stream,
generation_config=generation_config,
safety_settings=kwargs.pop("safety_settings", None),
)
return llm.client.generate_text(prompt=prompt, **kwargs)
except google.api_core.exceptions.FailedPrecondition as exc:
if "location is not supported" in exc.message:
raise ValueError(error_msg)
return _completion_with_retry(
prompt=prompt, is_gemini=is_gemini, stream=stream, **kwargs
)
def _strip_erroneous_leading_spaces(text: str) -> str:
"""Strip erroneous leading spaces from text.
The PaLM API will sometimes erroneously return a single leading space in all
lines > 1. This function strips that space.
"""
has_leading_space = all(not line or line[0] == " " for line in text.split("\n")[1:])
if has_leading_space:
return text.replace("\n ", "\n")
else:
return text
class _BaseGoogleGenerativeAI(BaseModel):
"""Base class for Google Generative AI LLMs"""
model: str = Field(
...,
description="""The name of the model to use.
Supported examples:
- gemini-pro
- models/text-bison-001""",
)
"""Model name to use."""
google_api_key: Optional[SecretStr] = None
temperature: float = 0.7
"""Run inference with this temperature. Must by in the closed interval
[0.0, 1.0]."""
top_p: Optional[float] = None
"""Decode using nucleus sampling: consider the smallest set of tokens whose
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
top_k: Optional[int] = None
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
Must be positive."""
max_output_tokens: Optional[int] = None
"""Maximum number of tokens to include in a candidate. Must be greater than zero.
If unset, will default to 64."""
n: int = 1
"""Number of chat completions to generate for each prompt. Note that the API may
not return the full n completions if duplicates are generated."""
max_retries: int = 6
"""The maximum number of retries to make when generating."""
client_options: Optional[Dict] = Field(
None,
description=(
"A dictionary of client options to pass to the Google API client, "
"such as `api_endpoint`."
),
)
transport: Optional[str] = Field(
None,
description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
)
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None
"""The default safety settings to use for all generations.
For example:
from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory
safety_settings = {
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
}
""" # noqa: E501
@property
def lc_secrets(self) -> Dict[str, str]:
return {"google_api_key": "GOOGLE_API_KEY"}
@property
def _model_family(self) -> str:
return GoogleModelFamily(self.model)
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"model": self.model,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"max_output_tokens": self.max_output_tokens,
"candidate_count": self.n,
}
class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
"""Google GenerativeAI models.
Example:
.. code-block:: python
from langchain_google_genai import GoogleGenerativeAI
llm = GoogleGenerativeAI(model="gemini-pro")
"""
client: Any #: :meta private:
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validates params and passes them to google-generativeai package."""
google_api_key = get_from_dict_or_env(
values, "google_api_key", "GOOGLE_API_KEY"
)
model_name = values["model"]
safety_settings = values["safety_settings"]
if isinstance(google_api_key, SecretStr):
google_api_key = google_api_key.get_secret_value()
genai.configure(
api_key=google_api_key,
transport=values.get("transport"),
client_options=values.get("client_options"),
)
if safety_settings and (
not GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI
):
raise ValueError("Safety settings are only supported for Gemini models")
if GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI:
values["client"] = genai.GenerativeModel(
model_name=model_name, safety_settings=safety_settings
)
else:
values["client"] = genai
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
raise ValueError("top_p must be in the range [0.0, 1.0]")
if values["top_k"] is not None and values["top_k"] <= 0:
raise ValueError("top_k must be positive")
if values["max_output_tokens"] is not None and values["max_output_tokens"] <= 0:
raise ValueError("max_output_tokens must be greater than zero")
return values
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
generations: List[List[Generation]] = []
generation_config = {
"stop_sequences": stop,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"max_output_tokens": self.max_output_tokens,
"candidate_count": self.n,
}
for prompt in prompts:
if self._model_family == GoogleModelFamily.GEMINI:
res = _completion_with_retry(
self,
prompt=prompt,
stream=False,
is_gemini=True,
run_manager=run_manager,
generation_config=generation_config,
safety_settings=kwargs.pop("safety_settings", None),
)
candidates = [
"".join([p.text for p in c.content.parts]) for c in res.candidates
]
generations.append([Generation(text=c) for c in candidates])
else:
res = _completion_with_retry(
self,
model=self.model,
prompt=prompt,
stream=False,
is_gemini=False,
run_manager=run_manager,
**generation_config,
)
prompt_generations = []
for candidate in res.candidates:
raw_text = candidate["output"]
stripped_text = _strip_erroneous_leading_spaces(raw_text)
prompt_generations.append(Generation(text=stripped_text))
generations.append(prompt_generations)
return LLMResult(generations=generations)
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
generation_config = kwargs.get("generation_config", {})
if stop:
generation_config["stop_sequences"] = stop
for stream_resp in _completion_with_retry(
self,
prompt,
stream=True,
is_gemini=True,
run_manager=run_manager,
generation_config=generation_config,
safety_settings=kwargs.pop("safety_settings", None),
**kwargs,
):
chunk = GenerationChunk(text=stream_resp.text)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
stream_resp.text,
chunk=chunk,
verbose=self.verbose,
)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "google_palm"
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.
Useful for checking if an input will fit in a model's context window.
Args:
text: The string input to tokenize.
Returns:
The integer number of tokens in the text.
"""
if self._model_family == GoogleModelFamily.GEMINI:
result = self.client.count_tokens(text)
token_count = result.total_tokens
else:
result = self.client.count_text_tokens(model=self.model, prompt=text)
token_count = result["token_count"]
return token_count

File diff suppressed because it is too large Load Diff

@ -1,107 +0,0 @@
[tool.poetry]
name = "langchain-google-genai"
version = "0.0.9"
description = "An integration package connecting Google's genai package and LangChain"
authors = []
readme = "README.md"
repository = "https://github.com/langchain-ai/langchain"
license = "MIT"
[tool.poetry.urls]
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/google-genai"
[tool.poetry.dependencies]
python = ">=3.9,<4.0"
langchain-core = "^0.1"
google-generativeai = "^0.3.1"
pillow = { version = "^10.1.0", optional = true }
[tool.poetry.extras]
images = ["pillow"]
[tool.poetry.group.test]
optional = true
[tool.poetry.group.test.dependencies]
pytest = "^7.3.0"
freezegun = "^1.2.2"
pytest-mock = "^3.10.0"
syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
langchain-core = { path = "../../core", develop = true }
numpy = "^1.26.2"
[tool.poetry.group.codespell]
optional = true
[tool.poetry.group.codespell.dependencies]
codespell = "^2.2.0"
[tool.poetry.group.test_integration]
optional = true
[tool.poetry.group.test_integration.dependencies]
pillow = "^10.1.0"
[tool.poetry.group.lint]
optional = true
[tool.poetry.group.lint.dependencies]
ruff = "^0.1.5"
[tool.poetry.group.typing.dependencies]
mypy = "^0.991"
langchain-core = { path = "../../core", develop = true }
types-requests = "^2.28.11.5"
types-google-cloud-ndb = "^2.2.0.1"
types-pillow = "^10.1.0.2"
[tool.poetry.group.dev]
optional = true
[tool.poetry.group.dev.dependencies]
langchain-core = { path = "../../core", develop = true }
pillow = "^10.1.0"
types-requests = "^2.31.0.10"
types-pillow = "^10.1.0.2"
types-google-cloud-ndb = "^2.2.0.1"
[tool.ruff.lint]
select = [
"E", # pycodestyle
"F", # pyflakes
"I", # isort
"T201", # print
]
[tool.mypy]
disallow_untyped_defs = "True"
[tool.coverage.run]
omit = ["tests/*"]
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
# --strict-markers will raise errors on unknown marks.
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
#
# https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors.
#
# https://github.com/tophat/syrupy
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
# Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [
"requires: mark tests as requiring a specific library",
"asyncio: mark tests as requiring asyncio",
"compile: mark placeholder test used to compile integration tests without running them",
]
asyncio_mode = "auto"

@ -1,17 +0,0 @@
import sys
import traceback
from importlib.machinery import SourceFileLoader
if __name__ == "__main__":
files = sys.argv[1:]
has_failure = False
for file in files:
try:
SourceFileLoader("x", file).load_module()
except Exception:
has_faillure = True
print(file) # noqa: T201
traceback.print_exc()
print() # noqa: T201
sys.exit(1 if has_failure else 0)

@ -1,27 +0,0 @@
#!/bin/bash
#
# This script searches for lines starting with "import pydantic" or "from pydantic"
# in tracked files within a Git repository.
#
# Usage: ./scripts/check_pydantic.sh /path/to/repository
# Check if a path argument is provided
if [ $# -ne 1 ]; then
echo "Usage: $0 /path/to/repository"
exit 1
fi
repository_path="$1"
# Search for lines matching the pattern within the specified repository
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
# Check if any matching lines were found
if [ -n "$result" ]; then
echo "ERROR: The following lines need to be updated:"
echo "$result"
echo "Please replace the code with an import from langchain_core.pydantic_v1."
echo "For example, replace 'from pydantic import BaseModel'"
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
exit 1
fi

@ -1,17 +0,0 @@
#!/bin/bash
set -eu
# Initialize a variable to keep track of errors
errors=0
# make sure not importing from langchain or langchain_experimental
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
# Decide on an exit status based on the errors
if [ "$errors" -gt 0 ]; then
exit 1
else
exit 0
fi

@ -1,228 +0,0 @@
"""Test ChatGoogleGenerativeAI chat model."""
from typing import Generator
import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_google_genai import (
ChatGoogleGenerativeAI,
HarmBlockThreshold,
HarmCategory,
)
from langchain_google_genai.chat_models import ChatGoogleGenerativeAIError
_MODEL = "gemini-pro" # TODO: Use nano when it's available.
_VISION_MODEL = "gemini-pro-vision"
_B64_string = """iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABhGlDQ1BJQ0MgUHJvZmlsZQAAeJx9kT1Iw0AcxV8/xCIVQTuIKGSoTi2IijhqFYpQIdQKrTqYXPoFTRqSFBdHwbXg4Mdi1cHFWVcHV0EQ/ABxdXFSdJES/5cUWsR4cNyPd/ced+8Af6PCVDM4DqiaZaSTCSGbWxW6XxHECPoRQ0hipj4niil4jq97+Ph6F+dZ3uf+HL1K3mSATyCeZbphEW8QT29aOud94ggrSQrxOXHMoAsSP3JddvmNc9FhP8+MGJn0PHGEWCh2sNzBrGSoxFPEUUXVKN+fdVnhvMVZrdRY6578heG8trLMdZrDSGIRSxAhQEYNZVRgIU6rRoqJNO0nPPxDjl8kl0yuMhg5FlCFCsnxg//B727NwuSEmxROAF0vtv0xCnTvAs26bX8f23bzBAg8A1da219tADOfpNfbWvQI6NsGLq7bmrwHXO4Ag0+6ZEiOFKDpLxSA9zP6phwwcAv0rLm9tfZx+gBkqKvUDXBwCIwVKXvd492hzt7+PdPq7wdzbXKn5swsVgAAA8lJREFUeJx90dtPHHUUB/Dz+81vZhb2wrDI3soUKBSRcisF21iqqCRNY01NTE0k8aHpi0k18VJfjOFvUF9M44MmGrHFQqSQiKSmFloL5c4CXW6Fhb0vO3ufvczMzweiBGI9+eW8ffI95/yQqqrwv4UxBgCfJ9w/2NfSVB+Nyn6/r+vdLo7H6FkYY6yoABR2PJujj34MSo/d/nHeVLYbydmIp/bEO0fEy/+NMcbTU4/j4Vs6Lr0ccKeYuUKWS4ABVCVHmRdszbfvTgfjR8kz5Jjs+9RREl9Zy2lbVK9wU3/kWLJLCXnqza1bfVe7b9jLbIeTMcYu13Jg/aMiPrCwVFcgtDiMhnxwJ/zXVDwSdVCVMRV7nqzl2i9e/fKrw8mqSp84e2sFj3Oj8/SrF/MaicmyYhAaXu58NPAbeAeyzY0NLecmh2+ODN3BewYBAkAY43giI3kebrnsRmvV9z2D4ciOa3EBAf31Tp9sMgdxMTFm6j74/Ogb70VCYQKAAIDCXkOAIC6pkYBWdwwnpHEdf6L9dJtJKPh95DZhzFKMEWRAGL927XpWTmMA+s8DAOBYAoR483l/iHZ/8bXoODl8b9UfyH72SXepzbyRJNvjFGHKMlhvMBze+cH9+4lEuOOlU2X1tVkFTU7Om03q080NDGXV1cflRpHwaaoiiiildB8jhDLZ7HDfz2Yidba6Vn2L4fhzFrNRKy5OZ2QOZ1U5W8VtqlVH/iUHcM933zZYWS7Wtj66zZr65bzGJQt0glHgudi9XVzEl4vKw2kUPhO020oPYI1qYc+2Xc0bRXFwTLY0VXa2VibD/lBaIXm1UChN5JSRUcQQ1Tk/47Cf3x8bY7y17Y17PVYTG1UkLPBFcqik7Zoa9JcLYoHBqHhXNgd6gS1k9EJ1TQ2l9EDy1saErmQ2kGpwGC2MLOtCM8nZEV1K0tKJtEksSm26J/rHg2zzmabKisq939nHzqUH7efzd4f/nPGW6NP8ybNFrOsWQhpoCuuhnJ4hAnPhFam01K4oQMjBg/mzBjVhuvw2O++KKT+BIVxJKzQECBDLF2qu2WTMmCovtDQ1f8iyoGkUADBCCGPsdnvTW2OtFm01VeB06msvdWlpPZU0wJRG85ns84umU3k+VyxeEcWqvYUBAGsUrbvme4be99HFeisP/pwUOIZaOqQX31ISgrKmZhLHtXNXuJq68orrr5/9mBCglCLAGGPyy81votEbcjlKLrC9E8mhH3wdHRdcyyvjidSlxjftPJpD+o25JYvRHGFoZDdks1mBQhxJu9uxvwEiXuHnHbLd1AAAAABJRU5ErkJggg==""" # noqa: E501
def test_chat_google_genai_stream() -> None:
"""Test streaming tokens from Gemini."""
llm = ChatGoogleGenerativeAI(model=_MODEL)
for token in llm.stream("This is a test. Say 'foo'"):
assert isinstance(token.content, str)
async def test_chat_google_genai_astream() -> None:
"""Test streaming tokens from Gemini."""
llm = ChatGoogleGenerativeAI(model=_MODEL)
async for token in llm.astream("This is a test. Say 'foo'"):
assert isinstance(token.content, str)
async def test_chat_google_genai_abatch() -> None:
"""Test streaming tokens from ChatGoogleGenerativeAI."""
llm = ChatGoogleGenerativeAI(model=_MODEL)
result = await llm.abatch(
["This is a test. Say 'foo'", "This is a test, say 'bar'"]
)
for token in result:
assert isinstance(token.content, str)
async def test_chat_google_genai_abatch_tags() -> None:
"""Test batch tokens from ChatGoogleGenerativeAI."""
llm = ChatGoogleGenerativeAI(model=_MODEL)
result = await llm.abatch(
["This is a test", "This is another test"], config={"tags": ["foo"]}
)
for token in result:
assert isinstance(token.content, str)
def test_chat_google_genai_batch() -> None:
"""Test batch tokens from ChatGoogleGenerativeAI."""
llm = ChatGoogleGenerativeAI(model=_MODEL)
result = llm.batch(["This is a test. Say 'foo'", "This is a test, say 'bar'"])
for token in result:
assert isinstance(token.content, str)
async def test_chat_google_genai_ainvoke() -> None:
"""Test invoke tokens from ChatGoogleGenerativeAI."""
llm = ChatGoogleGenerativeAI(model=_MODEL)
result = await llm.ainvoke("This is a test. Say 'foo'", config={"tags": ["foo"]})
assert isinstance(result.content, str)
def test_chat_google_genai_invoke() -> None:
"""Test invoke tokens from ChatGoogleGenerativeAI."""
llm = ChatGoogleGenerativeAI(model=_MODEL)
result = llm.invoke(
"This is a test. Say 'foo'",
config=dict(tags=["foo"]),
generation_config=dict(top_k=2, top_p=1, temperature=0.7),
)
assert isinstance(result.content, str)
assert not result.content.startswith(" ")
def test_chat_google_genai_invoke_multimodal() -> None:
messages: list = [
HumanMessage(
content=[
{
"type": "text",
"text": "Guess what's in this picture! You have 3 guesses.",
},
{
"type": "image_url",
"image_url": "data:image/png;base64," + _B64_string,
},
]
),
]
llm = ChatGoogleGenerativeAI(model=_VISION_MODEL)
response = llm.invoke(messages)
assert isinstance(response.content, str)
assert len(response.content.strip()) > 0
# Try streaming
for chunk in llm.stream(messages):
print(chunk) # noqa: T201
assert isinstance(chunk.content, str)
assert len(chunk.content.strip()) > 0
def test_chat_google_genai_invoke_multimodal_too_many_messages() -> None:
# Only supports 1 turn...
messages: list = [
HumanMessage(content="Hi there"),
AIMessage(content="Hi, how are you?"),
HumanMessage(
content=[
{
"type": "text",
"text": "I'm doing great! Guess what's in this picture!",
},
{
"type": "image_url",
"image_url": "data:image/png;base64," + _B64_string,
},
]
),
]
llm = ChatGoogleGenerativeAI(model=_VISION_MODEL)
with pytest.raises(ChatGoogleGenerativeAIError):
llm.invoke(messages)
def test_chat_google_genai_invoke_multimodal_invalid_model() -> None:
# need the vision model to support this.
messages: list = [
HumanMessage(
content=[
{
"type": "text",
"text": "I'm doing great! Guess what's in this picture!",
},
{
"type": "image_url",
"image_url": "data:image/png;base64," + _B64_string,
},
]
),
]
llm = ChatGoogleGenerativeAI(model=_MODEL)
with pytest.raises(ChatGoogleGenerativeAIError):
llm.invoke(messages)
def test_chat_google_genai_single_call_with_history() -> None:
model = ChatGoogleGenerativeAI(model=_MODEL)
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
response = model([message1, message2, message3])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_chat_google_genai_system_message_error() -> None:
model = ChatGoogleGenerativeAI(model=_MODEL)
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content="You're supposed to answer math questions.")
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
with pytest.raises(ValueError):
model([system_message, message1, message2, message3])
def test_chat_google_genai_system_message() -> None:
model = ChatGoogleGenerativeAI(model=_MODEL, convert_system_message_to_human=True)
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content="You're supposed to answer math questions.")
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
response = model([system_message, message1, message2, message3])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_generativeai_get_num_tokens_gemini() -> None:
llm = ChatGoogleGenerativeAI(temperature=0, model="gemini-pro")
output = llm.get_num_tokens("How are you?")
assert output == 4
def test_safety_settings_gemini() -> None:
safety_settings = {
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
# test with safety filters on bind
llm = ChatGoogleGenerativeAI(temperature=0, model="gemini-pro").bind(
safety_settings=safety_settings
)
output = llm.invoke("how to make a bomb?")
assert isinstance(output, AIMessage)
assert len(output.content) > 0
# test direct to stream
streamed_messages = []
output_stream = llm.stream("how to make a bomb?", safety_settings=safety_settings)
assert isinstance(output_stream, Generator)
for message in output_stream:
streamed_messages.append(message)
assert len(streamed_messages) > 0
# test as init param
llm = ChatGoogleGenerativeAI(
temperature=0, model="gemini-pro", safety_settings=safety_settings
)
out2 = llm.invoke("how to make a bomb")
assert isinstance(out2, AIMessage)
assert len(out2.content) > 0

@ -1,7 +0,0 @@
import pytest
@pytest.mark.compile
def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests."""
pass

@ -1,98 +0,0 @@
import numpy as np
import pytest
from langchain_google_genai._common import GoogleGenerativeAIError
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
_MODEL = "models/embedding-001"
@pytest.mark.parametrize(
"query",
[
"Hi",
"This is a longer query string to test the embedding functionality of the"
" model against the pickle rick?",
],
)
def test_embed_query_different_lengths(query: str) -> None:
"""Test embedding queries of different lengths."""
model = GoogleGenerativeAIEmbeddings(model=_MODEL)
result = model.embed_query(query)
assert len(result) == 768
@pytest.mark.parametrize(
"query",
[
"Hi",
"This is a longer query string to test the embedding functionality of the"
" model against the pickle rick?",
],
)
async def test_aembed_query_different_lengths(query: str) -> None:
"""Test embedding queries of different lengths."""
model = GoogleGenerativeAIEmbeddings(model=_MODEL)
result = await model.aembed_query(query)
assert len(result) == 768
def test_embed_documents() -> None:
"""Test embedding a query."""
model = GoogleGenerativeAIEmbeddings(
model=_MODEL,
)
result = model.embed_documents(["Hello world", "Good day, world"])
assert len(result) == 2
assert len(result[0]) == 768
assert len(result[1]) == 768
async def test_aembed_documents() -> None:
"""Test embedding a query."""
model = GoogleGenerativeAIEmbeddings(
model=_MODEL,
)
result = await model.aembed_documents(["Hello world", "Good day, world"])
assert len(result) == 2
assert len(result[0]) == 768
assert len(result[1]) == 768
def test_invalid_model_error_handling() -> None:
"""Test error handling with an invalid model name."""
with pytest.raises(GoogleGenerativeAIError):
GoogleGenerativeAIEmbeddings(model="invalid_model").embed_query("Hello world")
def test_invalid_api_key_error_handling() -> None:
"""Test error handling with an invalid API key."""
with pytest.raises(GoogleGenerativeAIError):
GoogleGenerativeAIEmbeddings(
model=_MODEL, google_api_key="invalid_key"
).embed_query("Hello world")
def test_embed_documents_consistency() -> None:
"""Test embedding consistency for the same document."""
model = GoogleGenerativeAIEmbeddings(model=_MODEL)
doc = "Consistent document for testing"
result1 = model.embed_documents([doc])
result2 = model.embed_documents([doc])
assert result1 == result2
def test_embed_documents_quality() -> None:
"""Smoke test embedding quality by comparing similar and dissimilar documents."""
model = GoogleGenerativeAIEmbeddings(model=_MODEL)
similar_docs = ["Document A", "Similar Document A"]
dissimilar_docs = ["Document A", "Completely Different Zebra"]
similar_embeddings = model.embed_documents(similar_docs)
dissimilar_embeddings = model.embed_documents(dissimilar_docs)
similar_distance = np.linalg.norm(
np.array(similar_embeddings[0]) - np.array(similar_embeddings[1])
)
dissimilar_distance = np.linalg.norm(
np.array(dissimilar_embeddings[0]) - np.array(dissimilar_embeddings[1])
)
assert similar_distance < dissimilar_distance

@ -1,84 +0,0 @@
"""Test ChatGoogleGenerativeAI function call."""
import json
from langchain_core.messages import AIMessage
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import tool
from langchain_google_genai.chat_models import (
ChatGoogleGenerativeAI,
)
def test_function_call() -> None:
functions = [
{
"name": "get_weather",
"description": "Determine weather in my location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": ["location"],
},
}
]
llm = ChatGoogleGenerativeAI(model="gemini-pro").bind(functions=functions)
res = llm.invoke("what weather is today in san francisco?")
assert res
assert res.additional_kwargs
assert "function_call" in res.additional_kwargs
assert "get_weather" == res.additional_kwargs["function_call"]["name"]
arguments_str = res.additional_kwargs["function_call"]["arguments"]
assert isinstance(arguments_str, str)
arguments = json.loads(arguments_str)
assert "location" in arguments
def test_tool_call() -> None:
@tool
def search_tool(query: str) -> str:
"""Searches the web for `query` and returns the result."""
raise NotImplementedError
llm = ChatGoogleGenerativeAI(model="gemini-pro").bind(functions=[search_tool])
response = llm.invoke("weather in san francisco")
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.content == ""
function_call = response.additional_kwargs.get("function_call")
assert function_call
assert function_call["name"] == "search_tool"
arguments_str = function_call.get("arguments")
assert arguments_str
arguments = json.loads(arguments_str)
assert "query" in arguments
class MyModel(BaseModel):
name: str
age: int
def test_pydantic_call() -> None:
llm = ChatGoogleGenerativeAI(model="gemini-pro").bind(functions=[MyModel])
response = llm.invoke("my name is Erick and I am 27 years old")
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.content == ""
function_call = response.additional_kwargs.get("function_call")
assert function_call
assert function_call["name"] == "MyModel"
arguments_str = function_call.get("arguments")
assert arguments_str
arguments = json.loads(arguments_str)
assert arguments == {
"name": "Erick",
"age": 27.0,
}

@ -1,106 +0,0 @@
"""Test Google GenerativeAI API wrapper.
Note: This test must be run with the GOOGLE_API_KEY environment variable set to a
valid API key.
"""
from typing import Generator
import pytest
from langchain_core.outputs import LLMResult
from langchain_google_genai import GoogleGenerativeAI, HarmBlockThreshold, HarmCategory
model_names = ["models/text-bison-001", "gemini-pro"]
@pytest.mark.parametrize(
"model_name",
model_names,
)
def test_google_generativeai_call(model_name: str) -> None:
"""Test valid call to Google GenerativeAI text API."""
if model_name:
llm = GoogleGenerativeAI(max_output_tokens=10, model=model_name)
else:
llm = GoogleGenerativeAI(max_output_tokens=10)
output = llm("Say foo:")
assert isinstance(output, str)
assert llm._llm_type == "google_palm"
if model_name and "gemini" in model_name:
assert llm.client.model_name == "models/gemini-pro"
else:
assert llm.model == "models/text-bison-001"
@pytest.mark.parametrize(
"model_name",
model_names,
)
def test_google_generativeai_generate(model_name: str) -> None:
n = 1 if model_name == "gemini-pro" else 2
llm = GoogleGenerativeAI(temperature=0.3, n=n, model=model_name)
output = llm.generate(["Say foo:"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 1
assert len(output.generations[0]) == n
def test_google_generativeai_get_num_tokens() -> None:
llm = GoogleGenerativeAI(model="models/text-bison-001")
output = llm.get_num_tokens("How are you?")
assert output == 4
async def test_google_generativeai_agenerate() -> None:
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
output = await llm.agenerate(["Please say foo:"])
assert isinstance(output, LLMResult)
def test_generativeai_stream() -> None:
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
outputs = list(llm.stream("Please say foo:"))
assert isinstance(outputs[0], str)
def test_generativeai_get_num_tokens_gemini() -> None:
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
output = llm.get_num_tokens("How are you?")
assert output == 4
def test_safety_settings_gemini() -> None:
# test with blocked prompt
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
output = llm.generate(prompts=["how to make a bomb?"])
assert isinstance(output, LLMResult)
assert len(output.generations[0]) == 0
# safety filters
safety_settings = {
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
# test with safety filters directly to generate
output = llm.generate(["how to make a bomb?"], safety_settings=safety_settings)
assert isinstance(output, LLMResult)
assert len(output.generations[0]) > 0
# test with safety filters directly to stream
streamed_messages = []
output_stream = llm.stream("how to make a bomb?", safety_settings=safety_settings)
assert isinstance(output_stream, Generator)
for message in output_stream:
streamed_messages.append(message)
assert len(streamed_messages) > 0
# test with safety filters on instantiation
llm = GoogleGenerativeAI(
model="gemini-pro",
safety_settings=safety_settings,
temperature=0,
)
output = llm.generate(prompts=["how to make a bomb?"])
assert isinstance(output, LLMResult)
assert len(output.generations[0]) > 0

@ -1,75 +0,0 @@
"""Test chat model integration."""
from typing import Dict, List, Union
import pytest
from langchain_core.messages import (
AIMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture
from langchain_google_genai.chat_models import (
ChatGoogleGenerativeAI,
_parse_chat_history,
)
def test_integration_initialization() -> None:
"""Test chat model initialization."""
ChatGoogleGenerativeAI(
model="gemini-nano",
google_api_key="...",
top_k=2,
top_p=1,
temperature=0.7,
n=2,
)
ChatGoogleGenerativeAI(
model="gemini-nano",
google_api_key="...",
top_k=2,
top_p=1,
temperature=0.7,
candidate_count=2,
)
def test_api_key_is_string() -> None:
chat = ChatGoogleGenerativeAI(model="gemini-nano", google_api_key="secret-api-key")
assert isinstance(chat.google_api_key, SecretStr)
def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> None:
chat = ChatGoogleGenerativeAI(model="gemini-nano", google_api_key="secret-api-key")
print(chat.google_api_key, end="") # noqa: T201
captured = capsys.readouterr()
assert captured.out == "**********"
def test_parse_history() -> None:
system_input = "You're supposed to answer math questions."
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content=system_input)
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
messages = [system_message, message1, message2, message3]
history = _parse_chat_history(messages, convert_system_message_to_human=True)
assert len(history) == 3
assert history[0] == {
"role": "user",
"parts": [{"text": system_input}, {"text": text_question1}],
}
assert history[1] == {"role": "model", "parts": [{"text": text_answer1}]}
@pytest.mark.parametrize("content", ['["a"]', '{"a":"b"}', "function output"])
def test_parse_function_history(content: Union[str, List[Union[str, Dict]]]) -> None:
function_message = FunctionMessage(name="search_tool", content=content)
_parse_chat_history([function_message], convert_system_message_to_human=True)

@ -1,38 +0,0 @@
"""Test embeddings model integration."""
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
def test_integration_initialization() -> None:
"""Test chat model initialization."""
GoogleGenerativeAIEmbeddings(
model="models/embedding-001",
google_api_key="...",
)
GoogleGenerativeAIEmbeddings(
model="models/embedding-001",
google_api_key="...",
task_type="retrieval_document",
)
def test_api_key_is_string() -> None:
embeddings = GoogleGenerativeAIEmbeddings(
model="models/embedding-001",
google_api_key="secret-api-key",
)
assert isinstance(embeddings.google_api_key, SecretStr)
def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> None:
embeddings = GoogleGenerativeAIEmbeddings(
model="models/embedding-001",
google_api_key="secret-api-key",
)
print(embeddings.google_api_key, end="") # noqa: T201
captured = capsys.readouterr()
assert captured.out == "**********"

@ -1,13 +0,0 @@
from langchain_google_genai import __all__
EXPECTED_ALL = [
"ChatGoogleGenerativeAI",
"GoogleGenerativeAIEmbeddings",
"GoogleGenerativeAI",
"HarmBlockThreshold",
"HarmCategory",
]
def test_all_imports() -> None:
assert sorted(EXPECTED_ALL) == sorted(__all__)

@ -1,8 +0,0 @@
from langchain_google_genai.llms import GoogleModelFamily
def test_model_family() -> None:
model = GoogleModelFamily("gemini-pro")
assert model == GoogleModelFamily.GEMINI
model = GoogleModelFamily("gemini-ultra")
assert model == GoogleModelFamily.GEMINI

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2023 LangChain, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

@ -1,61 +0,0 @@
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
# Default target executed when no arguments are given to make.
all: help
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/
integration_tests: TEST_FILE = tests/integration_tests/
test integration_tests:
poetry run pytest $(TEST_FILE)
tests:
poetry run pytest $(TEST_FILE)
######################
# LINTING AND FORMATTING
######################
# Define a variable for Python and notebook files.
PYTHON_FILES=.
MYPY_CACHE=.mypy_cache
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/google-vertexai --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=langchain_google_vertexai
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
poetry run ruff .
poetry run ruff format $(PYTHON_FILES) --diff
poetry run ruff --select I $(PYTHON_FILES)
mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
poetry run ruff format $(PYTHON_FILES)
poetry run ruff --select I --fix $(PYTHON_FILES)
spell_check:
poetry run codespell --toml pyproject.toml
spell_fix:
poetry run codespell --toml pyproject.toml -w
check_imports: $(shell find langchain_google_vertexai -name '*.py')
poetry run python ./scripts/check_imports.py $^
######################
# HELP
######################
help:
@echo '----'
@echo 'check_imports - check imports'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'test - run unit tests'
@echo 'tests - run unit tests'
@echo 'test TEST_FILE=<test_file> - run all tests in file'

@ -1,100 +1,3 @@
# langchain-google-vertexai
This package has moved!
This package contains the LangChain integrations for Google Cloud generative models.
## Installation
```bash
pip install -U langchain-google-vertexai
```
## Chat Models
`ChatVertexAI` class exposes models such as `gemini-pro` and `chat-bison`.
To use, you should have Google Cloud project with APIs enabled, and configured credentials. Initialize the model as:
```python
from langchain_google_vertexai import ChatVertexAI
llm = ChatVertexAI(model_name="gemini-pro")
llm.invoke("Sing a ballad of LangChain.")
```
You can use other models, e.g. `chat-bison`:
```python
from langchain_google_vertexai import ChatVertexAI
llm = ChatVertexAI(model_name="chat-bison", temperature=0.3)
llm.invoke("Sing a ballad of LangChain.")
```
#### Multimodal inputs
Gemini vision model supports image inputs when providing a single chat message. Example:
```python
from langchain_core.messages import HumanMessage
from langchain_google_vertexai import ChatVertexAI
llm = ChatVertexAI(model_name="gemini-pro-vision")
# example
message = HumanMessage(
content=[
{
"type": "text",
"text": "What's in this image?",
}, # You can optionally provide text parts
{"type": "image_url", "image_url": {"url": "https://picsum.photos/seed/picsum/200/300"}},
]
)
llm.invoke([message])
```
The value of `image_url` can be any of the following:
- A public image URL
- An accessible gcs file (e.g., "gcs://path/to/file.png")
- A local file path
- A base64 encoded image (e.g., ``)
## Embeddings
You can use Google Cloud's embeddings models as:
```python
from langchain_google_vertexai import VertexAIEmbeddings
embeddings = VertexAIEmbeddings()
embeddings.embed_query("hello, world!")
```
## LLMs
You can use Google Cloud's generative AI models as Langchain LLMs:
```python
from langchain.prompts import PromptTemplate
from langchain_google_vertexai import VertexAI
template = """Question: {question}
Answer: Let's think step by step."""
prompt = PromptTemplate.from_template(template)
chain = prompt | llm
question = "Who was the president in the year Justin Beiber was born?"
print(chain.invoke({"question": question}))
```
You can use Gemini and Palm models, including code-generations ones:
```python
from langchain_google_vertexai import VertexAI
llm = VertexAI(model_name="code-bison", max_output_tokens=1000, temperature=0.3)
question = "Write a python function that checks if a string is a valid email address"
output = llm(question)
```
https://github.com/langchain-ai/langchain-google/tree/main/libs/vertexai

@ -1,17 +0,0 @@
from langchain_google_vertexai._enums import HarmBlockThreshold, HarmCategory
from langchain_google_vertexai.chains import create_structured_runnable
from langchain_google_vertexai.chat_models import ChatVertexAI
from langchain_google_vertexai.embeddings import VertexAIEmbeddings
from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser
from langchain_google_vertexai.llms import VertexAI, VertexAIModelGarden
__all__ = [
"ChatVertexAI",
"VertexAIEmbeddings",
"VertexAI",
"VertexAIModelGarden",
"HarmBlockThreshold",
"HarmCategory",
"PydanticFunctionsOutputParser",
"create_structured_runnable",
]

@ -1,6 +0,0 @@
from vertexai.preview.generative_models import ( # type: ignore
HarmBlockThreshold,
HarmCategory,
)
__all__ = ["HarmBlockThreshold", "HarmCategory"]

@ -1,132 +0,0 @@
"""Utilities to init Vertex AI."""
import dataclasses
from importlib import metadata
from typing import Any, Callable, Dict, Optional, Union
import google.api_core
import proto # type: ignore[import-untyped]
from google.api_core.gapic_v1.client_info import ClientInfo
from google.cloud import storage # type: ignore[attr-defined]
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import create_base_retry_decorator
from vertexai.generative_models._generative_models import ( # type: ignore[import-untyped]
Candidate,
)
from vertexai.language_models import ( # type: ignore[import-untyped]
TextGenerationResponse,
)
from vertexai.preview.generative_models import Image # type: ignore[import-untyped]
def create_retry_decorator(
*,
max_retries: int = 1,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Creates a retry decorator for Vertex / Palm LLMs."""
errors = [
google.api_core.exceptions.ResourceExhausted,
google.api_core.exceptions.ServiceUnavailable,
google.api_core.exceptions.Aborted,
google.api_core.exceptions.DeadlineExceeded,
google.api_core.exceptions.GoogleAPIError,
]
decorator = create_base_retry_decorator(
error_types=errors, max_retries=max_retries, run_manager=run_manager
)
return decorator
def raise_vertex_import_error(minimum_expected_version: str = "1.38.0") -> None:
"""Raise ImportError related to Vertex SDK being not available.
Args:
minimum_expected_version: The lowest expected version of the SDK.
Raises:
ImportError: an ImportError that mentions a required version of the SDK.
"""
raise ImportError(
"Please, install or upgrade the google-cloud-aiplatform library: "
f"pip install google-cloud-aiplatform>={minimum_expected_version}"
)
def get_client_info(module: Optional[str] = None) -> "ClientInfo":
r"""Returns a custom user agent header.
Args:
module (Optional[str]):
Optional. The module for a custom user agent header.
Returns:
google.api_core.gapic_v1.client_info.ClientInfo
"""
langchain_version = metadata.version("langchain")
client_library_version = (
f"{langchain_version}-{module}" if module else langchain_version
)
return ClientInfo(
client_library_version=client_library_version,
user_agent=f"langchain/{client_library_version}",
)
def load_image_from_gcs(path: str, project: Optional[str] = None) -> Image:
"""Loads im Image from GCS."""
gcs_client = storage.Client(project=project)
pieces = path.split("/")
blobs = list(gcs_client.list_blobs(pieces[2], prefix="/".join(pieces[3:])))
if len(blobs) > 1:
raise ValueError(f"Found more than one candidate for {path}!")
return Image.from_bytes(blobs[0].download_as_bytes())
def is_codey_model(model_name: str) -> bool:
"""Returns True if the model name is a Codey model."""
return "code" in model_name
def is_gemini_model(model_name: str) -> bool:
"""Returns True if the model name is a Gemini model."""
return model_name is not None and "gemini" in model_name
def get_generation_info(
candidate: Union[TextGenerationResponse, Candidate],
is_gemini: bool,
*,
stream: bool = False,
) -> Dict[str, Any]:
if is_gemini:
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body
info = {
"is_blocked": any([rating.blocked for rating in candidate.safety_ratings]),
"safety_ratings": [
{
"category": rating.category.name,
"probability_label": rating.probability.name,
"blocked": rating.blocked,
}
for rating in candidate.safety_ratings
],
"citation_metadata": (
proto.Message.to_dict(candidate.citation_metadata)
if candidate.citation_metadata
else None
),
}
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat#response_body
else:
info = dataclasses.asdict(candidate)
info.pop("text")
info = {k: v for k, v in info.items() if not k.startswith("_")}
if stream:
# Remove non-streamable types, like bools.
info.pop("is_blocked")
return info

@ -1,111 +0,0 @@
from typing import (
Dict,
Optional,
Sequence,
Type,
Union,
)
from langchain_core.output_parsers import (
BaseGenerationOutputParser,
BaseOutputParser,
)
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable
from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser
def get_output_parser(
functions: Sequence[Type[BaseModel]],
) -> Union[BaseOutputParser, BaseGenerationOutputParser]:
"""Get the appropriate function output parser given the user functions.
Args:
functions: Sequence where element is a dictionary, a pydantic.BaseModel class,
or a Python function. If a dictionary is passed in, it is assumed to
already be a valid OpenAI function.
Returns:
A PydanticFunctionsOutputParser
"""
function_names = [f.__name__ for f in functions]
if len(functions) > 1:
pydantic_schema: Union[Dict, Type[BaseModel]] = {
name: fn for name, fn in zip(function_names, functions)
}
else:
pydantic_schema = functions[0]
output_parser: Union[
BaseOutputParser, BaseGenerationOutputParser
] = PydanticFunctionsOutputParser(pydantic_schema=pydantic_schema)
return output_parser
def create_structured_runnable(
function: Union[Type[BaseModel], Sequence[Type[BaseModel]]],
llm: Runnable,
*,
prompt: Optional[BasePromptTemplate] = None,
) -> Runnable:
"""Create a runnable sequence that uses OpenAI functions.
Args:
function: Either a single pydantic.BaseModel class or a sequence
of pydantic.BaseModels classes.
For best results, pydantic.BaseModels
should have descriptions of the parameters.
llm: Language model to use,
assumed to support the Google Vertex function-calling API.
prompt: BasePromptTemplate to pass to the model.
Returns:
A runnable sequence that will pass in the given functions to the model when run.
Example:
.. code-block:: python
from typing import Optional
from langchain_google_vertexai import ChatVertexAI, create_structured_runnable
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
class RecordPerson(BaseModel):
\"\"\"Record some identifying information about a person.\"\"\"
name: str = Field(..., description="The person's name")
age: int = Field(..., description="The person's age")
fav_food: Optional[str] = Field(None, description="The person's favorite food")
class RecordDog(BaseModel):
\"\"\"Record some identifying information about a dog.\"\"\"
name: str = Field(..., description="The dog's name")
color: str = Field(..., description="The dog's color")
fav_food: Optional[str] = Field(None, description="The dog's favorite food")
llm = ChatVertexAI(model_name="gemini-pro")
prompt = ChatPromptTemplate.from_template(\"\"\"
You are a world class algorithm for recording entities.
Make calls to the relevant function to record the entities in the following input: {input}
Tip: Make sure to answer in the correct format\"\"\"
)
chain = create_structured_runnable([RecordPerson, RecordDog], llm, prompt=prompt)
chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"})
# -> RecordDog(name="Harry", color="brown", fav_food="chicken")
""" # noqa: E501
if not function:
raise ValueError("Need to pass in at least one function. Received zero.")
functions = function if isinstance(function, Sequence) else [function]
output_parser = get_output_parser(functions)
llm_with_functions = llm.bind(functions=functions)
if prompt is None:
initial_chain = llm_with_functions
else:
initial_chain = prompt | llm_with_functions
return initial_chain | output_parser

@ -1,555 +0,0 @@
"""Wrapper around Google VertexAI chat-based models."""
from __future__ import annotations
import base64
import json
import logging
import re
from dataclasses import dataclass, field
from typing import Any, Dict, Iterator, List, Optional, Union, cast
from urllib.parse import urlparse
import proto # type: ignore[import-untyped]
import requests
from google.cloud.aiplatform_v1beta1.types.content import Part as GapicPart
from google.cloud.aiplatform_v1beta1.types.tool import FunctionCall
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
BaseChatModel,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import root_validator
from vertexai.language_models import ( # type: ignore
ChatMessage,
ChatModel,
ChatSession,
CodeChatModel,
CodeChatSession,
InputOutputTextPair,
)
from vertexai.preview.generative_models import ( # type: ignore
Candidate,
Content,
GenerativeModel,
Image,
Part,
)
from vertexai.preview.language_models import ( # type: ignore
ChatModel as PreviewChatModel,
)
from vertexai.preview.language_models import (
CodeChatModel as PreviewCodeChatModel,
)
from langchain_google_vertexai._utils import (
get_generation_info,
is_codey_model,
is_gemini_model,
load_image_from_gcs,
)
from langchain_google_vertexai.functions_utils import (
_format_tools_to_vertex_tool,
)
from langchain_google_vertexai.llms import (
_VertexAICommon,
)
logger = logging.getLogger(__name__)
@dataclass
class _ChatHistory:
"""Represents a context and a history of messages."""
history: List[ChatMessage] = field(default_factory=list)
context: Optional[str] = None
def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory:
"""Parse a sequence of messages into history.
Args:
history: The list of messages to re-create the history of the chat.
Returns:
A parsed chat history.
Raises:
ValueError: If a sequence of message has a SystemMessage not at the
first place.
"""
vertex_messages, context = [], None
for i, message in enumerate(history):
content = cast(str, message.content)
if i == 0 and isinstance(message, SystemMessage):
context = content
elif isinstance(message, AIMessage):
vertex_message = ChatMessage(content=message.content, author="bot")
vertex_messages.append(vertex_message)
elif isinstance(message, HumanMessage):
vertex_message = ChatMessage(content=message.content, author="user")
vertex_messages.append(vertex_message)
else:
raise ValueError(
f"Unexpected message with type {type(message)} at the position {i}."
)
chat_history = _ChatHistory(context=context, history=vertex_messages)
return chat_history
def _is_url(s: str) -> bool:
try:
result = urlparse(s)
return all([result.scheme, result.netloc])
except Exception as e:
logger.debug(f"Unable to parse URL: {e}")
return False
def _parse_chat_history_gemini(
history: List[BaseMessage],
project: Optional[str] = None,
convert_system_message_to_human: Optional[bool] = False,
) -> List[Content]:
def _convert_to_prompt(part: Union[str, Dict]) -> Part:
if isinstance(part, str):
return Part.from_text(part)
if not isinstance(part, Dict):
raise ValueError(
f"Message's content is expected to be a dict, got {type(part)}!"
)
if part["type"] == "text":
return Part.from_text(part["text"])
elif part["type"] == "image_url":
path = part["image_url"]["url"]
if path.startswith("gs://"):
image = load_image_from_gcs(path=path, project=project)
elif path.startswith("data:image/"):
# extract base64 component from image uri
try:
regexp = r"data:image/\w{2,4};base64,(.*)"
encoded = re.search(regexp, path).group(1) # type: ignore
except AttributeError:
raise ValueError(
"Invalid image uri. It should be in the format "
"data:image/<image_type>;base64,<base64_encoded_image>."
)
image = Image.from_bytes(base64.b64decode(encoded))
elif _is_url(path):
response = requests.get(path)
response.raise_for_status()
image = Image.from_bytes(response.content)
else:
image = Image.load_from_file(path)
else:
raise ValueError("Only text and image_url types are supported!")
return Part.from_image(image)
def _convert_to_parts(message: BaseMessage) -> List[Part]:
raw_content = message.content
if isinstance(raw_content, str):
raw_content = [raw_content]
return [_convert_to_prompt(part) for part in raw_content]
vertex_messages = []
raw_system_message = None
for i, message in enumerate(history):
if (
i == 0
and isinstance(message, SystemMessage)
and not convert_system_message_to_human
):
raise ValueError(
"""SystemMessages are not yet supported!
To automatically convert the leading SystemMessage to a HumanMessage,
set `convert_system_message_to_human` to True. Example:
llm = ChatVertexAI(model_name="gemini-pro", convert_system_message_to_human=True)
"""
)
elif i == 0 and isinstance(message, SystemMessage):
raw_system_message = message
continue
elif isinstance(message, AIMessage):
raw_function_call = message.additional_kwargs.get("function_call")
role = "model"
if raw_function_call:
function_call = FunctionCall(
{
"name": raw_function_call["name"],
"args": json.loads(raw_function_call["arguments"]),
}
)
gapic_part = GapicPart(function_call=function_call)
parts = [Part._from_gapic(gapic_part)]
else:
parts = _convert_to_parts(message)
elif isinstance(message, HumanMessage):
role = "user"
parts = _convert_to_parts(message)
elif isinstance(message, FunctionMessage):
role = "user"
parts = [
Part.from_function_response(
name=message.name,
response={
"content": message.content,
},
)
]
else:
raise ValueError(
f"Unexpected message with type {type(message)} at the position {i}."
)
if raw_system_message:
if role == "model":
raise ValueError(
"SystemMessage should be followed by a HumanMessage and "
"not by AIMessage."
)
parts = _convert_to_parts(raw_system_message) + parts
raw_system_message = None
vertex_message = Content(role=role, parts=parts)
vertex_messages.append(vertex_message)
return vertex_messages
def _parse_examples(examples: List[BaseMessage]) -> List[InputOutputTextPair]:
if len(examples) % 2 != 0:
raise ValueError(
f"Expect examples to have an even amount of messages, got {len(examples)}."
)
example_pairs = []
input_text = None
for i, example in enumerate(examples):
if i % 2 == 0:
if not isinstance(example, HumanMessage):
raise ValueError(
f"Expected the first message in a part to be from human, got "
f"{type(example)} for the {i}th message."
)
input_text = example.content
if i % 2 == 1:
if not isinstance(example, AIMessage):
raise ValueError(
f"Expected the second message in a part to be from AI, got "
f"{type(example)} for the {i}th message."
)
pair = InputOutputTextPair(
input_text=input_text, output_text=example.content
)
example_pairs.append(pair)
return example_pairs
def _get_question(messages: List[BaseMessage]) -> HumanMessage:
"""Get the human message at the end of a list of input messages to a chat model."""
if not messages:
raise ValueError("You should provide at least one message to start the chat!")
question = messages[-1]
if not isinstance(question, HumanMessage):
raise ValueError(
f"Last message in the list should be from human, got {question.type}."
)
return question
def _parse_response_candidate(response_candidate: "Candidate") -> AIMessage:
try:
content = response_candidate.text
except ValueError:
content = ""
additional_kwargs = {}
first_part = response_candidate.content.parts[0]
if first_part.function_call:
function_call = {"name": first_part.function_call.name}
# dump to match other function calling llm for now
function_call_args_dict = proto.Message.to_dict(first_part.function_call)[
"args"
]
function_call["arguments"] = json.dumps(
{k: function_call_args_dict[k] for k in function_call_args_dict}
)
additional_kwargs["function_call"] = function_call
return AIMessage(content=content, additional_kwargs=additional_kwargs)
class ChatVertexAI(_VertexAICommon, BaseChatModel):
"""`Vertex AI` Chat large language models API."""
model_name: str = "chat-bison"
"Underlying model name."
examples: Optional[List[BaseMessage]] = None
convert_system_message_to_human: bool = False
"""Whether to merge any leading SystemMessage into the following HumanMessage.
Gemini does not support system messages; any unsupported messages will
raise an error."""
@classmethod
def is_lc_serializable(self) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "vertexai"]
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""
is_gemini = is_gemini_model(values["model_name"])
safety_settings = values["safety_settings"]
if safety_settings and not is_gemini:
raise ValueError("Safety settings are only supported for Gemini models")
cls._init_vertexai(values)
if is_gemini:
values["client"] = GenerativeModel(
model_name=values["model_name"], safety_settings=safety_settings
)
values["client_preview"] = GenerativeModel(
model_name=values["model_name"], safety_settings=safety_settings
)
else:
if is_codey_model(values["model_name"]):
model_cls = CodeChatModel
model_cls_preview = PreviewCodeChatModel
else:
model_cls = ChatModel
model_cls_preview = PreviewChatModel
values["client"] = model_cls.from_pretrained(values["model_name"])
values["client_preview"] = model_cls_preview.from_pretrained(
values["model_name"]
)
return values
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
"""Generate next turn in the conversation.
Args:
messages: The history of the conversation as a list of messages. Code chat
does not support context.
stop: The list of stop words (optional).
run_manager: The CallbackManager for LLM run, it's not used at the moment.
stream: Whether to use the streaming endpoint.
Returns:
The ChatResult that contains outputs generated by the model.
Raises:
ValueError: if the last message in the list is not from human.
"""
should_stream = stream if stream is not None else self.streaming
safety_settings = kwargs.pop("safety_settings", None)
if should_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
params = self._prepare_params(stop=stop, stream=False, **kwargs)
msg_params = {}
if "candidate_count" in params:
msg_params["candidate_count"] = params.pop("candidate_count")
if self._is_gemini_model:
history_gemini = _parse_chat_history_gemini(
messages,
project=self.project,
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history_gemini.pop()
chat = self.client.start_chat(history=history_gemini)
# set param to `functions` until core tool/function calling implemented
raw_tools = params.pop("functions") if "functions" in params else None
tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None
response = chat.send_message(
message,
generation_config=params,
tools=tools,
safety_settings=safety_settings,
)
generations = [
ChatGeneration(
message=_parse_response_candidate(c),
generation_info=get_generation_info(c, self._is_gemini_model),
)
for c in response.candidates
]
else:
question = _get_question(messages)
history = _parse_chat_history(messages[:-1])
examples = kwargs.get("examples") or self.examples
if examples:
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, **params)
response = chat.send_message(question.content, **msg_params)
generations = [
ChatGeneration(
message=AIMessage(content=r.text),
generation_info=get_generation_info(r, self._is_gemini_model),
)
for r in response.candidates
]
return ChatResult(generations=generations)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Asynchronously generate next turn in the conversation.
Args:
messages: The history of the conversation as a list of messages. Code chat
does not support context.
stop: The list of stop words (optional).
run_manager: The CallbackManager for LLM run, it's not used at the moment.
Returns:
The ChatResult that contains outputs generated by the model.
Raises:
ValueError: if the last message in the list is not from human.
"""
if "stream" in kwargs:
kwargs.pop("stream")
logger.warning("ChatVertexAI does not currently support async streaming.")
params = self._prepare_params(stop=stop, **kwargs)
safety_settings = kwargs.pop("safety_settings", None)
msg_params = {}
if "candidate_count" in params:
msg_params["candidate_count"] = params.pop("candidate_count")
if self._is_gemini_model:
history_gemini = _parse_chat_history_gemini(
messages,
project=self.project,
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history_gemini.pop()
chat = self.client.start_chat(history=history_gemini)
# set param to `functions` until core tool/function calling implemented
raw_tools = params.pop("functions") if "functions" in params else None
tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None
response = await chat.send_message_async(
message,
generation_config=params,
tools=tools,
safety_settings=safety_settings,
)
generations = [
ChatGeneration(
message=_parse_response_candidate(c),
generation_info=get_generation_info(c, self._is_gemini_model),
)
for c in response.candidates
]
else:
question = _get_question(messages)
history = _parse_chat_history(messages[:-1])
examples = kwargs.get("examples", None) or self.examples
if examples:
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, **params)
response = await chat.send_message_async(question.content, **msg_params)
generations = [
ChatGeneration(
message=AIMessage(content=r.text),
generation_info=get_generation_info(r, self._is_gemini_model),
)
for r in response.candidates
]
return ChatResult(generations=generations)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params = self._prepare_params(stop=stop, stream=True, **kwargs)
if self._is_gemini_model:
history_gemini = _parse_chat_history_gemini(
messages,
project=self.project,
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history_gemini.pop()
chat = self.client.start_chat(history=history_gemini)
# set param to `functions` until core tool/function calling implemented
raw_tools = params.pop("functions") if "functions" in params else None
tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None
safety_settings = params.pop("safety_settings", None)
responses = chat.send_message(
message,
stream=True,
generation_config=params,
safety_settings=safety_settings,
tools=tools,
)
for response in responses:
message = _parse_response_candidate(response.candidates[0])
if run_manager:
run_manager.on_llm_new_token(message.content)
yield ChatGenerationChunk(
message=AIMessageChunk(
content=message.content,
additional_kwargs=message.additional_kwargs,
)
)
else:
question = _get_question(messages)
history = _parse_chat_history(messages[:-1])
examples = kwargs.get("examples", None)
if examples:
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, **params)
responses = chat.send_message_streaming(question.content, **params)
for response in responses:
if run_manager:
run_manager.on_llm_new_token(response.text)
yield ChatGenerationChunk(
message=AIMessageChunk(content=response.text),
generation_info=get_generation_info(response, self._is_gemini_model),
)
def _start_chat(
self, history: _ChatHistory, **kwargs: Any
) -> Union[ChatSession, CodeChatSession]:
if not self.is_codey_model:
return self.client.start_chat(
context=history.context, message_history=history.history, **kwargs
)
else:
return self.client.start_chat(message_history=history.history, **kwargs)

@ -1,336 +0,0 @@
import logging
import re
import string
import threading
from concurrent.futures import ThreadPoolExecutor, wait
from typing import Any, Dict, List, Literal, Optional, Tuple, Type
from google.api_core.exceptions import (
Aborted,
DeadlineExceeded,
InvalidArgument,
ResourceExhausted,
ServiceUnavailable,
)
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.pydantic_v1 import root_validator
from vertexai.language_models import ( # type: ignore
TextEmbeddingInput,
TextEmbeddingModel,
)
from langchain_google_vertexai.llms import _VertexAICommon
logger = logging.getLogger(__name__)
_MAX_TOKENS_PER_BATCH = 20000
_MAX_BATCH_SIZE = 250
_MIN_BATCH_SIZE = 5
class VertexAIEmbeddings(_VertexAICommon, Embeddings):
"""Google Cloud VertexAI embedding models."""
# Instance context
instance: Dict[str, Any] = {} #: :meta private:
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validates that the python package exists in environment."""
cls._init_vertexai(values)
if values["model_name"] == "textembedding-gecko-default":
logger.warning(
"Model_name will become a required arg for VertexAIEmbeddings "
"starting from Feb-01-2024. Currently the default is set to "
"textembedding-gecko@001"
)
values["model_name"] = "textembedding-gecko@001"
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
return values
def __init__(
self,
# the default value would be removed after Feb-01-2024
model_name: str = "textembedding-gecko-default",
project: Optional[str] = None,
location: str = "us-central1",
request_parallelism: int = 5,
max_retries: int = 6,
credentials: Optional[Any] = None,
**kwargs: Any,
):
"""Initialize the sentence_transformer."""
super().__init__(
project=project,
location=location,
credentials=credentials,
request_parallelism=request_parallelism,
max_retries=max_retries,
model_name=model_name,
**kwargs,
)
self.instance["max_batch_size"] = kwargs.get("max_batch_size", _MAX_BATCH_SIZE)
self.instance["batch_size"] = self.instance["max_batch_size"]
self.instance["min_batch_size"] = kwargs.get("min_batch_size", _MIN_BATCH_SIZE)
self.instance["min_good_batch_size"] = self.instance["min_batch_size"]
self.instance["lock"] = threading.Lock()
self.instance["batch_size_validated"] = False
self.instance["task_executor"] = ThreadPoolExecutor(
max_workers=request_parallelism
)
self.instance[
"embeddings_task_type_supported"
] = not self.client._endpoint_name.endswith("/textembedding-gecko@001")
@staticmethod
def _split_by_punctuation(text: str) -> List[str]:
"""Splits a string by punctuation and whitespace characters."""
split_by = string.punctuation + "\t\n "
pattern = f"([{split_by}])"
# Using re.split to split the text based on the pattern
return [segment for segment in re.split(pattern, text) if segment]
@staticmethod
def _prepare_batches(texts: List[str], batch_size: int) -> List[List[str]]:
"""Splits texts in batches based on current maximum batch size
and maximum tokens per request.
"""
text_index = 0
texts_len = len(texts)
batch_token_len = 0
batches: List[List[str]] = []
current_batch: List[str] = []
if texts_len == 0:
return []
while text_index < texts_len:
current_text = texts[text_index]
# Number of tokens per a text is conservatively estimated
# as 2 times number of words, punctuation and whitespace characters.
# Using `count_tokens` API will make batching too expensive.
# Utilizing a tokenizer, would add a dependency that would not
# necessarily be reused by the application using this class.
current_text_token_cnt = (
len(VertexAIEmbeddings._split_by_punctuation(current_text)) * 2
)
end_of_batch = False
if current_text_token_cnt > _MAX_TOKENS_PER_BATCH:
# Current text is too big even for a single batch.
# Such request will fail, but we still make a batch
# so that the app can get the error from the API.
if len(current_batch) > 0:
# Adding current batch if not empty.
batches.append(current_batch)
current_batch = [current_text]
text_index += 1
end_of_batch = True
elif (
batch_token_len + current_text_token_cnt > _MAX_TOKENS_PER_BATCH
or len(current_batch) == batch_size
):
end_of_batch = True
else:
if text_index == texts_len - 1:
# Last element - even though the batch may be not big,
# we still need to make it.
end_of_batch = True
batch_token_len += current_text_token_cnt
current_batch.append(current_text)
text_index += 1
if end_of_batch:
batches.append(current_batch)
current_batch = []
batch_token_len = 0
return batches
def _get_embeddings_with_retry(
self, texts: List[str], embeddings_type: Optional[str] = None
) -> List[List[float]]:
"""Makes a Vertex AI model request with retry logic."""
errors: List[Type[BaseException]] = [
ResourceExhausted,
ServiceUnavailable,
Aborted,
DeadlineExceeded,
]
retry_decorator = create_base_retry_decorator(
error_types=errors, max_retries=self.max_retries
)
@retry_decorator
def _completion_with_retry(texts_to_process: List[str]) -> Any:
if embeddings_type and self.instance["embeddings_task_type_supported"]:
requests = [
TextEmbeddingInput(text=t, task_type=embeddings_type)
for t in texts_to_process
]
else:
requests = texts_to_process
embeddings = self.client.get_embeddings(requests)
return [embs.values for embs in embeddings]
return _completion_with_retry(texts)
def _prepare_and_validate_batches(
self, texts: List[str], embeddings_type: Optional[str] = None
) -> Tuple[List[List[float]], List[List[str]]]:
"""Prepares text batches with one-time validation of batch size.
Batch size varies between GCP regions and individual project quotas.
# Returns embeddings of the first text batch that went through,
# and text batches for the rest of the texts.
"""
batches = VertexAIEmbeddings._prepare_batches(
texts, self.instance["batch_size"]
)
# If batch size if less or equal to one that went through before,
# then keep batches as they are.
if len(batches[0]) <= self.instance["min_good_batch_size"]:
return [], batches
with self.instance["lock"]:
# If largest possible batch size was validated
# while waiting for the lock, then check for rebuilding
# our batches, and return.
if self.instance["batch_size_validated"]:
if len(batches[0]) <= self.instance["batch_size"]:
return [], batches
else:
return [], VertexAIEmbeddings._prepare_batches(
texts, self.instance["batch_size"]
)
# Figure out largest possible batch size by trying to push
# batches and lowering their size in half after every failure.
first_batch = batches[0]
first_result = []
had_failure = False
while True:
try:
first_result = self._get_embeddings_with_retry(
first_batch, embeddings_type
)
break
except InvalidArgument:
had_failure = True
first_batch_len = len(first_batch)
if first_batch_len == self.instance["min_batch_size"]:
raise
first_batch_len = max(
self.instance["min_batch_size"], int(first_batch_len / 2)
)
first_batch = first_batch[:first_batch_len]
first_batch_len = len(first_batch)
self.instance["min_good_batch_size"] = max(
self.instance["min_good_batch_size"], first_batch_len
)
# If had a failure and recovered
# or went through with the max size, then it's a legit batch size.
if had_failure or first_batch_len == self.instance["max_batch_size"]:
self.instance["batch_size"] = first_batch_len
self.instance["batch_size_validated"] = True
# If batch size was updated,
# rebuild batches with the new batch size
# (texts that went through are excluded here).
if first_batch_len != self.instance["max_batch_size"]:
batches = VertexAIEmbeddings._prepare_batches(
texts[first_batch_len:], self.instance["batch_size"]
)
else:
# Still figuring out max batch size.
batches = batches[1:]
# Returning embeddings of the first text batch that went through,
# and text batches for the rest of texts.
return first_result, batches
def embed(
self,
texts: List[str],
batch_size: int = 0,
embeddings_task_type: Optional[
Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
]
] = None,
) -> List[List[float]]:
"""Embed a list of strings.
Args:
texts: List[str] The list of strings to embed.
batch_size: [int] The batch size of embeddings to send to the model.
If zero, then the largest batch size will be detected dynamically
at the first request, starting from 250, down to 5.
embeddings_task_type: [str] optional embeddings task type,
one of the following
RETRIEVAL_QUERY - Text is a query
in a search/retrieval setting.
RETRIEVAL_DOCUMENT - Text is a document
in a search/retrieval setting.
SEMANTIC_SIMILARITY - Embeddings will be used
for Semantic Textual Similarity (STS).
CLASSIFICATION - Embeddings will be used for classification.
CLUSTERING - Embeddings will be used for clustering.
Returns:
List of embeddings, one for each text.
"""
if len(texts) == 0:
return []
embeddings: List[List[float]] = []
first_batch_result: List[List[float]] = []
if batch_size > 0:
# Fixed batch size.
batches = VertexAIEmbeddings._prepare_batches(texts, batch_size)
else:
# Dynamic batch size, starting from 250 at the first call.
first_batch_result, batches = self._prepare_and_validate_batches(
texts, embeddings_task_type
)
# First batch result may have some embeddings already.
# In such case, batches have texts that were not processed yet.
embeddings.extend(first_batch_result)
tasks = []
for batch in batches:
tasks.append(
self.instance["task_executor"].submit(
self._get_embeddings_with_retry,
texts=batch,
embeddings_type=embeddings_task_type,
)
)
if len(tasks) > 0:
wait(tasks)
for t in tasks:
embeddings.extend(t.result())
return embeddings
def embed_documents(
self, texts: List[str], batch_size: int = 0
) -> List[List[float]]:
"""Embed a list of documents.
Args:
texts: List[str] The list of texts to embed.
batch_size: [int] The batch size of embeddings to send to the model.
If zero, then the largest batch size will be detected dynamically
at the first request, starting from 250, down to 5.
Returns:
List of embeddings, one for each text.
"""
return self.embed(texts, batch_size, "RETRIEVAL_DOCUMENT")
def embed_query(self, text: str) -> List[float]:
"""Embed a text.
Args:
text: The text to embed.
Returns:
Embedding for the text.
"""
embeddings = self.embed([text], 1, "RETRIEVAL_QUERY")
return embeddings[0]

@ -1,152 +0,0 @@
import json
from typing import Dict, List, Type, Union
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import FunctionDescription
from langchain_core.utils.json_schema import dereference_refs
from vertexai.preview.generative_models import ( # type: ignore
FunctionDeclaration,
)
from vertexai.preview.generative_models import Tool as VertexTool
def _format_pydantic_to_vertex_function(
pydantic_model: Type[BaseModel],
) -> FunctionDescription:
schema = dereference_refs(pydantic_model.schema())
schema.pop("definitions", None)
return {
"name": schema["title"],
"description": schema.get("description", ""),
"parameters": {
"properties": {
k: {
"type": v["type"],
"description": v.get("description"),
}
for k, v in schema["properties"].items()
},
"required": schema["required"],
"type": schema["type"],
},
}
def _format_tool_to_vertex_function(tool: BaseTool) -> FunctionDescription:
"Format tool into the Vertex function API."
if tool.args_schema:
schema = dereference_refs(tool.args_schema.schema())
schema.pop("definitions", None)
return {
"name": tool.name or schema["title"],
"description": tool.description or schema["description"],
"parameters": {
"properties": {
k: {
"type": v["type"],
"description": v.get("description"),
}
for k, v in schema["properties"].items()
},
"required": schema["required"],
"type": schema["type"],
},
}
else:
return {
"name": tool.name,
"description": tool.description,
"parameters": {
"properties": {
"__arg1": {"type": "string"},
},
"required": ["__arg1"],
"type": "object",
},
}
def _format_tools_to_vertex_tool(
tools: List[Union[BaseTool, Type[BaseModel]]],
) -> List[VertexTool]:
"Format tool into the Vertex Tool instance."
function_declarations = []
for tool in tools:
if isinstance(tool, BaseTool):
func = _format_tool_to_vertex_function(tool)
else:
func = _format_pydantic_to_vertex_function(tool)
function_declarations.append(FunctionDeclaration(**func))
return [VertexTool(function_declarations=function_declarations)]
class PydanticFunctionsOutputParser(BaseOutputParser):
"""Parse an output as a pydantic object.
This parser is used to parse the output of a ChatModel that uses
Google Vertex function format to invoke functions.
The parser extracts the function call invocation and matches
them to the pydantic schema provided.
An exception will be raised if the function call does not match
the provided schema.
Example:
... code-block:: python
message = AIMessage(
content="This is a test message",
additional_kwargs={
"function_call": {
"name": "cookie",
"arguments": json.dumps({"name": "value", "age": 10}),
}
},
)
chat_generation = ChatGeneration(message=message)
class Cookie(BaseModel):
name: str
age: int
class Dog(BaseModel):
species: str
# Full output
parser = PydanticOutputFunctionsParser(
pydantic_schema={"cookie": Cookie, "dog": Dog}
)
result = parser.parse_result([chat_generation])
"""
pydantic_schema: Union[Type[BaseModel], Dict[str, Type[BaseModel]]]
def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> BaseModel:
if not isinstance(result[0], ChatGeneration):
raise ValueError("This output parser only works on ChatGeneration output")
message = result[0].message
function_call = message.additional_kwargs.get("function_call", {})
if function_call:
function_name = function_call["name"]
tool_input = function_call.get("arguments", {})
if isinstance(self.pydantic_schema, dict):
schema = self.pydantic_schema[function_name]
else:
schema = self.pydantic_schema
return schema(**json.loads(tool_input))
else:
raise OutputParserException(f"Could not parse function call: {message}")
def parse(self, text: str) -> BaseModel:
raise ValueError("Can only parse messages")

@ -1,555 +0,0 @@
from __future__ import annotations
from concurrent.futures import Executor
from typing import Any, ClassVar, Dict, Iterator, List, Optional, Union
import vertexai # type: ignore[import-untyped]
from google.api_core.client_options import ClientOptions
from google.cloud.aiplatform.gapic import (
PredictionServiceAsyncClient,
PredictionServiceClient,
)
from google.cloud.aiplatform.models import Prediction
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from vertexai.language_models import ( # type: ignore[import-untyped]
CodeGenerationModel,
TextGenerationModel,
)
from vertexai.language_models._language_models import ( # type: ignore[import-untyped]
TextGenerationResponse,
)
from vertexai.preview.generative_models import ( # type: ignore[import-untyped]
GenerativeModel,
Image,
)
from vertexai.preview.language_models import ( # type: ignore[import-untyped]
ChatModel as PreviewChatModel,
)
from vertexai.preview.language_models import (
CodeChatModel as PreviewCodeChatModel,
)
from vertexai.preview.language_models import (
CodeGenerationModel as PreviewCodeGenerationModel,
)
from vertexai.preview.language_models import (
TextGenerationModel as PreviewTextGenerationModel,
)
from langchain_google_vertexai._enums import HarmBlockThreshold, HarmCategory
from langchain_google_vertexai._utils import (
create_retry_decorator,
get_client_info,
get_generation_info,
is_codey_model,
is_gemini_model,
)
_PALM_DEFAULT_MAX_OUTPUT_TOKENS = TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS
_PALM_DEFAULT_TEMPERATURE = 0.0
_PALM_DEFAULT_TOP_P = 0.95
_PALM_DEFAULT_TOP_K = 40
def _completion_with_retry(
llm: VertexAI,
prompt: List[Union[str, Image]],
stream: bool = False,
is_gemini: bool = False,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = create_retry_decorator(
max_retries=llm.max_retries, run_manager=run_manager
)
@retry_decorator
def _completion_with_retry_inner(
prompt: List[Union[str, Image]], is_gemini: bool = False, **kwargs: Any
) -> Any:
if is_gemini:
return llm.client.generate_content(
prompt,
stream=stream,
safety_settings=kwargs.pop("safety_settings", None),
generation_config=kwargs,
)
else:
if stream:
return llm.client.predict_streaming(prompt[0], **kwargs)
return llm.client.predict(prompt[0], **kwargs)
return _completion_with_retry_inner(prompt, is_gemini, **kwargs)
async def _acompletion_with_retry(
llm: VertexAI,
prompt: str,
is_gemini: bool = False,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = create_retry_decorator(
max_retries=llm.max_retries, run_manager=run_manager
)
@retry_decorator
async def _acompletion_with_retry_inner(
prompt: str, is_gemini: bool = False, **kwargs: Any
) -> Any:
if is_gemini:
return await llm.client.generate_content_async(
prompt,
generation_config=kwargs,
safety_settings=kwargs.pop("safety_settings", None),
)
return await llm.client.predict_async(prompt, **kwargs)
return await _acompletion_with_retry_inner(prompt, is_gemini, **kwargs)
class _VertexAIBase(BaseModel):
project: Optional[str] = None
"The default GCP project to use when making Vertex API calls."
location: str = "us-central1"
"The default location to use when making API calls."
request_parallelism: int = 5
"The amount of parallelism allowed for requests issued to VertexAI models. "
"Default is 5."
max_retries: int = 6
"""The maximum number of retries to make when generating."""
task_executor: ClassVar[Optional[Executor]] = Field(default=None, exclude=True)
stop: Optional[List[str]] = None
"Optional list of stop words to use when generating."
model_name: Optional[str] = None
"Underlying model name."
class _VertexAICommon(_VertexAIBase):
client: Any = None #: :meta private:
client_preview: Any = None #: :meta private:
model_name: str
"Underlying model name."
temperature: Optional[float] = None
"Sampling temperature, it controls the degree of randomness in token selection."
max_output_tokens: Optional[int] = None
"Token limit determines the maximum amount of text output from one prompt."
top_p: Optional[float] = None
"Tokens are selected from most probable to least until the sum of their "
"probabilities equals the top-p value. Top-p is ignored for Codey models."
top_k: Optional[int] = None
"How the model selects tokens for output, the next token is selected from "
"among the top-k most probable tokens. Top-k is ignored for Codey models."
credentials: Any = Field(default=None, exclude=True)
"The default custom credentials (google.auth.credentials.Credentials) to use "
"when making API calls. If not provided, credentials will be ascertained from "
"the environment."
n: int = 1
"""How many completions to generate for each prompt."""
streaming: bool = False
"""Whether to stream the results or not."""
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None
"""The default safety settings to use for all generations.
For example:
from langchain_google_vertexai import HarmBlockThreshold, HarmCategory
safety_settings = {
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
}
""" # noqa: E501
@property
def _llm_type(self) -> str:
return "vertexai"
@property
def is_codey_model(self) -> bool:
return is_codey_model(self.model_name)
@property
def _is_gemini_model(self) -> bool:
return is_gemini_model(self.model_name)
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Gets the identifying parameters."""
return {**{"model_name": self.model_name}, **self._default_params}
@property
def _default_params(self) -> Dict[str, Any]:
if self._is_gemini_model:
default_params = {}
else:
default_params = {
"temperature": _PALM_DEFAULT_TEMPERATURE,
"max_output_tokens": _PALM_DEFAULT_MAX_OUTPUT_TOKENS,
"top_p": _PALM_DEFAULT_TOP_P,
"top_k": _PALM_DEFAULT_TOP_K,
}
params = {
"temperature": self.temperature,
"max_output_tokens": self.max_output_tokens,
"candidate_count": self.n,
}
if not self.is_codey_model:
params.update(
{
"top_k": self.top_k,
"top_p": self.top_p,
}
)
updated_params = {}
for param_name, param_value in params.items():
default_value = default_params.get(param_name)
if param_value or default_value:
updated_params[param_name] = (
param_value if param_value else default_value
)
return updated_params
@classmethod
def _init_vertexai(cls, values: Dict) -> None:
vertexai.init(
project=values.get("project"),
location=values.get("location"),
credentials=values.get("credentials"),
)
return None
def _prepare_params(
self,
stop: Optional[List[str]] = None,
stream: bool = False,
**kwargs: Any,
) -> dict:
stop_sequences = stop or self.stop
params_mapping = {"n": "candidate_count"}
params = {params_mapping.get(k, k): v for k, v in kwargs.items()}
params = {**self._default_params, "stop_sequences": stop_sequences, **params}
if stream or self.streaming:
params.pop("candidate_count")
return params
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.
Useful for checking if an input will fit in a model's context window.
Args:
text: The string input to tokenize.
Returns:
The integer number of tokens in the text.
"""
is_palm_chat_model = isinstance(
self.client_preview, PreviewChatModel
) or isinstance(self.client_preview, PreviewCodeChatModel)
if is_palm_chat_model:
result = self.client_preview.start_chat().count_tokens(text)
else:
result = self.client_preview.count_tokens([text])
return result.total_tokens
class VertexAI(_VertexAICommon, BaseLLM):
"""Google Vertex AI large language models."""
model_name: str = "text-bison"
"The name of the Vertex AI large language model."
tuned_model_name: Optional[str] = None
"The name of a tuned model. If provided, model_name is ignored."
@classmethod
def is_lc_serializable(self) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "llms", "vertexai"]
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""
tuned_model_name = values.get("tuned_model_name")
model_name = values["model_name"]
safety_settings = values["safety_settings"]
is_gemini = is_gemini_model(values["model_name"])
cls._init_vertexai(values)
if safety_settings and (not is_gemini or tuned_model_name):
raise ValueError("Safety settings are only supported for Gemini models")
if is_codey_model(model_name):
model_cls = CodeGenerationModel
preview_model_cls = PreviewCodeGenerationModel
elif is_gemini:
model_cls = GenerativeModel
preview_model_cls = GenerativeModel
else:
model_cls = TextGenerationModel
preview_model_cls = PreviewTextGenerationModel
if tuned_model_name:
values["client"] = model_cls.get_tuned_model(tuned_model_name)
values["client_preview"] = preview_model_cls.get_tuned_model(
tuned_model_name
)
else:
if is_gemini:
values["client"] = model_cls(
model_name=model_name, safety_settings=safety_settings
)
values["client_preview"] = preview_model_cls(
model_name=model_name, safety_settings=safety_settings
)
else:
values["client"] = model_cls.from_pretrained(model_name)
values["client_preview"] = preview_model_cls.from_pretrained(model_name)
if values["streaming"] and values["n"] > 1:
raise ValueError("Only one candidate can be generated with streaming!")
return values
def _response_to_generation(
self, response: TextGenerationResponse, *, stream: bool = False
) -> GenerationChunk:
"""Converts a stream response to a generation chunk."""
generation_info = get_generation_info(
response, self._is_gemini_model, stream=stream
)
try:
text = response.text
except AttributeError:
text = ""
except ValueError:
text = ""
return GenerationChunk(
text=text,
generation_info=generation_info,
)
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> LLMResult:
should_stream = stream if stream is not None else self.streaming
params = self._prepare_params(stop=stop, stream=should_stream, **kwargs)
generations: List[List[Generation]] = []
for prompt in prompts:
if should_stream:
generation = GenerationChunk(text="")
for chunk in self._stream(
prompt, stop=stop, run_manager=run_manager, **kwargs
):
generation += chunk
generations.append([generation])
else:
res = _completion_with_retry(
self,
[prompt],
stream=should_stream,
is_gemini=self._is_gemini_model,
run_manager=run_manager,
**params,
)
generations.append(
[self._response_to_generation(r) for r in res.candidates]
)
return LLMResult(generations=generations)
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
params = self._prepare_params(stop=stop, **kwargs)
generations: List[List[Generation]] = []
for prompt in prompts:
res = await _acompletion_with_retry(
self,
prompt,
is_gemini=self._is_gemini_model,
run_manager=run_manager,
**params,
)
generations.append(
[self._response_to_generation(r) for r in res.candidates]
)
return LLMResult(generations=generations)
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = self._prepare_params(stop=stop, stream=True, **kwargs)
for stream_resp in _completion_with_retry(
self,
[prompt],
stream=True,
is_gemini=self._is_gemini_model,
run_manager=run_manager,
**params,
):
# Gemini models return GenerationResponse even when streaming, which has a
# candidates field.
stream_resp = (
stream_resp
if isinstance(stream_resp, TextGenerationResponse)
else stream_resp.candidates[0]
)
chunk = self._response_to_generation(stream_resp, stream=True)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
chunk=chunk,
verbose=self.verbose,
)
class VertexAIModelGarden(_VertexAIBase, BaseLLM):
"""Large language models served from Vertex AI Model Garden."""
client: Any = None #: :meta private:
async_client: Any = None #: :meta private:
endpoint_id: str
"A name of an endpoint where the model has been deployed."
allowed_model_args: Optional[List[str]] = None
"Allowed optional args to be passed to the model."
prompt_arg: str = "prompt"
result_arg: Optional[str] = "generated_text"
"Set result_arg to None if output of the model is expected to be a string."
"Otherwise, if it's a dict, provided an argument that contains the result."
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""
if not values["project"]:
raise ValueError(
"A GCP project should be provided to run inference on Model Garden!"
)
client_options = ClientOptions(
api_endpoint=f"{values['location']}-aiplatform.googleapis.com"
)
client_info = get_client_info(module="vertex-ai-model-garden")
values["client"] = PredictionServiceClient(
client_options=client_options, client_info=client_info
)
values["async_client"] = PredictionServiceAsyncClient(
client_options=client_options, client_info=client_info
)
return values
@property
def endpoint_path(self) -> str:
return self.client.endpoint_path(
project=self.project, location=self.location, endpoint=self.endpoint_id
)
@property
def _llm_type(self) -> str:
return "vertexai_model_garden"
def _prepare_request(self, prompts: List[str], **kwargs: Any) -> List["Value"]:
instances = []
for prompt in prompts:
if self.allowed_model_args:
instance = {
k: v for k, v in kwargs.items() if k in self.allowed_model_args
}
else:
instance = {}
instance[self.prompt_arg] = prompt
instances.append(instance)
predict_instances = [
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
]
return predict_instances
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
instances = self._prepare_request(prompts, **kwargs)
response = self.client.predict(endpoint=self.endpoint_path, instances=instances)
return self._parse_response(response)
def _parse_response(self, predictions: "Prediction") -> LLMResult:
generations: List[List[Generation]] = []
for result in predictions.predictions:
generations.append(
[
Generation(text=self._parse_prediction(prediction))
for prediction in result
]
)
return LLMResult(generations=generations)
def _parse_prediction(self, prediction: Any) -> str:
if isinstance(prediction, str):
return prediction
if self.result_arg:
try:
return prediction[self.result_arg]
except KeyError:
if isinstance(prediction, str):
error_desc = (
"Provided non-None `result_arg` (result_arg="
f"{self.result_arg}). But got prediction of type "
f"{type(prediction)} instead of dict. Most probably, you"
"need to set `result_arg=None` during VertexAIModelGarden "
"initialization."
)
raise ValueError(error_desc)
else:
raise ValueError(f"{self.result_arg} key not found in prediction!")
return prediction
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
instances = self._prepare_request(prompts, **kwargs)
response = await self.async_client.predict(
endpoint=self.endpoint_path, instances=instances
)
return self._parse_response(response)

File diff suppressed because it is too large Load Diff

@ -1,110 +0,0 @@
[tool.poetry]
name = "langchain-google-vertexai"
version = "0.0.5"
description = "An integration package connecting GoogleVertexAI and LangChain"
authors = []
readme = "README.md"
repository = "https://github.com/langchain-ai/langchain"
license = "MIT"
[tool.poetry.urls]
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/google-vertexai"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = "^0.1.7"
google-cloud-aiplatform = "^1.39.0"
google-cloud-storage = "^2.14.0"
[tool.poetry.group.test]
optional = true
[tool.poetry.group.test.dependencies]
pytest = "^7.3.0"
freezegun = "^1.2.2"
pytest-mock = "^3.10.0"
syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
langchain-core = { path = "../../core", develop = true }
[tool.poetry.group.codespell]
optional = true
[tool.poetry.group.codespell.dependencies]
codespell = "^2.2.0"
[tool.poetry.group.test_integration]
optional = true
[tool.poetry.group.test_integration.dependencies]
langchain = { path = "../../langchain" }
langchain-community = { path = "../../community" }
numexpr = { version = "^2.8.8", python = ">=3.9,<4.0" }
google-api-python-client = "^2.114.0"
[tool.poetry.group.lint]
optional = true
[tool.poetry.group.lint.dependencies]
ruff = "^0.1.5"
[tool.poetry.group.typing.dependencies]
mypy = "^1"
langchain-core = { path = "../../core", develop = true }
types-google-cloud-ndb = "^2.2.0.20240106"
types-requests = "^2.31.0.20231231"
types-protobuf = "^4.24.0.4"
[tool.poetry.group.dev]
optional = true
[tool.poetry.group.dev.dependencies]
langchain-core = { path = "../../core", develop = true }
[tool.ruff.lint]
select = [
"E", # pycodestyle
"F", # pyflakes
"I", # isort
"T201", # print
]
[tool.mypy]
check_untyped_defs = true
error_summary = false
pretty = true
show_column_numbers = true
show_error_codes = true
show_error_context = true
warn_redundant_casts = true
warn_unreachable = true
warn_unused_configs = true
warn_unused_ignores = true
[tool.coverage.run]
omit = ["tests/*"]
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
# --strict-markers will raise errors on unknown marks.
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
#
# https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors.
#
# https://github.com/tophat/syrupy
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
# Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [
"requires: mark tests as requiring a specific library",
"asyncio: mark tests as requiring asyncio",
"compile: mark placeholder test used to compile integration tests without running them",
]
asyncio_mode = "auto"

@ -1,17 +0,0 @@
import sys
import traceback
from importlib.machinery import SourceFileLoader
if __name__ == "__main__":
files = sys.argv[1:]
has_failure = False
for file in files:
try:
SourceFileLoader("x", file).load_module()
except Exception:
has_faillure = True
print(file) # noqa: T201
traceback.print_exc()
print() # noqa: T201
sys.exit(1 if has_failure else 0)

@ -1,27 +0,0 @@
#!/bin/bash
#
# This script searches for lines starting with "import pydantic" or "from pydantic"
# in tracked files within a Git repository.
#
# Usage: ./scripts/check_pydantic.sh /path/to/repository
# Check if a path argument is provided
if [ $# -ne 1 ]; then
echo "Usage: $0 /path/to/repository"
exit 1
fi
repository_path="$1"
# Search for lines matching the pattern within the specified repository
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
# Check if any matching lines were found
if [ -n "$result" ]; then
echo "ERROR: The following lines need to be updated:"
echo "$result"
echo "Please replace the code with an import from langchain_core.pydantic_v1."
echo "For example, replace 'from pydantic import BaseModel'"
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
exit 1
fi

@ -1,17 +0,0 @@
#!/bin/bash
set -eu
# Initialize a variable to keep track of errors
errors=0
# make sure not importing from langchain or langchain_experimental
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
# Decide on an exit status based on the errors
if [ "$errors" -gt 0 ]; then
exit 1
else
exit 0
fi

@ -1,260 +0,0 @@
"""Test ChatGoogleVertexAI chat model."""
import json
from typing import Optional, cast
import pytest
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.pydantic_v1 import BaseModel
from langchain_google_vertexai.chat_models import ChatVertexAI
model_names_to_test = [None, "codechat-bison", "chat-bison", "gemini-pro"]
@pytest.mark.parametrize("model_name", model_names_to_test)
def test_initialization(model_name: Optional[str]) -> None:
"""Test chat model initialization."""
if model_name:
model = ChatVertexAI(model_name=model_name)
else:
model = ChatVertexAI()
assert model._llm_type == "vertexai"
try:
assert model.model_name == model.client._model_id
except AttributeError:
assert model.model_name == model.client._model_name.split("/")[-1]
@pytest.mark.parametrize("model_name", model_names_to_test)
def test_vertexai_single_call(model_name: Optional[str]) -> None:
if model_name:
model = ChatVertexAI(model_name=model_name)
else:
model = ChatVertexAI()
message = HumanMessage(content="Hello")
response = model([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
# mark xfail because Vertex API randomly doesn't respect
# the n/candidate_count parameter
@pytest.mark.xfail
def test_candidates() -> None:
model = ChatVertexAI(model_name="chat-bison@001", temperature=0.3, n=2)
message = HumanMessage(content="Hello")
response = model.generate(messages=[[message]])
assert isinstance(response, LLMResult)
assert len(response.generations) == 1
assert len(response.generations[0]) == 2
@pytest.mark.parametrize("model_name", ["chat-bison@001", "gemini-pro"])
async def test_vertexai_agenerate(model_name: str) -> None:
model = ChatVertexAI(temperature=0, model_name=model_name)
message = HumanMessage(content="Hello")
response = await model.agenerate([[message]])
assert isinstance(response, LLMResult)
assert isinstance(response.generations[0][0].message, AIMessage) # type: ignore
sync_response = model.generate([[message]])
sync_generation = cast(ChatGeneration, sync_response.generations[0][0])
async_generation = cast(ChatGeneration, response.generations[0][0])
# assert some properties to make debugging easier
# xfail: this is not equivalent with temp=0 right now
# assert sync_generation.message.content == async_generation.message.content
assert sync_generation.generation_info == async_generation.generation_info
# xfail: content is not same right now
# assert sync_generation == async_generation
@pytest.mark.parametrize("model_name", ["chat-bison@001", "gemini-pro"])
def test_vertexai_stream(model_name: str) -> None:
model = ChatVertexAI(temperature=0, model_name=model_name)
message = HumanMessage(content="Hello")
sync_response = model.stream([message])
for chunk in sync_response:
assert isinstance(chunk, AIMessageChunk)
def test_vertexai_single_call_with_context() -> None:
model = ChatVertexAI()
raw_context = (
"My name is Ned. You are my personal assistant. My favorite movies "
"are Lord of the Rings and Hobbit."
)
question = (
"Hello, could you recommend a good movie for me to watch this evening, please?"
)
context = SystemMessage(content=raw_context)
message = HumanMessage(content=question)
response = model([context, message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_multimodal() -> None:
llm = ChatVertexAI(model_name="gemini-pro-vision")
gcs_url = (
"gs://cloud-samples-data/generative-ai/image/"
"320px-Felis_catus-cat_on_snow.jpg"
)
image_message = {
"type": "image_url",
"image_url": {"url": gcs_url},
}
text_message = {
"type": "text",
"text": "What is shown in this image?",
}
message = HumanMessage(content=[text_message, image_message])
output = llm([message])
assert isinstance(output.content, str)
@pytest.mark.xfail(reason="problem on vertex side")
def test_multimodal_history() -> None:
llm = ChatVertexAI(model_name="gemini-pro-vision")
gcs_url = (
"gs://cloud-samples-data/generative-ai/image/"
"320px-Felis_catus-cat_on_snow.jpg"
)
image_message = {
"type": "image_url",
"image_url": {"url": gcs_url},
}
text_message = {
"type": "text",
"text": "What is shown in this image?",
}
message1 = HumanMessage(content=[text_message, image_message])
message2 = AIMessage(
content=(
"This is a picture of a cat in the snow. The cat is a tabby cat, which is "
"a type of cat with a striped coat. The cat is standing in the snow, and "
"its fur is covered in snow."
)
)
message3 = HumanMessage(content="What time of day is it?")
response = llm([message1, message2, message3])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_vertexai_single_call_with_examples() -> None:
model = ChatVertexAI()
raw_context = "My name is Ned. You are my personal assistant."
question = "2+2"
text_question, text_answer = "4+4", "8"
inp = HumanMessage(content=text_question)
output = AIMessage(content=text_answer)
context = SystemMessage(content=raw_context)
message = HumanMessage(content=question)
response = model([context, message], examples=[inp, output])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
@pytest.mark.parametrize("model_name", model_names_to_test)
def test_vertexai_single_call_with_history(model_name: Optional[str]) -> None:
if model_name:
model = ChatVertexAI(model_name=model_name)
else:
model = ChatVertexAI()
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
response = model([message1, message2, message3])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_vertexai_single_call_fails_no_message() -> None:
chat = ChatVertexAI()
with pytest.raises(ValueError) as exc_info:
_ = chat([])
assert (
str(exc_info.value)
== "You should provide at least one message to start the chat!"
)
@pytest.mark.parametrize("model_name", ["gemini-pro"])
def test_chat_vertexai_gemini_system_message_error(model_name: str) -> None:
model = ChatVertexAI(model_name=model_name)
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content="You're supposed to answer math questions.")
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
with pytest.raises(ValueError):
model([system_message, message1, message2, message3])
@pytest.mark.parametrize("model_name", model_names_to_test)
def test_chat_vertexai_system_message(model_name: Optional[str]) -> None:
if model_name:
model = ChatVertexAI(
model_name=model_name, convert_system_message_to_human=True
)
else:
model = ChatVertexAI()
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content="You're supposed to answer math questions.")
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
response = model([system_message, message1, message2, message3])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
@pytest.mark.parametrize("model_name", model_names_to_test)
def test_get_num_tokens_from_messages(model_name: str) -> None:
if model_name:
model = ChatVertexAI(model_name=model_name, temperature=0.0)
else:
model = ChatVertexAI(temperature=0.0)
message = HumanMessage(content="Hello")
token = model.get_num_tokens_from_messages(messages=[message])
assert isinstance(token, int)
assert token == 3
def test_chat_vertexai_gemini_function_calling() -> None:
class MyModel(BaseModel):
name: str
age: int
model = ChatVertexAI(model_name="gemini-pro").bind(functions=[MyModel])
message = HumanMessage(content="My name is Erick and I am 27 years old")
response = model.invoke([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.content == ""
function_call = response.additional_kwargs.get("function_call")
assert function_call
assert function_call["name"] == "MyModel"
arguments_str = function_call.get("arguments")
assert arguments_str
arguments = json.loads(arguments_str)
assert arguments == {
"name": "Erick",
"age": 27.0,
}

@ -1,7 +0,0 @@
import pytest
@pytest.mark.compile
def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests."""
pass

@ -1,70 +0,0 @@
"""Test Vertex AI API wrapper.
Your end-user credentials would be used to make the calls (make sure you've run
`gcloud auth login` first).
"""
import pytest
from langchain_google_vertexai.embeddings import VertexAIEmbeddings
def test_initialization() -> None:
"""Test embedding model initialization."""
VertexAIEmbeddings()
def test_langchain_google_vertexai_embedding_documents() -> None:
documents = ["foo bar"]
model = VertexAIEmbeddings()
output = model.embed_documents(documents)
assert len(output) == 1
assert len(output[0]) == 768
assert model.model_name == model.client._model_id
assert model.model_name == "textembedding-gecko@001"
def test_langchain_google_vertexai_embedding_query() -> None:
document = "foo bar"
model = VertexAIEmbeddings()
output = model.embed_query(document)
assert len(output) == 768
def test_langchain_google_vertexai_large_batches() -> None:
documents = ["foo bar" for _ in range(0, 251)]
model_uscentral1 = VertexAIEmbeddings(location="us-central1")
model_asianortheast1 = VertexAIEmbeddings(location="asia-northeast1")
model_uscentral1.embed_documents(documents)
model_asianortheast1.embed_documents(documents)
assert model_uscentral1.instance["batch_size"] >= 250
assert model_asianortheast1.instance["batch_size"] < 50
def test_langchain_google_vertexai_paginated_texts() -> None:
documents = [
"foo bar",
"foo baz",
"bar foo",
"baz foo",
"bar bar",
"foo foo",
"baz baz",
"baz bar",
]
model = VertexAIEmbeddings()
output = model.embed_documents(documents)
assert len(output) == 8
assert len(output[0]) == 768
assert model.model_name == model.client._model_id
def test_warning(caplog: pytest.LogCaptureFixture) -> None:
_ = VertexAIEmbeddings()
assert len(caplog.records) == 1
record = caplog.records[0]
assert record.levelname == "WARNING"
expected_message = (
"Model_name will become a required arg for VertexAIEmbeddings starting from "
"Feb-01-2024. Currently the default is set to textembedding-gecko@001"
)
assert record.message == expected_message

@ -1,195 +0,0 @@
"""Test Vertex AI API wrapper.
Your end-user credentials would be used to make the calls (make sure you've run
`gcloud auth login` first).
"""
import os
from typing import Optional
import pytest
from langchain_core.outputs import LLMResult
from langchain_google_vertexai.llms import VertexAI, VertexAIModelGarden
model_names_to_test = ["text-bison@001", "gemini-pro"]
model_names_to_test_with_default = [None] + model_names_to_test
@pytest.mark.parametrize(
"model_name",
model_names_to_test_with_default,
)
def test_vertex_initialization(model_name: str) -> None:
llm = VertexAI(model_name=model_name) if model_name else VertexAI()
assert llm._llm_type == "vertexai"
try:
assert llm.model_name == llm.client._model_id
except AttributeError:
assert llm.model_name == llm.client._model_name.split("/")[-1]
@pytest.mark.parametrize(
"model_name",
model_names_to_test_with_default,
)
def test_vertex_invoke(model_name: str) -> None:
llm = (
VertexAI(model_name=model_name, temperature=0)
if model_name
else VertexAI(temperature=0.0)
)
output = llm.invoke("Say foo:")
assert isinstance(output, str)
@pytest.mark.parametrize(
"model_name",
model_names_to_test_with_default,
)
def test_vertex_generate(model_name: str) -> None:
llm = (
VertexAI(model_name=model_name, temperature=0)
if model_name
else VertexAI(temperature=0.0)
)
output = llm.generate(["Say foo:"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 1
@pytest.mark.xfail(reason="VertexAI doesn't always respect number of candidates")
def test_vertex_generate_multiple_candidates() -> None:
llm = VertexAI(temperature=0.3, n=2, model_name="text-bison@001")
output = llm.generate(["Say foo:"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 1
assert len(output.generations[0]) == 2
@pytest.mark.xfail(reason="VertexAI doesn't always respect number of candidates")
def test_vertex_generate_code() -> None:
llm = VertexAI(temperature=0.3, n=2, model_name="code-bison@001")
output = llm.generate(["generate a python method that says foo:"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 1
assert len(output.generations[0]) == 2
async def test_vertex_agenerate() -> None:
llm = VertexAI(temperature=0)
output = await llm.agenerate(["Please say foo:"])
assert isinstance(output, LLMResult)
@pytest.mark.parametrize(
"model_name",
model_names_to_test_with_default,
)
def test_stream(model_name: str) -> None:
llm = (
VertexAI(temperature=0, model_name=model_name)
if model_name
else VertexAI(temperature=0)
)
for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token, str)
async def test_vertex_consistency() -> None:
llm = VertexAI(temperature=0)
output = llm.generate(["Please say foo:"])
streaming_output = llm.generate(["Please say foo:"], stream=True)
async_output = await llm.agenerate(["Please say foo:"])
assert output.generations[0][0].text == streaming_output.generations[0][0].text
assert output.generations[0][0].text == async_output.generations[0][0].text
@pytest.mark.skip("CI testing not set up")
@pytest.mark.parametrize(
"endpoint_os_variable_name,result_arg",
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
)
def test_model_garden(
endpoint_os_variable_name: str, result_arg: Optional[str]
) -> None:
"""In order to run this test, you should provide endpoint names.
Example:
export FALCON_ENDPOINT_ID=...
export LLAMA_ENDPOINT_ID=...
export PROJECT=...
"""
endpoint_id = os.environ[endpoint_os_variable_name]
project = os.environ["PROJECT"]
location = "europe-west4"
llm = VertexAIModelGarden(
endpoint_id=endpoint_id,
project=project,
result_arg=result_arg,
location=location,
)
output = llm("What is the meaning of life?")
assert isinstance(output, str)
assert llm._llm_type == "vertexai_model_garden"
@pytest.mark.skip("CI testing not set up")
@pytest.mark.parametrize(
"endpoint_os_variable_name,result_arg",
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
)
def test_model_garden_generate(
endpoint_os_variable_name: str, result_arg: Optional[str]
) -> None:
"""In order to run this test, you should provide endpoint names.
Example:
export FALCON_ENDPOINT_ID=...
export LLAMA_ENDPOINT_ID=...
export PROJECT=...
"""
endpoint_id = os.environ[endpoint_os_variable_name]
project = os.environ["PROJECT"]
location = "europe-west4"
llm = VertexAIModelGarden(
endpoint_id=endpoint_id,
project=project,
result_arg=result_arg,
location=location,
)
output = llm.generate(["What is the meaning of life?", "How much is 2+2"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 2
@pytest.mark.skip("CI testing not set up")
@pytest.mark.asyncio
@pytest.mark.parametrize(
"endpoint_os_variable_name,result_arg",
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
)
async def test_model_garden_agenerate(
endpoint_os_variable_name: str, result_arg: Optional[str]
) -> None:
endpoint_id = os.environ[endpoint_os_variable_name]
project = os.environ["PROJECT"]
location = "europe-west4"
llm = VertexAIModelGarden(
endpoint_id=endpoint_id,
project=project,
result_arg=result_arg,
location=location,
)
output = await llm.agenerate(["What is the meaning of life?", "How much is 2+2"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 2
@pytest.mark.parametrize(
"model_name",
model_names_to_test,
)
def test_vertex_call_count_tokens(model_name: str) -> None:
llm = VertexAI(model_name=model_name)
output = llm.get_num_tokens("How are you?")
assert output == 4

@ -1,97 +0,0 @@
from langchain_core.outputs import LLMResult
from langchain_google_vertexai import HarmBlockThreshold, HarmCategory, VertexAI
SAFETY_SETTINGS = {
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
}
# below context and question are taken from one of opensource QA datasets
BLOCKED_PROMPT = """
You are agent designed to answer questions.
You are given context in triple backticks.
```
The religion\'s failure to report abuse allegations to authorities has also been
criticized. The Watch Tower Society\'s policy is that elders inform authorities when
required by law to do so, but otherwise leave that action up to the victim and his
or her family. The Australian Royal Commission into Institutional Responses to Child
Sexual Abuse found that of 1006 alleged perpetrators of child sexual abuse
identified by the Jehovah\'s Witnesses within their organization since 1950,
"not one was reported by the church to secular authorities." William Bowen, a former
Jehovah\'s Witness elder who established the Silentlambs organization to assist sex
abuse victims within the religion, has claimed Witness leaders discourage followers
from reporting incidents of sexual misconduct to authorities, and other critics claim
the organization is reluctant to alert authorities in order to protect its "crime-free"
reputation. In court cases in the United Kingdom and the United States the Watch Tower
Society has been found to have been negligent in its failure to protect children from
known sex offenders within the congregation and the Society has settled other child
abuse lawsuits out of court, reportedly paying as much as $780,000 to one plaintiff
without admitting wrongdoing.
```
Question: What have courts in both the UK and the US found the Watch Tower Society to
have been for failing to protect children from sexual predators within the
congregation ?
Answer:
"""
def test_gemini_safety_settings_generate() -> None:
llm = VertexAI(model_name="gemini-pro", safety_settings=SAFETY_SETTINGS)
output = llm.generate(["What do you think about child abuse:"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 1
generation_info = output.generations[0][0].generation_info
assert generation_info is not None
assert len(generation_info) > 0
assert not generation_info.get("is_blocked")
blocked_output = llm.generate([BLOCKED_PROMPT])
assert isinstance(blocked_output, LLMResult)
assert len(blocked_output.generations) == 1
assert len(blocked_output.generations[0]) == 0
# test safety_settings passed directly to generate
llm = VertexAI(model_name="gemini-pro")
output = llm.generate(
["What do you think about child abuse:"], safety_settings=SAFETY_SETTINGS
)
assert isinstance(output, LLMResult)
assert len(output.generations) == 1
generation_info = output.generations[0][0].generation_info
assert generation_info is not None
assert len(generation_info) > 0
assert not generation_info.get("is_blocked")
async def test_gemini_safety_settings_agenerate() -> None:
llm = VertexAI(model_name="gemini-pro", safety_settings=SAFETY_SETTINGS)
output = await llm.agenerate(["What do you think about child abuse:"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 1
generation_info = output.generations[0][0].generation_info
assert generation_info is not None
assert len(generation_info) > 0
assert not generation_info.get("is_blocked")
blocked_output = await llm.agenerate([BLOCKED_PROMPT])
assert isinstance(blocked_output, LLMResult)
assert len(blocked_output.generations) == 1
# assert len(blocked_output.generations[0][0].generation_info) > 0
# assert blocked_output.generations[0][0].generation_info.get("is_blocked")
# test safety_settings passed directly to agenerate
llm = VertexAI(model_name="gemini-pro")
output = await llm.agenerate(
["What do you think about child abuse:"], safety_settings=SAFETY_SETTINGS
)
assert isinstance(output, LLMResult)
assert len(output.generations) == 1
generation_info = output.generations[0][0].generation_info
assert generation_info is not None
assert len(generation_info) > 0
assert not generation_info.get("is_blocked")

@ -1,172 +0,0 @@
import os
import re
from typing import Any, List, Union
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
from langchain_core.messages import AIMessageChunk
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.tools import Tool
from langchain_google_vertexai.chat_models import ChatVertexAI
class _TestOutputParser(BaseOutputParser):
def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> Union[AgentAction, AgentFinish]:
if not isinstance(result[0], ChatGeneration):
raise ValueError("This output parser only works on ChatGeneration output")
message = result[0].message
function_call = message.additional_kwargs.get("function_call", {})
if function_call:
function_name = function_call["name"]
tool_input = function_call.get("arguments", {})
content_msg = f"responded: {message.content}\n" if message.content else "\n"
log_msg = (
f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n"
)
return AgentActionMessageLog(
tool=function_name,
tool_input=tool_input,
log=log_msg,
message_log=[message],
)
return AgentFinish(
return_values={"output": message.content}, log=str(message.content)
)
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
raise ValueError("Can only parse messages")
def test_tools() -> None:
from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad import (
format_to_openai_function_messages,
)
from langchain.chains import LLMMathChain
llm = ChatVertexAI(model_name="gemini-pro")
math_chain = LLMMathChain.from_llm(llm=llm)
tools = [
Tool(
name="Calculator",
func=math_chain.run,
description="useful for when you need to answer questions about math",
)
]
prompt = ChatPromptTemplate.from_messages(
[
("user", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
llm_with_tools = llm.bind(functions=tools)
agent: Any = (
{
"input": lambda x: x["input"],
"agent_scratchpad": lambda x: format_to_openai_function_messages(
x["intermediate_steps"]
),
}
| prompt
| llm_with_tools
| _TestOutputParser()
)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
response = agent_executor.invoke({"input": "What is 6 raised to the 0.43 power?"})
assert isinstance(response, dict)
assert response["input"] == "What is 6 raised to the 0.43 power?"
# convert string " The result is 2.160752567226312" to just numbers/periods
# use regex to find \d+\.\d+
just_numbers = re.findall(r"\d+\.\d+", response["output"])[0]
assert round(float(just_numbers), 2) == 2.16
def test_stream() -> None:
from langchain.chains import LLMMathChain
llm = ChatVertexAI(model_name="gemini-pro")
math_chain = LLMMathChain.from_llm(llm=llm)
tools = [
Tool(
name="Calculator",
func=math_chain.run,
description="useful for when you need to answer questions about math",
)
]
response = list(llm.stream("What is 6 raised to the 0.43 power?", functions=tools))
assert len(response) == 1
assert isinstance(response[0], AIMessageChunk)
assert "function_call" in response[0].additional_kwargs
def test_multiple_tools() -> None:
from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad import format_to_openai_function_messages
from langchain.chains import LLMMathChain
from langchain.utilities import (
GoogleSearchAPIWrapper,
)
llm = ChatVertexAI(model_name="gemini-pro", max_output_tokens=1024)
math_chain = LLMMathChain.from_llm(llm=llm)
google_search_api_key = os.environ["GOOGLE_SEARCH_API_KEY"]
google_cse_id = os.environ["GOOGLE_CSE_ID"]
search = GoogleSearchAPIWrapper(
k=10, google_api_key=google_search_api_key, google_cse_id=google_cse_id
)
tools = [
Tool(
name="Calculator",
func=math_chain.run,
description="useful for when you need to answer questions about math",
),
Tool(
name="Search",
func=search.run,
description=(
"useful for when you need to answer questions about current events. "
"You should ask targeted questions"
),
),
]
prompt = ChatPromptTemplate.from_messages(
[
("user", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
llm_with_tools = llm.bind(functions=tools)
agent: Any = (
{
"input": lambda x: x["input"],
"agent_scratchpad": lambda x: format_to_openai_function_messages(
x["intermediate_steps"]
),
}
| prompt
| llm_with_tools
| _TestOutputParser()
)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
question = (
"Who is Leo DiCaprio's girlfriend? What is her "
"current age raised to the 0.43 power?"
)
response = agent_executor.invoke({"input": question})
assert isinstance(response, dict)
assert response["input"] == question
# xfail: not getting age in search result most of time
# assert "3.850" in response["output"]

@ -1,318 +0,0 @@
"""Test chat model integration."""
import json
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from unittest.mock import MagicMock, Mock, patch
import pytest
from google.cloud.aiplatform_v1beta1.types import (
Content,
FunctionCall,
Part,
)
from google.cloud.aiplatform_v1beta1.types import (
content as gapic_content_types,
)
from langchain_core.messages import (
AIMessage,
HumanMessage,
SystemMessage,
)
from vertexai.language_models import ChatMessage, InputOutputTextPair # type: ignore
from vertexai.preview.generative_models import ( # type: ignore
Candidate,
)
from langchain_google_vertexai.chat_models import (
ChatVertexAI,
_parse_chat_history,
_parse_chat_history_gemini,
_parse_examples,
_parse_response_candidate,
)
def test_parse_examples_correct() -> None:
text_question = (
"Hello, could you recommend a good movie for me to watch this evening, please?"
)
question = HumanMessage(content=text_question)
text_answer = (
"Sure, You might enjoy The Lord of the Rings: The Fellowship of the Ring "
"(2001): This is the first movie in the Lord of the Rings trilogy."
)
answer = AIMessage(content=text_answer)
examples = _parse_examples([question, answer, question, answer])
assert len(examples) == 2
assert examples == [
InputOutputTextPair(input_text=text_question, output_text=text_answer),
InputOutputTextPair(input_text=text_question, output_text=text_answer),
]
def test_parse_examples_failes_wrong_sequence() -> None:
with pytest.raises(ValueError) as exc_info:
_ = _parse_examples([AIMessage(content="a")])
assert (
str(exc_info.value)
== "Expect examples to have an even amount of messages, got 1."
)
@dataclass
class StubTextChatResponse:
"""Stub text-chat response from VertexAI for testing."""
text: str
@pytest.mark.parametrize("stop", [None, "stop1"])
def test_vertexai_args_passed(stop: Optional[str]) -> None:
response_text = "Goodbye"
user_prompt = "Hello"
prompt_params: Dict[str, Any] = {
"max_output_tokens": 1,
"temperature": 10000.0,
"top_k": 10,
"top_p": 0.5,
}
# Mock the library to ensure the args are passed correctly
with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg:
mock_response = MagicMock()
mock_response.candidates = [StubTextChatResponse(text=response_text)]
mock_chat = MagicMock()
mock_send_message = MagicMock(return_value=mock_response)
mock_chat.send_message = mock_send_message
mock_model = MagicMock()
mock_start_chat = MagicMock(return_value=mock_chat)
mock_model.start_chat = mock_start_chat
mg.return_value = mock_model
model = ChatVertexAI(**prompt_params)
message = HumanMessage(content=user_prompt)
if stop:
response = model([message], stop=[stop])
else:
response = model([message])
assert response.content == response_text
mock_send_message.assert_called_once_with(user_prompt, candidate_count=1)
expected_stop_sequence = [stop] if stop else None
mock_start_chat.assert_called_once_with(
context=None,
message_history=[],
**prompt_params,
stop_sequences=expected_stop_sequence,
)
def test_parse_chat_history_correct() -> None:
text_context = (
"My name is Ned. You are my personal assistant. My "
"favorite movies are Lord of the Rings and Hobbit."
)
context = SystemMessage(content=text_context)
text_question = (
"Hello, could you recommend a good movie for me to watch this evening, please?"
)
question = HumanMessage(content=text_question)
text_answer = (
"Sure, You might enjoy The Lord of the Rings: The Fellowship of the Ring "
"(2001): This is the first movie in the Lord of the Rings trilogy."
)
answer = AIMessage(content=text_answer)
history = _parse_chat_history([context, question, answer, question, answer])
assert history.context == context.content
assert len(history.history) == 4
assert history.history == [
ChatMessage(content=text_question, author="user"),
ChatMessage(content=text_answer, author="bot"),
ChatMessage(content=text_question, author="user"),
ChatMessage(content=text_answer, author="bot"),
]
def test_parse_history_gemini() -> None:
system_input = "You're supposed to answer math questions."
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content=system_input)
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
messages = [system_message, message1, message2, message3]
history = _parse_chat_history_gemini(messages, convert_system_message_to_human=True)
assert len(history) == 3
assert history[0].role == "user"
assert history[0].parts[0].text == system_input
assert history[0].parts[1].text == text_question1
assert history[1].role == "model"
assert history[1].parts[0].text == text_answer1
def test_default_params_palm() -> None:
user_prompt = "Hello"
with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg:
mock_response = MagicMock()
mock_response.candidates = [StubTextChatResponse(text="Goodbye")]
mock_chat = MagicMock()
mock_send_message = MagicMock(return_value=mock_response)
mock_chat.send_message = mock_send_message
mock_model = MagicMock()
mock_start_chat = MagicMock(return_value=mock_chat)
mock_model.start_chat = mock_start_chat
mg.return_value = mock_model
model = ChatVertexAI(model_name="text-bison@001")
message = HumanMessage(content=user_prompt)
_ = model([message])
mock_start_chat.assert_called_once_with(
context=None,
message_history=[],
max_output_tokens=128,
top_k=40,
top_p=0.95,
stop_sequences=None,
)
@dataclass
class StubGeminiResponse:
"""Stub gemini response from VertexAI for testing."""
text: str
content: Any
citation_metadata: Any
safety_ratings: List[Any] = field(default_factory=list)
def test_default_params_gemini() -> None:
user_prompt = "Hello"
with patch("langchain_google_vertexai.chat_models.GenerativeModel") as gm:
mock_response = MagicMock()
mock_response.candidates = [
StubGeminiResponse(
text="Goodbye",
content=Mock(parts=[Mock(function_call=None)]),
citation_metadata=None,
)
]
mock_chat = MagicMock()
mock_send_message = MagicMock(return_value=mock_response)
mock_chat.send_message = mock_send_message
mock_model = MagicMock()
mock_start_chat = MagicMock(return_value=mock_chat)
mock_model.start_chat = mock_start_chat
gm.return_value = mock_model
model = ChatVertexAI(model_name="gemini-pro")
message = HumanMessage(content=user_prompt)
_ = model([message])
mock_start_chat.assert_called_once_with(history=[])
@pytest.mark.parametrize(
"raw_candidate, expected",
[
(
gapic_content_types.Candidate(
content=Content(
role="model",
parts=[
Part(
function_call=FunctionCall(
name="Information",
args={"name": "Ben"},
),
)
],
)
),
{
"name": "Information",
"arguments": {"name": "Ben"},
},
),
(
gapic_content_types.Candidate(
content=Content(
role="model",
parts=[
Part(
function_call=FunctionCall(
name="Information",
args={"info": ["A", "B", "C"]},
),
)
],
)
),
{
"name": "Information",
"arguments": {"info": ["A", "B", "C"]},
},
),
(
gapic_content_types.Candidate(
content=Content(
role="model",
parts=[
Part(
function_call=FunctionCall(
name="Information",
args={
"people": [
{"name": "Joe", "age": 30},
{"name": "Martha"},
]
},
),
)
],
)
),
{
"name": "Information",
"arguments": {
"people": [
{"name": "Joe", "age": 30},
{"name": "Martha"},
]
},
},
),
(
gapic_content_types.Candidate(
content=Content(
role="model",
parts=[
Part(
function_call=FunctionCall(
name="Information",
args={"info": [[1, 2, 3], [4, 5, 6]]},
),
)
],
)
),
{
"name": "Information",
"arguments": {"info": [[1, 2, 3], [4, 5, 6]]},
},
),
],
)
def test_parse_response_candidate(raw_candidate, expected) -> None:
response_candidate = Candidate._from_gapic(raw_candidate)
result = _parse_response_candidate(response_candidate)
result_arguments = json.loads(
result.additional_kwargs["function_call"]["arguments"]
)
assert result_arguments == expected["arguments"]

@ -1,16 +0,0 @@
from langchain_google_vertexai import __all__
EXPECTED_ALL = [
"ChatVertexAI",
"VertexAIEmbeddings",
"VertexAI",
"VertexAIModelGarden",
"HarmBlockThreshold",
"HarmCategory",
"PydanticFunctionsOutputParser",
"create_structured_runnable",
]
def test_all_imports() -> None:
assert sorted(EXPECTED_ALL) == sorted(__all__)
Loading…
Cancel
Save