mirror of https://github.com/hwchase17/langchain
google-vertexai: added langchain_google_vertexai package (#15218)
added langchain_google_vertexai package --------- Co-authored-by: Erick Friis <erick@langchain.dev>pull/5066/head
parent
e1fc4d5b95
commit
f73bf4ee54
@ -0,0 +1 @@
|
||||
__pycache__
|
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
@ -0,0 +1,59 @@
|
||||
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
|
||||
test:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
|
||||
######################
|
||||
# LINTING AND FORMATTING
|
||||
######################
|
||||
|
||||
# Define a variable for Python and notebook files.
|
||||
PYTHON_FILES=.
|
||||
MYPY_CACHE=.mypy_cache
|
||||
lint format: PYTHON_FILES=.
|
||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/google-vertexai --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||
lint_package: PYTHON_FILES=langchain_google_vertexai
|
||||
lint_tests: PYTHON_FILES=tests
|
||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||
|
||||
lint lint_diff lint_package lint_tests:
|
||||
poetry run ruff .
|
||||
poetry run ruff format $(PYTHON_FILES) --diff
|
||||
poetry run ruff --select I $(PYTHON_FILES)
|
||||
mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
|
||||
format format_diff:
|
||||
poetry run ruff format $(PYTHON_FILES)
|
||||
poetry run ruff --select I --fix $(PYTHON_FILES)
|
||||
|
||||
spell_check:
|
||||
poetry run codespell --toml pyproject.toml
|
||||
|
||||
spell_fix:
|
||||
poetry run codespell --toml pyproject.toml -w
|
||||
|
||||
check_imports: $(shell find langchain_google_vertexai -name '*.py')
|
||||
poetry run python ./scripts/check_imports.py $^
|
||||
|
||||
######################
|
||||
# HELP
|
||||
######################
|
||||
|
||||
help:
|
||||
@echo '----'
|
||||
@echo 'check_imports - check imports'
|
||||
@echo 'format - run code formatters'
|
||||
@echo 'lint - run linters'
|
||||
@echo 'test - run unit tests'
|
||||
@echo 'tests - run unit tests'
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
@ -0,0 +1,100 @@
|
||||
# langchain-google-vertexai
|
||||
|
||||
This package contains the LangChain integrations for Google Cloud generative models.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install -U langchain-google-vertexai
|
||||
```
|
||||
|
||||
## Chat Models
|
||||
|
||||
`ChatVertexAI` class exposes models .
|
||||
|
||||
To use, you should have Google Cloud project with APIs enabled, and configured credentials. Initialize the model as:
|
||||
|
||||
```python
|
||||
from langchain_google_vertexai import ChatVertexAI
|
||||
|
||||
llm = ChatVertexAI(model_name="gemini-pro")
|
||||
llm.invoke("Sing a ballad of LangChain.")
|
||||
```
|
||||
|
||||
You can use other models, e.g. `chat-bison`:
|
||||
```python
|
||||
from langchain_google_vertexai import ChatVertexAI
|
||||
|
||||
llm = ChatVertexAI(model_name="chat-bison", temperature=0.3)
|
||||
llm.invoke("Sing a ballad of LangChain.")
|
||||
```
|
||||
|
||||
#### Multimodal inputs
|
||||
|
||||
Gemini vision model supports image inputs when providing a single chat message. Example:
|
||||
|
||||
```python
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_google_vertexai import ChatVertexAI
|
||||
|
||||
llm = ChatVertexAI(model_name="gemini-pro-vision")
|
||||
# example
|
||||
message = HumanMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this image?",
|
||||
}, # You can optionally provide text parts
|
||||
{"type": "image_url", "image_url": {"url": "https://picsum.photos/seed/picsum/200/300"}},
|
||||
]
|
||||
)
|
||||
llm.invoke([message])
|
||||
```
|
||||
|
||||
The value of `image_url` can be any of the following:
|
||||
|
||||
- A public image URL
|
||||
- An accessible gcs file (e.g., "gcs://path/to/file.png")
|
||||
- A local file path
|
||||
- A base64 encoded image (e.g., `data:image/png;base64,abcd124`)
|
||||
|
||||
|
||||
## Embeddings
|
||||
|
||||
You can use Google Cloud's embeddings models as:
|
||||
|
||||
```
|
||||
from langchain_google_vertexai import VertexAIEmbeddings
|
||||
|
||||
embeddings = VertexAIEmbeddings()
|
||||
embeddings.embed_query("hello, world!")
|
||||
```
|
||||
|
||||
## LLMs
|
||||
You can use Google Cloud's generative AI models as Langchain LLMs:
|
||||
|
||||
```python
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain_google_vertexai import VertexAI
|
||||
|
||||
template = """Question: {question}
|
||||
|
||||
Answer: Let's think step by step."""
|
||||
prompt = PromptTemplate.from_template(template)
|
||||
|
||||
chain = prompt | llm
|
||||
|
||||
question = "Who was the president in the year Justin Beiber was born?"
|
||||
print(chain.invoke({"question": question}))
|
||||
```
|
||||
|
||||
You can use Gemini and Palm models, including code-generations ones:
|
||||
```python
|
||||
from langchain_google_vertexai import VertexAI
|
||||
|
||||
llm = VertexAI(model_name="code-bison", max_output_tokens=1000, temperature=0.3)
|
||||
|
||||
question = "Write a python function that checks if a string is a valid email address"
|
||||
|
||||
output = llm(question)
|
||||
```
|
@ -0,0 +1,5 @@
|
||||
from langchain_google_vertexai.chat_models import ChatVertexAI
|
||||
from langchain_google_vertexai.embeddings import VertexAIEmbeddings
|
||||
from langchain_google_vertexai.llms import VertexAI, VertexAIModelGarden
|
||||
|
||||
__all__ = ["ChatVertexAI", "VertexAIEmbeddings", "VertexAI", "VertexAIModelGarden"]
|
@ -0,0 +1,88 @@
|
||||
"""Utilities to init Vertex AI."""
|
||||
from importlib import metadata
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import google.api_core
|
||||
from google.api_core.gapic_v1.client_info import ClientInfo
|
||||
from google.cloud import storage # type: ignore
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.llms import create_base_retry_decorator
|
||||
from vertexai.preview.generative_models import Image # type: ignore
|
||||
|
||||
|
||||
def create_retry_decorator(
|
||||
*,
|
||||
max_retries: int = 1,
|
||||
run_manager: Optional[
|
||||
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||
] = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Creates a retry decorator for Vertex / Palm LLMs."""
|
||||
|
||||
errors = [
|
||||
google.api_core.exceptions.ResourceExhausted,
|
||||
google.api_core.exceptions.ServiceUnavailable,
|
||||
google.api_core.exceptions.Aborted,
|
||||
google.api_core.exceptions.DeadlineExceeded,
|
||||
google.api_core.exceptions.GoogleAPIError,
|
||||
]
|
||||
decorator = create_base_retry_decorator(
|
||||
error_types=errors, max_retries=max_retries, run_manager=run_manager
|
||||
)
|
||||
return decorator
|
||||
|
||||
|
||||
def raise_vertex_import_error(minimum_expected_version: str = "1.38.0") -> None:
|
||||
"""Raise ImportError related to Vertex SDK being not available.
|
||||
|
||||
Args:
|
||||
minimum_expected_version: The lowest expected version of the SDK.
|
||||
Raises:
|
||||
ImportError: an ImportError that mentions a required version of the SDK.
|
||||
"""
|
||||
raise ImportError(
|
||||
"Please, install or upgrade the google-cloud-aiplatform library: "
|
||||
f"pip install google-cloud-aiplatform>={minimum_expected_version}"
|
||||
)
|
||||
|
||||
|
||||
def get_client_info(module: Optional[str] = None) -> "ClientInfo":
|
||||
r"""Returns a custom user agent header.
|
||||
|
||||
Args:
|
||||
module (Optional[str]):
|
||||
Optional. The module for a custom user agent header.
|
||||
Returns:
|
||||
google.api_core.gapic_v1.client_info.ClientInfo
|
||||
"""
|
||||
langchain_version = metadata.version("langchain")
|
||||
client_library_version = (
|
||||
f"{langchain_version}-{module}" if module else langchain_version
|
||||
)
|
||||
return ClientInfo(
|
||||
client_library_version=client_library_version,
|
||||
user_agent=f"langchain/{client_library_version}",
|
||||
)
|
||||
|
||||
|
||||
def load_image_from_gcs(path: str, project: Optional[str] = None) -> Image:
|
||||
"""Loads im Image from GCS."""
|
||||
gcs_client = storage.Client(project=project)
|
||||
pieces = path.split("/")
|
||||
blobs = list(gcs_client.list_blobs(pieces[2], prefix="/".join(pieces[3:])))
|
||||
if len(blobs) > 1:
|
||||
raise ValueError(f"Found more than one candidate for {path}!")
|
||||
return Image.from_bytes(blobs[0].download_as_bytes())
|
||||
|
||||
|
||||
def is_codey_model(model_name: str) -> bool:
|
||||
"""Returns True if the model name is a Codey model."""
|
||||
return "code" in model_name
|
||||
|
||||
|
||||
def is_gemini_model(model_name: str) -> bool:
|
||||
"""Returns True if the model name is a Gemini model."""
|
||||
return model_name is not None and "gemini" in model_name
|
@ -0,0 +1,366 @@
|
||||
"""Wrapper around Google VertexAI chat-based models."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Iterator, List, Optional, Union, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from vertexai.language_models import ( # type: ignore
|
||||
ChatMessage,
|
||||
ChatModel,
|
||||
ChatSession,
|
||||
CodeChatModel,
|
||||
CodeChatSession,
|
||||
InputOutputTextPair,
|
||||
)
|
||||
from vertexai.preview.generative_models import ( # type: ignore
|
||||
Content,
|
||||
GenerativeModel,
|
||||
Image,
|
||||
Part,
|
||||
)
|
||||
|
||||
from langchain_google_vertexai._utils import (
|
||||
is_codey_model,
|
||||
is_gemini_model,
|
||||
load_image_from_gcs,
|
||||
)
|
||||
from langchain_google_vertexai.llms import (
|
||||
_VertexAICommon,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ChatHistory:
|
||||
"""Represents a context and a history of messages."""
|
||||
|
||||
history: List[ChatMessage] = field(default_factory=list)
|
||||
context: Optional[str] = None
|
||||
|
||||
|
||||
def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory:
|
||||
"""Parse a sequence of messages into history.
|
||||
|
||||
Args:
|
||||
history: The list of messages to re-create the history of the chat.
|
||||
Returns:
|
||||
A parsed chat history.
|
||||
Raises:
|
||||
ValueError: If a sequence of message has a SystemMessage not at the
|
||||
first place.
|
||||
"""
|
||||
|
||||
vertex_messages, context = [], None
|
||||
for i, message in enumerate(history):
|
||||
content = cast(str, message.content)
|
||||
if i == 0 and isinstance(message, SystemMessage):
|
||||
context = content
|
||||
elif isinstance(message, AIMessage):
|
||||
vertex_message = ChatMessage(content=message.content, author="bot")
|
||||
vertex_messages.append(vertex_message)
|
||||
elif isinstance(message, HumanMessage):
|
||||
vertex_message = ChatMessage(content=message.content, author="user")
|
||||
vertex_messages.append(vertex_message)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected message with type {type(message)} at the position {i}."
|
||||
)
|
||||
chat_history = _ChatHistory(context=context, history=vertex_messages)
|
||||
return chat_history
|
||||
|
||||
|
||||
def _is_url(s: str) -> bool:
|
||||
try:
|
||||
result = urlparse(s)
|
||||
return all([result.scheme, result.netloc])
|
||||
except Exception as e:
|
||||
logger.debug(f"Unable to parse URL: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _parse_chat_history_gemini(
|
||||
history: List[BaseMessage], project: Optional[str]
|
||||
) -> List[Content]:
|
||||
def _convert_to_prompt(part: Union[str, Dict]) -> Part:
|
||||
if isinstance(part, str):
|
||||
return Part.from_text(part)
|
||||
|
||||
if not isinstance(part, Dict):
|
||||
raise ValueError(
|
||||
f"Message's content is expected to be a dict, got {type(part)}!"
|
||||
)
|
||||
if part["type"] == "text":
|
||||
return Part.from_text(part["text"])
|
||||
elif part["type"] == "image_url":
|
||||
path = part["image_url"]["url"]
|
||||
if path.startswith("gs://"):
|
||||
image = load_image_from_gcs(path=path, project=project)
|
||||
elif path.startswith("data:image/"):
|
||||
# extract base64 component from image uri
|
||||
try:
|
||||
regexp = r"data:image/\w{2,4};base64,(.*)"
|
||||
encoded = re.search(regexp, path).group(1) # type: ignore
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"Invalid image uri. It should be in the format "
|
||||
"data:image/<image_type>;base64,<base64_encoded_image>."
|
||||
)
|
||||
image = Image.from_bytes(base64.b64decode(encoded))
|
||||
elif _is_url(path):
|
||||
response = requests.get(path)
|
||||
response.raise_for_status()
|
||||
image = Image.from_bytes(response.content)
|
||||
else:
|
||||
image = Image.load_from_file(path)
|
||||
else:
|
||||
raise ValueError("Only text and image_url types are supported!")
|
||||
return Part.from_image(image)
|
||||
|
||||
vertex_messages = []
|
||||
for i, message in enumerate(history):
|
||||
if i == 0 and isinstance(message, SystemMessage):
|
||||
raise ValueError("SystemMessages are not yet supported!")
|
||||
elif isinstance(message, AIMessage):
|
||||
role = "model"
|
||||
elif isinstance(message, HumanMessage):
|
||||
role = "user"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected message with type {type(message)} at the position {i}."
|
||||
)
|
||||
|
||||
raw_content = message.content
|
||||
if isinstance(raw_content, str):
|
||||
raw_content = [raw_content]
|
||||
parts = [_convert_to_prompt(part) for part in raw_content]
|
||||
vertex_message = Content(role=role, parts=parts)
|
||||
vertex_messages.append(vertex_message)
|
||||
return vertex_messages
|
||||
|
||||
|
||||
def _parse_examples(examples: List[BaseMessage]) -> List[InputOutputTextPair]:
|
||||
if len(examples) % 2 != 0:
|
||||
raise ValueError(
|
||||
f"Expect examples to have an even amount of messages, got {len(examples)}."
|
||||
)
|
||||
example_pairs = []
|
||||
input_text = None
|
||||
for i, example in enumerate(examples):
|
||||
if i % 2 == 0:
|
||||
if not isinstance(example, HumanMessage):
|
||||
raise ValueError(
|
||||
f"Expected the first message in a part to be from human, got "
|
||||
f"{type(example)} for the {i}th message."
|
||||
)
|
||||
input_text = example.content
|
||||
if i % 2 == 1:
|
||||
if not isinstance(example, AIMessage):
|
||||
raise ValueError(
|
||||
f"Expected the second message in a part to be from AI, got "
|
||||
f"{type(example)} for the {i}th message."
|
||||
)
|
||||
pair = InputOutputTextPair(
|
||||
input_text=input_text, output_text=example.content
|
||||
)
|
||||
example_pairs.append(pair)
|
||||
return example_pairs
|
||||
|
||||
|
||||
def _get_question(messages: List[BaseMessage]) -> HumanMessage:
|
||||
"""Get the human message at the end of a list of input messages to a chat model."""
|
||||
if not messages:
|
||||
raise ValueError("You should provide at least one message to start the chat!")
|
||||
question = messages[-1]
|
||||
if not isinstance(question, HumanMessage):
|
||||
raise ValueError(
|
||||
f"Last message in the list should be from human, got {question.type}."
|
||||
)
|
||||
return question
|
||||
|
||||
|
||||
class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
"""`Vertex AI` Chat large language models API."""
|
||||
|
||||
model_name: str = "chat-bison"
|
||||
"Underlying model name."
|
||||
examples: Optional[List[BaseMessage]] = None
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that the python package exists in environment."""
|
||||
is_gemini = is_gemini_model(values["model_name"])
|
||||
cls._init_vertexai(values)
|
||||
if is_gemini:
|
||||
values["client"] = GenerativeModel(model_name=values["model_name"])
|
||||
else:
|
||||
if is_codey_model(values["model_name"]):
|
||||
model_cls = CodeChatModel
|
||||
else:
|
||||
model_cls = ChatModel
|
||||
values["client"] = model_cls.from_pretrained(values["model_name"])
|
||||
return values
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Generate next turn in the conversation.
|
||||
|
||||
Args:
|
||||
messages: The history of the conversation as a list of messages. Code chat
|
||||
does not support context.
|
||||
stop: The list of stop words (optional).
|
||||
run_manager: The CallbackManager for LLM run, it's not used at the moment.
|
||||
stream: Whether to use the streaming endpoint.
|
||||
|
||||
Returns:
|
||||
The ChatResult that contains outputs generated by the model.
|
||||
|
||||
Raises:
|
||||
ValueError: if the last message in the list is not from human.
|
||||
"""
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
question = _get_question(messages)
|
||||
params = self._prepare_params(stop=stop, stream=False, **kwargs)
|
||||
msg_params = {}
|
||||
if "candidate_count" in params:
|
||||
msg_params["candidate_count"] = params.pop("candidate_count")
|
||||
|
||||
if self._is_gemini_model:
|
||||
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
|
||||
message = history_gemini.pop()
|
||||
chat = self.client.start_chat(history=history_gemini)
|
||||
response = chat.send_message(message, generation_config=params)
|
||||
else:
|
||||
history = _parse_chat_history(messages[:-1])
|
||||
examples = kwargs.get("examples") or self.examples
|
||||
if examples:
|
||||
params["examples"] = _parse_examples(examples)
|
||||
chat = self._start_chat(history, **params)
|
||||
response = chat.send_message(question.content, **msg_params)
|
||||
generations = [
|
||||
ChatGeneration(message=AIMessage(content=r.text))
|
||||
for r in response.candidates
|
||||
]
|
||||
return ChatResult(generations=generations)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Asynchronously generate next turn in the conversation.
|
||||
|
||||
Args:
|
||||
messages: The history of the conversation as a list of messages. Code chat
|
||||
does not support context.
|
||||
stop: The list of stop words (optional).
|
||||
run_manager: The CallbackManager for LLM run, it's not used at the moment.
|
||||
|
||||
Returns:
|
||||
The ChatResult that contains outputs generated by the model.
|
||||
|
||||
Raises:
|
||||
ValueError: if the last message in the list is not from human.
|
||||
"""
|
||||
if "stream" in kwargs:
|
||||
kwargs.pop("stream")
|
||||
logger.warning("ChatVertexAI does not currently support async streaming.")
|
||||
|
||||
params = self._prepare_params(stop=stop, **kwargs)
|
||||
msg_params = {}
|
||||
if "candidate_count" in params:
|
||||
msg_params["candidate_count"] = params.pop("candidate_count")
|
||||
|
||||
if self._is_gemini_model:
|
||||
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
|
||||
message = history_gemini.pop()
|
||||
chat = self.client.start_chat(history=history_gemini)
|
||||
response = await chat.send_message_async(message, generation_config=params)
|
||||
else:
|
||||
question = _get_question(messages)
|
||||
history = _parse_chat_history(messages[:-1])
|
||||
examples = kwargs.get("examples", None)
|
||||
if examples:
|
||||
params["examples"] = _parse_examples(examples)
|
||||
chat = self._start_chat(history, **params)
|
||||
response = await chat.send_message_async(question.content, **msg_params)
|
||||
|
||||
generations = [
|
||||
ChatGeneration(message=AIMessage(content=r.text))
|
||||
for r in response.candidates
|
||||
]
|
||||
return ChatResult(generations=generations)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
params = self._prepare_params(stop=stop, stream=True, **kwargs)
|
||||
if self._is_gemini_model:
|
||||
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
|
||||
message = history_gemini.pop()
|
||||
chat = self.client.start_chat(history=history_gemini)
|
||||
responses = chat.send_message(
|
||||
message, stream=True, generation_config=params
|
||||
)
|
||||
else:
|
||||
question = _get_question(messages)
|
||||
history = _parse_chat_history(messages[:-1])
|
||||
examples = kwargs.get("examples", None)
|
||||
if examples:
|
||||
params["examples"] = _parse_examples(examples)
|
||||
chat = self._start_chat(history, **params)
|
||||
responses = chat.send_message_streaming(question.content, **params)
|
||||
for response in responses:
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(response.text)
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=response.text))
|
||||
|
||||
def _start_chat(
|
||||
self, history: _ChatHistory, **kwargs: Any
|
||||
) -> Union[ChatSession, CodeChatSession]:
|
||||
if not self.is_codey_model:
|
||||
return self.client.start_chat(
|
||||
context=history.context, message_history=history.history, **kwargs
|
||||
)
|
||||
else:
|
||||
return self.client.start_chat(message_history=history.history, **kwargs)
|
@ -0,0 +1,336 @@
|
||||
import logging
|
||||
import re
|
||||
import string
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor, wait
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Type
|
||||
|
||||
from google.api_core.exceptions import (
|
||||
Aborted,
|
||||
DeadlineExceeded,
|
||||
InvalidArgument,
|
||||
ResourceExhausted,
|
||||
ServiceUnavailable,
|
||||
)
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models.llms import create_base_retry_decorator
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from vertexai.language_models import ( # type: ignore
|
||||
TextEmbeddingInput,
|
||||
TextEmbeddingModel,
|
||||
)
|
||||
|
||||
from langchain_google_vertexai.llms import _VertexAICommon
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MAX_TOKENS_PER_BATCH = 20000
|
||||
_MAX_BATCH_SIZE = 250
|
||||
_MIN_BATCH_SIZE = 5
|
||||
|
||||
|
||||
class VertexAIEmbeddings(_VertexAICommon, Embeddings):
|
||||
"""Google Cloud VertexAI embedding models."""
|
||||
|
||||
# Instance context
|
||||
instance: Dict[str, Any] = {} #: :meta private:
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validates that the python package exists in environment."""
|
||||
cls._init_vertexai(values)
|
||||
if values["model_name"] == "textembedding-gecko-default":
|
||||
logger.warning(
|
||||
"Model_name will become a required arg for VertexAIEmbeddings "
|
||||
"starting from Feb-01-2024. Currently the default is set to "
|
||||
"textembedding-gecko@001"
|
||||
)
|
||||
values["model_name"] = "textembedding-gecko@001"
|
||||
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
|
||||
return values
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# the default value would be removed after Feb-01-2024
|
||||
model_name: str = "textembedding-gecko-default",
|
||||
project: Optional[str] = None,
|
||||
location: str = "us-central1",
|
||||
request_parallelism: int = 5,
|
||||
max_retries: int = 6,
|
||||
credentials: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize the sentence_transformer."""
|
||||
super().__init__(
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
request_parallelism=request_parallelism,
|
||||
max_retries=max_retries,
|
||||
model_name=model_name,
|
||||
**kwargs,
|
||||
)
|
||||
self.instance["max_batch_size"] = kwargs.get("max_batch_size", _MAX_BATCH_SIZE)
|
||||
self.instance["batch_size"] = self.instance["max_batch_size"]
|
||||
self.instance["min_batch_size"] = kwargs.get("min_batch_size", _MIN_BATCH_SIZE)
|
||||
self.instance["min_good_batch_size"] = self.instance["min_batch_size"]
|
||||
self.instance["lock"] = threading.Lock()
|
||||
self.instance["batch_size_validated"] = False
|
||||
self.instance["task_executor"] = ThreadPoolExecutor(
|
||||
max_workers=request_parallelism
|
||||
)
|
||||
self.instance[
|
||||
"embeddings_task_type_supported"
|
||||
] = not self.client._endpoint_name.endswith("/textembedding-gecko@001")
|
||||
|
||||
@staticmethod
|
||||
def _split_by_punctuation(text: str) -> List[str]:
|
||||
"""Splits a string by punctuation and whitespace characters."""
|
||||
split_by = string.punctuation + "\t\n "
|
||||
pattern = f"([{split_by}])"
|
||||
# Using re.split to split the text based on the pattern
|
||||
return [segment for segment in re.split(pattern, text) if segment]
|
||||
|
||||
@staticmethod
|
||||
def _prepare_batches(texts: List[str], batch_size: int) -> List[List[str]]:
|
||||
"""Splits texts in batches based on current maximum batch size
|
||||
and maximum tokens per request.
|
||||
"""
|
||||
text_index = 0
|
||||
texts_len = len(texts)
|
||||
batch_token_len = 0
|
||||
batches: List[List[str]] = []
|
||||
current_batch: List[str] = []
|
||||
if texts_len == 0:
|
||||
return []
|
||||
while text_index < texts_len:
|
||||
current_text = texts[text_index]
|
||||
# Number of tokens per a text is conservatively estimated
|
||||
# as 2 times number of words, punctuation and whitespace characters.
|
||||
# Using `count_tokens` API will make batching too expensive.
|
||||
# Utilizing a tokenizer, would add a dependency that would not
|
||||
# necessarily be reused by the application using this class.
|
||||
current_text_token_cnt = (
|
||||
len(VertexAIEmbeddings._split_by_punctuation(current_text)) * 2
|
||||
)
|
||||
end_of_batch = False
|
||||
if current_text_token_cnt > _MAX_TOKENS_PER_BATCH:
|
||||
# Current text is too big even for a single batch.
|
||||
# Such request will fail, but we still make a batch
|
||||
# so that the app can get the error from the API.
|
||||
if len(current_batch) > 0:
|
||||
# Adding current batch if not empty.
|
||||
batches.append(current_batch)
|
||||
current_batch = [current_text]
|
||||
text_index += 1
|
||||
end_of_batch = True
|
||||
elif (
|
||||
batch_token_len + current_text_token_cnt > _MAX_TOKENS_PER_BATCH
|
||||
or len(current_batch) == batch_size
|
||||
):
|
||||
end_of_batch = True
|
||||
else:
|
||||
if text_index == texts_len - 1:
|
||||
# Last element - even though the batch may be not big,
|
||||
# we still need to make it.
|
||||
end_of_batch = True
|
||||
batch_token_len += current_text_token_cnt
|
||||
current_batch.append(current_text)
|
||||
text_index += 1
|
||||
if end_of_batch:
|
||||
batches.append(current_batch)
|
||||
current_batch = []
|
||||
batch_token_len = 0
|
||||
return batches
|
||||
|
||||
def _get_embeddings_with_retry(
|
||||
self, texts: List[str], embeddings_type: Optional[str] = None
|
||||
) -> List[List[float]]:
|
||||
"""Makes a Vertex AI model request with retry logic."""
|
||||
|
||||
errors: List[Type[BaseException]] = [
|
||||
ResourceExhausted,
|
||||
ServiceUnavailable,
|
||||
Aborted,
|
||||
DeadlineExceeded,
|
||||
]
|
||||
retry_decorator = create_base_retry_decorator(
|
||||
error_types=errors, max_retries=self.max_retries
|
||||
)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(texts_to_process: List[str]) -> Any:
|
||||
if embeddings_type and self.instance["embeddings_task_type_supported"]:
|
||||
requests = [
|
||||
TextEmbeddingInput(text=t, task_type=embeddings_type)
|
||||
for t in texts_to_process
|
||||
]
|
||||
else:
|
||||
requests = texts_to_process
|
||||
embeddings = self.client.get_embeddings(requests)
|
||||
return [embs.values for embs in embeddings]
|
||||
|
||||
return _completion_with_retry(texts)
|
||||
|
||||
def _prepare_and_validate_batches(
|
||||
self, texts: List[str], embeddings_type: Optional[str] = None
|
||||
) -> Tuple[List[List[float]], List[List[str]]]:
|
||||
"""Prepares text batches with one-time validation of batch size.
|
||||
Batch size varies between GCP regions and individual project quotas.
|
||||
# Returns embeddings of the first text batch that went through,
|
||||
# and text batches for the rest of the texts.
|
||||
"""
|
||||
|
||||
batches = VertexAIEmbeddings._prepare_batches(
|
||||
texts, self.instance["batch_size"]
|
||||
)
|
||||
# If batch size if less or equal to one that went through before,
|
||||
# then keep batches as they are.
|
||||
if len(batches[0]) <= self.instance["min_good_batch_size"]:
|
||||
return [], batches
|
||||
with self.instance["lock"]:
|
||||
# If largest possible batch size was validated
|
||||
# while waiting for the lock, then check for rebuilding
|
||||
# our batches, and return.
|
||||
if self.instance["batch_size_validated"]:
|
||||
if len(batches[0]) <= self.instance["batch_size"]:
|
||||
return [], batches
|
||||
else:
|
||||
return [], VertexAIEmbeddings._prepare_batches(
|
||||
texts, self.instance["batch_size"]
|
||||
)
|
||||
# Figure out largest possible batch size by trying to push
|
||||
# batches and lowering their size in half after every failure.
|
||||
first_batch = batches[0]
|
||||
first_result = []
|
||||
had_failure = False
|
||||
while True:
|
||||
try:
|
||||
first_result = self._get_embeddings_with_retry(
|
||||
first_batch, embeddings_type
|
||||
)
|
||||
break
|
||||
except InvalidArgument:
|
||||
had_failure = True
|
||||
first_batch_len = len(first_batch)
|
||||
if first_batch_len == self.instance["min_batch_size"]:
|
||||
raise
|
||||
first_batch_len = max(
|
||||
self.instance["min_batch_size"], int(first_batch_len / 2)
|
||||
)
|
||||
first_batch = first_batch[:first_batch_len]
|
||||
first_batch_len = len(first_batch)
|
||||
self.instance["min_good_batch_size"] = max(
|
||||
self.instance["min_good_batch_size"], first_batch_len
|
||||
)
|
||||
# If had a failure and recovered
|
||||
# or went through with the max size, then it's a legit batch size.
|
||||
if had_failure or first_batch_len == self.instance["max_batch_size"]:
|
||||
self.instance["batch_size"] = first_batch_len
|
||||
self.instance["batch_size_validated"] = True
|
||||
# If batch size was updated,
|
||||
# rebuild batches with the new batch size
|
||||
# (texts that went through are excluded here).
|
||||
if first_batch_len != self.instance["max_batch_size"]:
|
||||
batches = VertexAIEmbeddings._prepare_batches(
|
||||
texts[first_batch_len:], self.instance["batch_size"]
|
||||
)
|
||||
else:
|
||||
# Still figuring out max batch size.
|
||||
batches = batches[1:]
|
||||
# Returning embeddings of the first text batch that went through,
|
||||
# and text batches for the rest of texts.
|
||||
return first_result, batches
|
||||
|
||||
def embed(
|
||||
self,
|
||||
texts: List[str],
|
||||
batch_size: int = 0,
|
||||
embeddings_task_type: Optional[
|
||||
Literal[
|
||||
"RETRIEVAL_QUERY",
|
||||
"RETRIEVAL_DOCUMENT",
|
||||
"SEMANTIC_SIMILARITY",
|
||||
"CLASSIFICATION",
|
||||
"CLUSTERING",
|
||||
]
|
||||
] = None,
|
||||
) -> List[List[float]]:
|
||||
"""Embed a list of strings.
|
||||
|
||||
Args:
|
||||
texts: List[str] The list of strings to embed.
|
||||
batch_size: [int] The batch size of embeddings to send to the model.
|
||||
If zero, then the largest batch size will be detected dynamically
|
||||
at the first request, starting from 250, down to 5.
|
||||
embeddings_task_type: [str] optional embeddings task type,
|
||||
one of the following
|
||||
RETRIEVAL_QUERY - Text is a query
|
||||
in a search/retrieval setting.
|
||||
RETRIEVAL_DOCUMENT - Text is a document
|
||||
in a search/retrieval setting.
|
||||
SEMANTIC_SIMILARITY - Embeddings will be used
|
||||
for Semantic Textual Similarity (STS).
|
||||
CLASSIFICATION - Embeddings will be used for classification.
|
||||
CLUSTERING - Embeddings will be used for clustering.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
if len(texts) == 0:
|
||||
return []
|
||||
embeddings: List[List[float]] = []
|
||||
first_batch_result: List[List[float]] = []
|
||||
if batch_size > 0:
|
||||
# Fixed batch size.
|
||||
batches = VertexAIEmbeddings._prepare_batches(texts, batch_size)
|
||||
else:
|
||||
# Dynamic batch size, starting from 250 at the first call.
|
||||
first_batch_result, batches = self._prepare_and_validate_batches(
|
||||
texts, embeddings_task_type
|
||||
)
|
||||
# First batch result may have some embeddings already.
|
||||
# In such case, batches have texts that were not processed yet.
|
||||
embeddings.extend(first_batch_result)
|
||||
tasks = []
|
||||
for batch in batches:
|
||||
tasks.append(
|
||||
self.instance["task_executor"].submit(
|
||||
self._get_embeddings_with_retry,
|
||||
texts=batch,
|
||||
embeddings_type=embeddings_task_type,
|
||||
)
|
||||
)
|
||||
if len(tasks) > 0:
|
||||
wait(tasks)
|
||||
for t in tasks:
|
||||
embeddings.extend(t.result())
|
||||
return embeddings
|
||||
|
||||
def embed_documents(
|
||||
self, texts: List[str], batch_size: int = 0
|
||||
) -> List[List[float]]:
|
||||
"""Embed a list of documents.
|
||||
|
||||
Args:
|
||||
texts: List[str] The list of texts to embed.
|
||||
batch_size: [int] The batch size of embeddings to send to the model.
|
||||
If zero, then the largest batch size will be detected dynamically
|
||||
at the first request, starting from 250, down to 5.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
return self.embed(texts, batch_size, "RETRIEVAL_DOCUMENT")
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed a text.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
"""
|
||||
embeddings = self.embed([text], 1, "RETRIEVAL_QUERY")
|
||||
return embeddings[0]
|
@ -0,0 +1,469 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import Executor
|
||||
from typing import Any, ClassVar, Dict, Iterator, List, Optional, Union
|
||||
|
||||
import vertexai # type: ignore
|
||||
from google.api_core.client_options import ClientOptions
|
||||
from google.cloud.aiplatform.gapic import (
|
||||
PredictionServiceAsyncClient,
|
||||
PredictionServiceClient,
|
||||
)
|
||||
from google.cloud.aiplatform.models import Prediction
|
||||
from google.protobuf import json_format
|
||||
from google.protobuf.struct_pb2 import Value
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.llms import BaseLLM
|
||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from vertexai.language_models import ( # type: ignore
|
||||
CodeGenerationModel,
|
||||
TextGenerationModel,
|
||||
)
|
||||
from vertexai.language_models._language_models import ( # type: ignore
|
||||
TextGenerationResponse,
|
||||
)
|
||||
from vertexai.preview.generative_models import GenerativeModel, Image # type: ignore
|
||||
from vertexai.preview.language_models import ( # type: ignore
|
||||
CodeGenerationModel as PreviewCodeGenerationModel,
|
||||
)
|
||||
from vertexai.preview.language_models import (
|
||||
TextGenerationModel as PreviewTextGenerationModel,
|
||||
)
|
||||
|
||||
from langchain_google_vertexai._utils import (
|
||||
create_retry_decorator,
|
||||
get_client_info,
|
||||
is_codey_model,
|
||||
is_gemini_model,
|
||||
)
|
||||
|
||||
|
||||
def _completion_with_retry(
|
||||
llm: VertexAI,
|
||||
prompt: List[Union[str, Image]],
|
||||
stream: bool = False,
|
||||
is_gemini: bool = False,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = create_retry_decorator(
|
||||
max_retries=llm.max_retries, run_manager=run_manager
|
||||
)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry_inner(
|
||||
prompt: List[Union[str, Image]], is_gemini: bool = False, **kwargs: Any
|
||||
) -> Any:
|
||||
if is_gemini:
|
||||
return llm.client.generate_content(
|
||||
prompt, stream=stream, generation_config=kwargs
|
||||
)
|
||||
else:
|
||||
if stream:
|
||||
return llm.client.predict_streaming(prompt[0], **kwargs)
|
||||
return llm.client.predict(prompt[0], **kwargs)
|
||||
|
||||
return _completion_with_retry_inner(prompt, is_gemini, **kwargs)
|
||||
|
||||
|
||||
async def _acompletion_with_retry(
|
||||
llm: VertexAI,
|
||||
prompt: str,
|
||||
is_gemini: bool = False,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = create_retry_decorator(
|
||||
max_retries=llm.max_retries, run_manager=run_manager
|
||||
)
|
||||
|
||||
@retry_decorator
|
||||
async def _acompletion_with_retry_inner(
|
||||
prompt: str, is_gemini: bool = False, **kwargs: Any
|
||||
) -> Any:
|
||||
if is_gemini:
|
||||
return await llm.client.generate_content_async(
|
||||
prompt, generation_config=kwargs
|
||||
)
|
||||
return await llm.client.predict_async(prompt, **kwargs)
|
||||
|
||||
return await _acompletion_with_retry_inner(prompt, is_gemini, **kwargs)
|
||||
|
||||
|
||||
class _VertexAIBase(BaseModel):
|
||||
project: Optional[str] = None
|
||||
"The default GCP project to use when making Vertex API calls."
|
||||
location: str = "us-central1"
|
||||
"The default location to use when making API calls."
|
||||
request_parallelism: int = 5
|
||||
"The amount of parallelism allowed for requests issued to VertexAI models. "
|
||||
"Default is 5."
|
||||
max_retries: int = 6
|
||||
"""The maximum number of retries to make when generating."""
|
||||
task_executor: ClassVar[Optional[Executor]] = Field(default=None, exclude=True)
|
||||
stop: Optional[List[str]] = None
|
||||
"Optional list of stop words to use when generating."
|
||||
model_name: Optional[str] = None
|
||||
"Underlying model name."
|
||||
|
||||
|
||||
class _VertexAICommon(_VertexAIBase):
|
||||
client: Any = None #: :meta private:
|
||||
client_preview: Any = None #: :meta private:
|
||||
model_name: str
|
||||
"Underlying model name."
|
||||
temperature: float = 0.0
|
||||
"Sampling temperature, it controls the degree of randomness in token selection."
|
||||
max_output_tokens: int = 128
|
||||
"Token limit determines the maximum amount of text output from one prompt."
|
||||
top_p: float = 0.95
|
||||
"Tokens are selected from most probable to least until the sum of their "
|
||||
"probabilities equals the top-p value. Top-p is ignored for Codey models."
|
||||
top_k: int = 40
|
||||
"How the model selects tokens for output, the next token is selected from "
|
||||
"among the top-k most probable tokens. Top-k is ignored for Codey models."
|
||||
credentials: Any = Field(default=None, exclude=True)
|
||||
"The default custom credentials (google.auth.credentials.Credentials) to use "
|
||||
"when making API calls. If not provided, credentials will be ascertained from "
|
||||
"the environment."
|
||||
n: int = 1
|
||||
"""How many completions to generate for each prompt."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "vertexai"
|
||||
|
||||
@property
|
||||
def is_codey_model(self) -> bool:
|
||||
return is_codey_model(self.model_name)
|
||||
|
||||
@property
|
||||
def _is_gemini_model(self) -> bool:
|
||||
return is_gemini_model(self.model_name)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Gets the identifying parameters."""
|
||||
return {**{"model_name": self.model_name}, **self._default_params}
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
params = {
|
||||
"temperature": self.temperature,
|
||||
"max_output_tokens": self.max_output_tokens,
|
||||
"candidate_count": self.n,
|
||||
}
|
||||
if not self.is_codey_model:
|
||||
params.update(
|
||||
{
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
@classmethod
|
||||
def _init_vertexai(cls, values: Dict) -> None:
|
||||
vertexai.init(
|
||||
project=values.get("project"),
|
||||
location=values.get("location"),
|
||||
credentials=values.get("credentials"),
|
||||
)
|
||||
return None
|
||||
|
||||
def _prepare_params(
|
||||
self,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
stop_sequences = stop or self.stop
|
||||
params_mapping = {"n": "candidate_count"}
|
||||
params = {params_mapping.get(k, k): v for k, v in kwargs.items()}
|
||||
params = {**self._default_params, "stop_sequences": stop_sequences, **params}
|
||||
if stream or self.streaming:
|
||||
params.pop("candidate_count")
|
||||
return params
|
||||
|
||||
|
||||
class VertexAI(_VertexAICommon, BaseLLM):
|
||||
"""Google Vertex AI large language models."""
|
||||
|
||||
model_name: str = "text-bison"
|
||||
"The name of the Vertex AI large language model."
|
||||
tuned_model_name: Optional[str] = None
|
||||
"The name of a tuned model. If provided, model_name is ignored."
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that the python package exists in environment."""
|
||||
tuned_model_name = values.get("tuned_model_name")
|
||||
model_name = values["model_name"]
|
||||
is_gemini = is_gemini_model(values["model_name"])
|
||||
cls._init_vertexai(values)
|
||||
|
||||
if is_codey_model(model_name):
|
||||
model_cls = CodeGenerationModel
|
||||
preview_model_cls = PreviewCodeGenerationModel
|
||||
elif is_gemini:
|
||||
model_cls = GenerativeModel
|
||||
preview_model_cls = GenerativeModel
|
||||
else:
|
||||
model_cls = TextGenerationModel
|
||||
preview_model_cls = PreviewTextGenerationModel
|
||||
|
||||
if tuned_model_name:
|
||||
values["client"] = model_cls.get_tuned_model(tuned_model_name)
|
||||
values["client_preview"] = preview_model_cls.get_tuned_model(
|
||||
tuned_model_name
|
||||
)
|
||||
else:
|
||||
if is_gemini:
|
||||
values["client"] = model_cls(model_name=model_name)
|
||||
values["client_preview"] = preview_model_cls(model_name=model_name)
|
||||
else:
|
||||
values["client"] = model_cls.from_pretrained(model_name)
|
||||
values["client_preview"] = preview_model_cls.from_pretrained(model_name)
|
||||
|
||||
if values["streaming"] and values["n"] > 1:
|
||||
raise ValueError("Only one candidate can be generated with streaming!")
|
||||
return values
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Get the number of tokens present in the text.
|
||||
|
||||
Useful for checking if an input will fit in a model's context window.
|
||||
|
||||
Args:
|
||||
text: The string input to tokenize.
|
||||
|
||||
Returns:
|
||||
The integer number of tokens in the text.
|
||||
"""
|
||||
result = self.client_preview.count_tokens([text])
|
||||
return result.total_tokens
|
||||
|
||||
def _response_to_generation(
|
||||
self, response: TextGenerationResponse
|
||||
) -> GenerationChunk:
|
||||
"""Converts a stream response to a generation chunk."""
|
||||
try:
|
||||
generation_info = {
|
||||
"is_blocked": response.is_blocked,
|
||||
"safety_attributes": response.safety_attributes,
|
||||
}
|
||||
except Exception:
|
||||
generation_info = None
|
||||
return GenerationChunk(text=response.text, generation_info=generation_info)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
params = self._prepare_params(stop=stop, stream=should_stream, **kwargs)
|
||||
generations: List[List[Generation]] = []
|
||||
for prompt in prompts:
|
||||
if should_stream:
|
||||
generation = GenerationChunk(text="")
|
||||
for chunk in self._stream(
|
||||
prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
generation += chunk
|
||||
generations.append([generation])
|
||||
else:
|
||||
res = _completion_with_retry(
|
||||
self,
|
||||
[prompt],
|
||||
stream=should_stream,
|
||||
is_gemini=self._is_gemini_model,
|
||||
run_manager=run_manager,
|
||||
**params,
|
||||
)
|
||||
generations.append(
|
||||
[self._response_to_generation(r) for r in res.candidates]
|
||||
)
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
params = self._prepare_params(stop=stop, **kwargs)
|
||||
generations = []
|
||||
for prompt in prompts:
|
||||
res = await _acompletion_with_retry(
|
||||
self,
|
||||
prompt,
|
||||
is_gemini=self._is_gemini_model,
|
||||
run_manager=run_manager,
|
||||
**params,
|
||||
)
|
||||
generations.append(
|
||||
[self._response_to_generation(r) for r in res.candidates]
|
||||
)
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
params = self._prepare_params(stop=stop, stream=True, **kwargs)
|
||||
for stream_resp in _completion_with_retry(
|
||||
self,
|
||||
[prompt],
|
||||
stream=True,
|
||||
is_gemini=self._is_gemini_model,
|
||||
run_manager=run_manager,
|
||||
**params,
|
||||
):
|
||||
chunk = self._response_to_generation(stream_resp)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
chunk=chunk,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
|
||||
|
||||
class VertexAIModelGarden(_VertexAIBase, BaseLLM):
|
||||
"""Large language models served from Vertex AI Model Garden."""
|
||||
|
||||
client: Any = None #: :meta private:
|
||||
async_client: Any = None #: :meta private:
|
||||
endpoint_id: str
|
||||
"A name of an endpoint where the model has been deployed."
|
||||
allowed_model_args: Optional[List[str]] = None
|
||||
"Allowed optional args to be passed to the model."
|
||||
prompt_arg: str = "prompt"
|
||||
result_arg: Optional[str] = "generated_text"
|
||||
"Set result_arg to None if output of the model is expected to be a string."
|
||||
"Otherwise, if it's a dict, provided an argument that contains the result."
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that the python package exists in environment."""
|
||||
|
||||
if not values["project"]:
|
||||
raise ValueError(
|
||||
"A GCP project should be provided to run inference on Model Garden!"
|
||||
)
|
||||
|
||||
client_options = ClientOptions(
|
||||
api_endpoint=f"{values['location']}-aiplatform.googleapis.com"
|
||||
)
|
||||
client_info = get_client_info(module="vertex-ai-model-garden")
|
||||
values["client"] = PredictionServiceClient(
|
||||
client_options=client_options, client_info=client_info
|
||||
)
|
||||
values["async_client"] = PredictionServiceAsyncClient(
|
||||
client_options=client_options, client_info=client_info
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def endpoint_path(self) -> str:
|
||||
return self.client.endpoint_path(
|
||||
project=self.project, # type: ignore
|
||||
location=self.location,
|
||||
endpoint=self.endpoint_id,
|
||||
)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "vertexai_model_garden"
|
||||
|
||||
def _prepare_request(self, prompts: List[str], **kwargs: Any) -> List["Value"]:
|
||||
instances = []
|
||||
for prompt in prompts:
|
||||
if self.allowed_model_args:
|
||||
instance = {
|
||||
k: v for k, v in kwargs.items() if k in self.allowed_model_args
|
||||
}
|
||||
else:
|
||||
instance = {}
|
||||
instance[self.prompt_arg] = prompt
|
||||
instances.append(instance)
|
||||
|
||||
predict_instances = [
|
||||
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
|
||||
]
|
||||
return predict_instances
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
instances = self._prepare_request(prompts, **kwargs)
|
||||
response = self.client.predict(endpoint=self.endpoint_path, instances=instances)
|
||||
return self._parse_response(response)
|
||||
|
||||
def _parse_response(self, predictions: "Prediction") -> LLMResult:
|
||||
generations: List[List[Generation]] = []
|
||||
for result in predictions.predictions:
|
||||
generations.append(
|
||||
[
|
||||
Generation(text=self._parse_prediction(prediction))
|
||||
for prediction in result
|
||||
]
|
||||
)
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
def _parse_prediction(self, prediction: Any) -> str:
|
||||
if isinstance(prediction, str):
|
||||
return prediction
|
||||
|
||||
if self.result_arg:
|
||||
try:
|
||||
return prediction[self.result_arg]
|
||||
except KeyError:
|
||||
if isinstance(prediction, str):
|
||||
error_desc = (
|
||||
"Provided non-None `result_arg` (result_arg="
|
||||
f"{self.result_arg}). But got prediction of type "
|
||||
f"{type(prediction)} instead of dict. Most probably, you"
|
||||
"need to set `result_arg=None` during VertexAIModelGarden "
|
||||
"initialization."
|
||||
)
|
||||
raise ValueError(error_desc)
|
||||
else:
|
||||
raise ValueError(f"{self.result_arg} key not found in prediction!")
|
||||
|
||||
return prediction
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
instances = self._prepare_request(prompts, **kwargs)
|
||||
response = await self.async_client.predict(
|
||||
endpoint=self.endpoint_path, instances=instances
|
||||
)
|
||||
return self._parse_response(response)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,94 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-google-vertexai"
|
||||
version = "0.0.1"
|
||||
description = "An integration package connecting GoogleVertexAI and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = ">=0.1,<0.2"
|
||||
google-cloud-aiplatform = "1.38.1"
|
||||
google-cloud-storage = "^2.14.0"
|
||||
types-requests = "^2.31.0.20231231"
|
||||
types-protobuf = "^4.24.0.4"
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
pytest = "^7.3.0"
|
||||
freezegun = "^1.2.2"
|
||||
pytest-mock = "^3.10.0"
|
||||
syrupy = "^4.0.2"
|
||||
pytest-watcher = "^0.3.4"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
langchain-core = {path = "../../core", develop = true}
|
||||
types-requests = "^2.31.0.20231231"
|
||||
types-protobuf = "^4.24.0.4"
|
||||
|
||||
[tool.poetry.group.codespell]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.codespell.dependencies]
|
||||
codespell = "^2.2.0"
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test_integration.dependencies]
|
||||
|
||||
[tool.poetry.group.lint]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.1.5"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^0.991"
|
||||
langchain-core = {path = "../../core", develop = true}
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
langchain-core = {path = "../../core", develop = true}
|
||||
|
||||
[tool.ruff]
|
||||
select = [
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [
|
||||
"tests/*",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
# --strict-markers will raise errors on unknown marks.
|
||||
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
||||
#
|
||||
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
||||
# --strict-config any warnings encountered while parsing the `pytest`
|
||||
# section of the configuration file raise errors.
|
||||
#
|
||||
# https://github.com/tophat/syrupy
|
||||
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
||||
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
|
||||
# Registering custom markers.
|
||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||
markers = [
|
||||
"requires: mark tests as requiring a specific library",
|
||||
"asyncio: mark tests as requiring asyncio",
|
||||
"compile: mark placeholder test used to compile integration tests without running them",
|
||||
]
|
||||
asyncio_mode = "auto"
|
@ -0,0 +1,17 @@
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.machinery import SourceFileLoader
|
||||
|
||||
if __name__ == "__main__":
|
||||
files = sys.argv[1:]
|
||||
has_failure = False
|
||||
for file in files:
|
||||
try:
|
||||
SourceFileLoader("x", file).load_module()
|
||||
except Exception:
|
||||
has_faillure = True
|
||||
print(file)
|
||||
traceback.print_exc()
|
||||
print()
|
||||
|
||||
sys.exit(1 if has_failure else 0)
|
@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# This script searches for lines starting with "import pydantic" or "from pydantic"
|
||||
# in tracked files within a Git repository.
|
||||
#
|
||||
# Usage: ./scripts/check_pydantic.sh /path/to/repository
|
||||
|
||||
# Check if a path argument is provided
|
||||
if [ $# -ne 1 ]; then
|
||||
echo "Usage: $0 /path/to/repository"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
repository_path="$1"
|
||||
|
||||
# Search for lines matching the pattern within the specified repository
|
||||
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
||||
|
||||
# Check if any matching lines were found
|
||||
if [ -n "$result" ]; then
|
||||
echo "ERROR: The following lines need to be updated:"
|
||||
echo "$result"
|
||||
echo "Please replace the code with an import from langchain_core.pydantic_v1."
|
||||
echo "For example, replace 'from pydantic import BaseModel'"
|
||||
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
||||
exit 1
|
||||
fi
|
@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -eu
|
||||
|
||||
# Initialize a variable to keep track of errors
|
||||
errors=0
|
||||
|
||||
# make sure not importing from langchain or langchain_experimental
|
||||
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
||||
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
||||
|
||||
# Decide on an exit status based on the errors
|
||||
if [ "$errors" -gt 0 ]; then
|
||||
exit 1
|
||||
else
|
||||
exit 0
|
||||
fi
|
@ -0,0 +1,176 @@
|
||||
"""Test ChatGoogleVertexAI chat model."""
|
||||
import pytest
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from langchain_google_vertexai.chat_models import ChatVertexAI
|
||||
|
||||
model_names_to_test = [None, "codechat-bison", "chat-bison", "gemini-pro"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", model_names_to_test)
|
||||
def test_initialization(model_name: str) -> None:
|
||||
"""Test chat model initialization."""
|
||||
if model_name:
|
||||
model = ChatVertexAI(model_name=model_name)
|
||||
else:
|
||||
model = ChatVertexAI()
|
||||
assert model._llm_type == "vertexai"
|
||||
try:
|
||||
assert model.model_name == model.client._model_id
|
||||
except AttributeError:
|
||||
assert model.model_name == model.client._model_name.split("/")[-1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", model_names_to_test)
|
||||
def test_vertexai_single_call(model_name: str) -> None:
|
||||
if model_name:
|
||||
model = ChatVertexAI(model_name=model_name)
|
||||
else:
|
||||
model = ChatVertexAI()
|
||||
message = HumanMessage(content="Hello")
|
||||
response = model([message])
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
# mark xfail because Vertex API randomly doesn't respect
|
||||
# the n/candidate_count parameter
|
||||
@pytest.mark.xfail
|
||||
def test_candidates() -> None:
|
||||
model = ChatVertexAI(model_name="chat-bison@001", temperature=0.3, n=2)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = model.generate(messages=[[message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 1
|
||||
assert len(response.generations[0]) == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["chat-bison@001", "gemini-pro"])
|
||||
async def test_vertexai_agenerate(model_name: str) -> None:
|
||||
model = ChatVertexAI(temperature=0, model_name=model_name)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = await model.agenerate([[message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
assert isinstance(response.generations[0][0].message, AIMessage) # type: ignore
|
||||
|
||||
sync_response = model.generate([[message]])
|
||||
assert response.generations[0][0] == sync_response.generations[0][0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["chat-bison@001", "gemini-pro"])
|
||||
def test_vertexai_stream(model_name: str) -> None:
|
||||
model = ChatVertexAI(temperature=0, model_name=model_name)
|
||||
message = HumanMessage(content="Hello")
|
||||
|
||||
sync_response = model.stream([message])
|
||||
for chunk in sync_response:
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
|
||||
|
||||
def test_vertexai_single_call_with_context() -> None:
|
||||
model = ChatVertexAI()
|
||||
raw_context = (
|
||||
"My name is Ned. You are my personal assistant. My favorite movies "
|
||||
"are Lord of the Rings and Hobbit."
|
||||
)
|
||||
question = (
|
||||
"Hello, could you recommend a good movie for me to watch this evening, please?"
|
||||
)
|
||||
context = SystemMessage(content=raw_context)
|
||||
message = HumanMessage(content=question)
|
||||
response = model([context, message])
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_multimodal() -> None:
|
||||
llm = ChatVertexAI(model_name="gemini-pro-vision")
|
||||
gcs_url = (
|
||||
"gs://cloud-samples-data/generative-ai/image/"
|
||||
"320px-Felis_catus-cat_on_snow.jpg"
|
||||
)
|
||||
image_message = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": gcs_url},
|
||||
}
|
||||
text_message = {
|
||||
"type": "text",
|
||||
"text": "What is shown in this image?",
|
||||
}
|
||||
message = HumanMessage(content=[text_message, image_message])
|
||||
output = llm([message])
|
||||
assert isinstance(output.content, str)
|
||||
|
||||
|
||||
def test_multimodal_history() -> None:
|
||||
llm = ChatVertexAI(model_name="gemini-pro-vision")
|
||||
gcs_url = (
|
||||
"gs://cloud-samples-data/generative-ai/image/"
|
||||
"320px-Felis_catus-cat_on_snow.jpg"
|
||||
)
|
||||
image_message = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": gcs_url},
|
||||
}
|
||||
text_message = {
|
||||
"type": "text",
|
||||
"text": "What is shown in this image?",
|
||||
}
|
||||
message1 = HumanMessage(content=[text_message, image_message])
|
||||
message2 = AIMessage(
|
||||
content=(
|
||||
"This is a picture of a cat in the snow. The cat is a tabby cat, which is "
|
||||
"a type of cat with a striped coat. The cat is standing in the snow, and "
|
||||
"its fur is covered in snow."
|
||||
)
|
||||
)
|
||||
message3 = HumanMessage(content="What time of day is it?")
|
||||
response = llm([message1, message2, message3])
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_vertexai_single_call_with_examples() -> None:
|
||||
model = ChatVertexAI()
|
||||
raw_context = "My name is Ned. You are my personal assistant."
|
||||
question = "2+2"
|
||||
text_question, text_answer = "4+4", "8"
|
||||
inp = HumanMessage(content=text_question)
|
||||
output = AIMessage(content=text_answer)
|
||||
context = SystemMessage(content=raw_context)
|
||||
message = HumanMessage(content=question)
|
||||
response = model([context, message], examples=[inp, output])
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", model_names_to_test)
|
||||
def test_vertexai_single_call_with_history(model_name: str) -> None:
|
||||
if model_name:
|
||||
model = ChatVertexAI(model_name=model_name)
|
||||
else:
|
||||
model = ChatVertexAI()
|
||||
text_question1, text_answer1 = "How much is 2+2?", "4"
|
||||
text_question2 = "How much is 3+3?"
|
||||
message1 = HumanMessage(content=text_question1)
|
||||
message2 = AIMessage(content=text_answer1)
|
||||
message3 = HumanMessage(content=text_question2)
|
||||
response = model([message1, message2, message3])
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_vertexai_single_call_fails_no_message() -> None:
|
||||
chat = ChatVertexAI()
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_ = chat([])
|
||||
assert (
|
||||
str(exc_info.value)
|
||||
== "You should provide at least one message to start the chat!"
|
||||
)
|
@ -0,0 +1,7 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.compile
|
||||
def test_placeholder() -> None:
|
||||
"""Used for compiling integration tests without running any real tests."""
|
||||
pass
|
@ -0,0 +1,70 @@
|
||||
"""Test Vertex AI API wrapper.
|
||||
|
||||
Your end-user credentials would be used to make the calls (make sure you've run
|
||||
`gcloud auth login` first).
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from langchain_google_vertexai.embeddings import VertexAIEmbeddings
|
||||
|
||||
|
||||
def test_initialization() -> None:
|
||||
"""Test embedding model initialization."""
|
||||
VertexAIEmbeddings()
|
||||
|
||||
|
||||
def test_langchain_google_vertexai_embedding_documents() -> None:
|
||||
documents = ["foo bar"]
|
||||
model = VertexAIEmbeddings()
|
||||
output = model.embed_documents(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 768
|
||||
assert model.model_name == model.client._model_id
|
||||
assert model.model_name == "textembedding-gecko@001"
|
||||
|
||||
|
||||
def test_langchain_google_vertexai_embedding_query() -> None:
|
||||
document = "foo bar"
|
||||
model = VertexAIEmbeddings()
|
||||
output = model.embed_query(document)
|
||||
assert len(output) == 768
|
||||
|
||||
|
||||
def test_langchain_google_vertexai_large_batches() -> None:
|
||||
documents = ["foo bar" for _ in range(0, 251)]
|
||||
model_uscentral1 = VertexAIEmbeddings(location="us-central1")
|
||||
model_asianortheast1 = VertexAIEmbeddings(location="asia-northeast1")
|
||||
model_uscentral1.embed_documents(documents)
|
||||
model_asianortheast1.embed_documents(documents)
|
||||
assert model_uscentral1.instance["batch_size"] >= 250
|
||||
assert model_asianortheast1.instance["batch_size"] < 50
|
||||
|
||||
|
||||
def test_langchain_google_vertexai_paginated_texts() -> None:
|
||||
documents = [
|
||||
"foo bar",
|
||||
"foo baz",
|
||||
"bar foo",
|
||||
"baz foo",
|
||||
"bar bar",
|
||||
"foo foo",
|
||||
"baz baz",
|
||||
"baz bar",
|
||||
]
|
||||
model = VertexAIEmbeddings()
|
||||
output = model.embed_documents(documents)
|
||||
assert len(output) == 8
|
||||
assert len(output[0]) == 768
|
||||
assert model.model_name == model.client._model_id
|
||||
|
||||
|
||||
def test_warning(caplog: pytest.LogCaptureFixture) -> None:
|
||||
_ = VertexAIEmbeddings()
|
||||
assert len(caplog.records) == 1
|
||||
record = caplog.records[0]
|
||||
assert record.levelname == "WARNING"
|
||||
expected_message = (
|
||||
"Model_name will become a required arg for VertexAIEmbeddings starting from "
|
||||
"Feb-01-2024. Currently the default is set to textembedding-gecko@001"
|
||||
)
|
||||
assert record.message == expected_message
|
@ -0,0 +1,175 @@
|
||||
"""Test Vertex AI API wrapper.
|
||||
|
||||
Your end-user credentials would be used to make the calls (make sure you've run
|
||||
`gcloud auth login` first).
|
||||
"""
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from langchain_google_vertexai.llms import VertexAI, VertexAIModelGarden
|
||||
|
||||
model_names_to_test = ["text-bison@001", "gemini-pro"]
|
||||
model_names_to_test_with_default = [None] + model_names_to_test
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
model_names_to_test_with_default,
|
||||
)
|
||||
def test_vertex_initialization(model_name: str) -> None:
|
||||
llm = VertexAI(model_name=model_name) if model_name else VertexAI()
|
||||
assert llm._llm_type == "vertexai"
|
||||
try:
|
||||
assert llm.model_name == llm.client._model_id
|
||||
except AttributeError:
|
||||
assert llm.model_name == llm.client._model_name.split("/")[-1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
model_names_to_test_with_default,
|
||||
)
|
||||
def test_vertex_call(model_name: str) -> None:
|
||||
llm = (
|
||||
VertexAI(model_name=model_name, temperature=0)
|
||||
if model_name
|
||||
else VertexAI(temperature=0.0)
|
||||
)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_vertex_generate() -> None:
|
||||
llm = VertexAI(temperature=0.3, n=2, model_name="text-bison@001")
|
||||
output = llm.generate(["Say foo:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations) == 1
|
||||
assert len(output.generations[0]) == 2
|
||||
|
||||
|
||||
def test_vertex_generate_code() -> None:
|
||||
llm = VertexAI(temperature=0.3, n=2, model_name="code-bison@001")
|
||||
output = llm.generate(["generate a python method that says foo:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations) == 1
|
||||
assert len(output.generations[0]) == 2
|
||||
|
||||
|
||||
async def test_vertex_agenerate() -> None:
|
||||
llm = VertexAI(temperature=0)
|
||||
output = await llm.agenerate(["Please say foo:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
model_names_to_test_with_default,
|
||||
)
|
||||
def test_stream(model_name: str) -> None:
|
||||
llm = (
|
||||
VertexAI(temperature=0, model_name=model_name)
|
||||
if model_name
|
||||
else VertexAI(temperature=0)
|
||||
)
|
||||
for token in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
async def test_vertex_consistency() -> None:
|
||||
llm = VertexAI(temperature=0)
|
||||
output = llm.generate(["Please say foo:"])
|
||||
streaming_output = llm.generate(["Please say foo:"], stream=True)
|
||||
async_output = await llm.agenerate(["Please say foo:"])
|
||||
assert output.generations[0][0].text == streaming_output.generations[0][0].text
|
||||
assert output.generations[0][0].text == async_output.generations[0][0].text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint_os_variable_name,result_arg",
|
||||
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
|
||||
)
|
||||
def test_model_garden(
|
||||
endpoint_os_variable_name: str, result_arg: Optional[str]
|
||||
) -> None:
|
||||
"""In order to run this test, you should provide endpoint names.
|
||||
|
||||
Example:
|
||||
export FALCON_ENDPOINT_ID=...
|
||||
export LLAMA_ENDPOINT_ID=...
|
||||
export PROJECT=...
|
||||
"""
|
||||
endpoint_id = os.environ[endpoint_os_variable_name]
|
||||
project = os.environ["PROJECT"]
|
||||
location = "europe-west4"
|
||||
llm = VertexAIModelGarden(
|
||||
endpoint_id=endpoint_id,
|
||||
project=project,
|
||||
result_arg=result_arg,
|
||||
location=location,
|
||||
)
|
||||
output = llm("What is the meaning of life?")
|
||||
assert isinstance(output, str)
|
||||
assert llm._llm_type == "vertexai_model_garden"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint_os_variable_name,result_arg",
|
||||
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
|
||||
)
|
||||
def test_model_garden_generate(
|
||||
endpoint_os_variable_name: str, result_arg: Optional[str]
|
||||
) -> None:
|
||||
"""In order to run this test, you should provide endpoint names.
|
||||
|
||||
Example:
|
||||
export FALCON_ENDPOINT_ID=...
|
||||
export LLAMA_ENDPOINT_ID=...
|
||||
export PROJECT=...
|
||||
"""
|
||||
endpoint_id = os.environ[endpoint_os_variable_name]
|
||||
project = os.environ["PROJECT"]
|
||||
location = "europe-west4"
|
||||
llm = VertexAIModelGarden(
|
||||
endpoint_id=endpoint_id,
|
||||
project=project,
|
||||
result_arg=result_arg,
|
||||
location=location,
|
||||
)
|
||||
output = llm.generate(["What is the meaning of life?", "How much is 2+2"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint_os_variable_name,result_arg",
|
||||
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
|
||||
)
|
||||
async def test_model_garden_agenerate(
|
||||
endpoint_os_variable_name: str, result_arg: Optional[str]
|
||||
) -> None:
|
||||
endpoint_id = os.environ[endpoint_os_variable_name]
|
||||
project = os.environ["PROJECT"]
|
||||
location = "europe-west4"
|
||||
llm = VertexAIModelGarden(
|
||||
endpoint_id=endpoint_id,
|
||||
project=project,
|
||||
result_arg=result_arg,
|
||||
location=location,
|
||||
)
|
||||
output = await llm.agenerate(["What is the meaning of life?", "How much is 2+2"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations) == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
model_names_to_test,
|
||||
)
|
||||
def test_vertex_call_count_tokens(model_name: str) -> None:
|
||||
llm = VertexAI(model_name=model_name)
|
||||
output = llm.get_num_tokens("How are you?")
|
||||
assert output == 4
|
@ -0,0 +1,112 @@
|
||||
"""Test chat model integration."""
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from vertexai.language_models import ChatMessage, InputOutputTextPair # type: ignore
|
||||
|
||||
from langchain_google_vertexai.chat_models import (
|
||||
ChatVertexAI,
|
||||
_parse_chat_history,
|
||||
_parse_examples,
|
||||
)
|
||||
|
||||
|
||||
def test_parse_examples_correct() -> None:
|
||||
text_question = (
|
||||
"Hello, could you recommend a good movie for me to watch this evening, please?"
|
||||
)
|
||||
question = HumanMessage(content=text_question)
|
||||
text_answer = (
|
||||
"Sure, You might enjoy The Lord of the Rings: The Fellowship of the Ring "
|
||||
"(2001): This is the first movie in the Lord of the Rings trilogy."
|
||||
)
|
||||
answer = AIMessage(content=text_answer)
|
||||
examples = _parse_examples([question, answer, question, answer])
|
||||
assert len(examples) == 2
|
||||
assert examples == [
|
||||
InputOutputTextPair(input_text=text_question, output_text=text_answer),
|
||||
InputOutputTextPair(input_text=text_question, output_text=text_answer),
|
||||
]
|
||||
|
||||
|
||||
def test_parse_examples_failes_wrong_sequence() -> None:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_ = _parse_examples([AIMessage(content="a")])
|
||||
assert (
|
||||
str(exc_info.value)
|
||||
== "Expect examples to have an even amount of messages, got 1."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stop", [None, "stop1"])
|
||||
def test_vertexai_args_passed(stop: Optional[str]) -> None:
|
||||
response_text = "Goodbye"
|
||||
user_prompt = "Hello"
|
||||
prompt_params = {
|
||||
"max_output_tokens": 1,
|
||||
"temperature": 10000.0,
|
||||
"top_k": 10,
|
||||
"top_p": 0.5,
|
||||
}
|
||||
|
||||
# Mock the library to ensure the args are passed correctly
|
||||
with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg:
|
||||
mock_response = MagicMock()
|
||||
mock_response.candidates = [Mock(text=response_text)]
|
||||
mock_chat = MagicMock()
|
||||
mock_send_message = MagicMock(return_value=mock_response)
|
||||
mock_chat.send_message = mock_send_message
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_start_chat = MagicMock(return_value=mock_chat)
|
||||
mock_model.start_chat = mock_start_chat
|
||||
mg.return_value = mock_model
|
||||
|
||||
model = ChatVertexAI(**prompt_params)
|
||||
message = HumanMessage(content=user_prompt)
|
||||
if stop:
|
||||
response = model([message], stop=[stop])
|
||||
else:
|
||||
response = model([message])
|
||||
|
||||
assert response.content == response_text
|
||||
mock_send_message.assert_called_once_with(user_prompt, candidate_count=1)
|
||||
expected_stop_sequence = [stop] if stop else None
|
||||
mock_start_chat.assert_called_once_with(
|
||||
context=None,
|
||||
message_history=[],
|
||||
**prompt_params,
|
||||
stop_sequences=expected_stop_sequence,
|
||||
)
|
||||
|
||||
|
||||
def test_parse_chat_history_correct() -> None:
|
||||
text_context = (
|
||||
"My name is Ned. You are my personal assistant. My "
|
||||
"favorite movies are Lord of the Rings and Hobbit."
|
||||
)
|
||||
context = SystemMessage(content=text_context)
|
||||
text_question = (
|
||||
"Hello, could you recommend a good movie for me to watch this evening, please?"
|
||||
)
|
||||
question = HumanMessage(content=text_question)
|
||||
text_answer = (
|
||||
"Sure, You might enjoy The Lord of the Rings: The Fellowship of the Ring "
|
||||
"(2001): This is the first movie in the Lord of the Rings trilogy."
|
||||
)
|
||||
answer = AIMessage(content=text_answer)
|
||||
history = _parse_chat_history([context, question, answer, question, answer])
|
||||
assert history.context == context.content
|
||||
assert len(history.history) == 4
|
||||
assert history.history == [
|
||||
ChatMessage(content=text_question, author="user"),
|
||||
ChatMessage(content=text_answer, author="bot"),
|
||||
ChatMessage(content=text_question, author="user"),
|
||||
ChatMessage(content=text_answer, author="bot"),
|
||||
]
|
@ -0,0 +1,7 @@
|
||||
from langchain_google_vertexai import __all__
|
||||
|
||||
EXPECTED_ALL = ["ChatVertexAI", "VertexAIEmbeddings", "VertexAI", "VertexAIModelGarden"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
Loading…
Reference in New Issue