Harrison/vertex (#5049)

Co-authored-by: Leonid Kuligin <kuligin@google.com>
Co-authored-by: Leonid Kuligin <lkuligin@yandex.ru>
Co-authored-by: sasha-gitg <44654632+sasha-gitg@users.noreply.github.com>
Co-authored-by: Justin Flick <Justinjayflick@gmail.com>
Co-authored-by: Justin Flick <jflick@homesite.com>
searx_updates
Harrison Chase 12 months ago committed by GitHub
parent e6c4571191
commit a775aa6389
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,170 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Google Cloud Platform Vertex AI PaLM \n",
"\n",
"Note: This is seperate from the Google PaLM integration. Google has chosen to offer an enterprise version of PaLM through GCP, and this supports the models made available through there. \n",
"\n",
"PaLM API on Vertex AI is a Preview offering, subject to the Pre-GA Offerings Terms of the [GCP Service Specific Terms](https://cloud.google.com/terms/service-terms). \n",
"\n",
"Pre-GA products and features may have limited support, and changes to pre-GA products and features may not be compatible with other pre-GA versions. For more information, see the [launch stage descriptions](https://cloud.google.com/products#product-launch-stages). Further, by using PaLM API on Vertex AI, you agree to the Generative AI Preview [terms and conditions](https://cloud.google.com/trustedtester/aitos) (Preview Terms).\n",
"\n",
"For PaLM API on Vertex AI, you can process personal data as outlined in the Cloud Data Processing Addendum, subject to applicable restrictions and obligations in the Agreement (as defined in the Preview Terms).\n",
"\n",
"To use Vertex AI PaLM you must have the `google-cloud-aiplatform` Python package installed and either:\n",
"- Have credentials configured for your environment (gcloud, workload identity, etc...)\n",
"- Store the path to a service account JSON file as the GOOGLE_APPLICATION_CREDENTIALS environment variable\n",
"\n",
"This codebase uses the `google.auth` library which first looks for the application credentials variable mentioned above, and then looks for system-level auth.\n",
"\n",
"For more information, see: \n",
"- https://cloud.google.com/docs/authentication/application-default-credentials#GAC\n",
"- https://googleapis.dev/python/google-auth/latest/reference/google.auth.html#module-google.auth\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"#!pip install google-cloud-aiplatform"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"\n",
"from langchain.chat_models import ChatVertexAI\n",
"from langchain.prompts.chat import (\n",
" ChatPromptTemplate,\n",
" SystemMessagePromptTemplate,\n",
" HumanMessagePromptTemplate,\n",
")\n",
"from langchain.schema import (\n",
" HumanMessage,\n",
" SystemMessage\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"chat = ChatVertexAI()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='Sure, here is the translation of the sentence \"I love programming\" from English to French:\\n\\nJ\\'aime programmer.', additional_kwargs={}, example=False)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = [\n",
" SystemMessage(content=\"You are a helpful assistant that translates English to French.\"),\n",
" HumanMessage(content=\"Translate this sentence from English to French. I love programming.\")\n",
"]\n",
"chat(messages)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can make use of templating by using a `MessagePromptTemplate`. You can build a `ChatPromptTemplate` from one or more `MessagePromptTemplates`. You can use `ChatPromptTemplate`'s `format_prompt` -- this returns a `PromptValue`, which you can convert to a string or Message object, depending on whether you want to use the formatted value as input to an llm or chat model.\n",
"\n",
"For convenience, there is a `from_template` method exposed on the template. If you were to use this template, this is what it would look like:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"template=\"You are a helpful assistant that translates {input_language} to {output_language}.\"\n",
"system_message_prompt = SystemMessagePromptTemplate.from_template(template)\n",
"human_template=\"{text}\"\n",
"human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='Sure, here is the translation of \"I love programming\" in French:\\n\\nJ\\'aime programmer.', additional_kwargs={}, example=False)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])\n",
"\n",
"# get a chat completion from the formatted messages\n",
"chat(chat_prompt.format_prompt(input_language=\"English\", output_language=\"French\", text=\"I love programming.\").to_messages())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
},
"vscode": {
"interpreter": {
"hash": "cc99336516f23363341912c6723b01ace86f02e26b4290be1efc0677e2e2ec24"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}

@ -0,0 +1,138 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Google Cloud Platform Vertex AI PaLM \n",
"\n",
"Note: This is seperate from the Google PaLM integration. Google has chosen to offer an enterprise version of PaLM through GCP, and this supports the models made available through there. \n",
"\n",
"PaLM API on Vertex AI is a Preview offering, subject to the Pre-GA Offerings Terms of the [GCP Service Specific Terms](https://cloud.google.com/terms/service-terms). \n",
"\n",
"Pre-GA products and features may have limited support, and changes to pre-GA products and features may not be compatible with other pre-GA versions. For more information, see the [launch stage descriptions](https://cloud.google.com/products#product-launch-stages). Further, by using PaLM API on Vertex AI, you agree to the Generative AI Preview [terms and conditions](https://cloud.google.com/trustedtester/aitos) (Preview Terms).\n",
"\n",
"For PaLM API on Vertex AI, you can process personal data as outlined in the Cloud Data Processing Addendum, subject to applicable restrictions and obligations in the Agreement (as defined in the Preview Terms).\n",
"\n",
"To use Vertex AI PaLM you must have the `google-cloud-aiplatform` Python package installed and either:\n",
"- Have credentials configured for your environment (gcloud, workload identity, etc...)\n",
"- Store the path to a service account JSON file as the GOOGLE_APPLICATION_CREDENTIALS environment variable\n",
"\n",
"This codebase uses the `google.auth` library which first looks for the application credentials variable mentioned above, and then looks for system-level auth.\n",
"\n",
"For more information, see: \n",
"- https://cloud.google.com/docs/authentication/application-default-credentials#GAC\n",
"- https://googleapis.dev/python/google-auth/latest/reference/google.auth.html#module-google.auth\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"#!pip install google-cloud-aiplatform"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"\n",
"from langchain.llms import VertexAI\n",
"from langchain import PromptTemplate, LLMChain"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"Question: {question}\n",
"\n",
"Answer: Let's think step by step.\"\"\"\n",
"\n",
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"llm = VertexAI()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Justin Bieber was born on March 1, 1994. The Super Bowl in 1994 was won by the San Francisco 49ers.\\nThe final answer: San Francisco 49ers.'"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n",
"\n",
"llm_chain.run(question)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
},
"vscode": {
"interpreter": {
"hash": "cc99336516f23363341912c6723b01ace86f02e26b4290be1efc0677e2e2ec24"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}

@ -0,0 +1,113 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Google Cloud Platform Vertex AI PaLM \n",
"\n",
"Note: This is seperate from the Google PaLM integration. Google has chosen to offer an enterprise version of PaLM through GCP, and this supports the models made available through there. \n",
"\n",
"PaLM API on Vertex AI is a Preview offering, subject to the Pre-GA Offerings Terms of the [GCP Service Specific Terms](https://cloud.google.com/terms/service-terms). \n",
"\n",
"Pre-GA products and features may have limited support, and changes to pre-GA products and features may not be compatible with other pre-GA versions. For more information, see the [launch stage descriptions](https://cloud.google.com/products#product-launch-stages). Further, by using PaLM API on Vertex AI, you agree to the Generative AI Preview [terms and conditions](https://cloud.google.com/trustedtester/aitos) (Preview Terms).\n",
"\n",
"For PaLM API on Vertex AI, you can process personal data as outlined in the Cloud Data Processing Addendum, subject to applicable restrictions and obligations in the Agreement (as defined in the Preview Terms).\n",
"\n",
"To use Vertex AI PaLM you must have the `google-cloud-aiplatform` Python package installed and either:\n",
"- Have credentials configured for your environment (gcloud, workload identity, etc...)\n",
"- Store the path to a service account JSON file as the GOOGLE_APPLICATION_CREDENTIALS environment variable\n",
"\n",
"This codebase uses the `google.auth` library which first looks for the application credentials variable mentioned above, and then looks for system-level auth.\n",
"\n",
"For more information, see: \n",
"- https://cloud.google.com/docs/authentication/application-default-credentials#GAC\n",
"- https://googleapis.dev/python/google-auth/latest/reference/google.auth.html#module-google.auth\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"#!pip install google-cloud-aiplatform"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"\n",
"from langchain.embeddings import VertexAIEmbeddings"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"embeddings = VertexAIEmbeddings()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"text = \"This is a test document.\""
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"query_result = embeddings.embed_query(text)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"doc_result = embeddings.embed_documents([text])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
},
"vscode": {
"interpreter": {
"hash": "cc99336516f23363341912c6723b01ace86f02e26b4290be1efc0677e2e2ec24"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}

@ -3,6 +3,7 @@ from langchain.chat_models.azure_openai import AzureChatOpenAI
from langchain.chat_models.google_palm import ChatGooglePalm
from langchain.chat_models.openai import ChatOpenAI
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
from langchain.chat_models.vertexai import ChatVertexAI
__all__ = [
"ChatOpenAI",
@ -10,4 +11,5 @@ __all__ = [
"PromptLayerChatOpenAI",
"ChatAnthropic",
"ChatGooglePalm",
"ChatVertexAI",
]

@ -0,0 +1,137 @@
"""Wrapper around Google VertexAI chat-based models."""
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from pydantic import root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.llms.vertexai import _VertexAICommon
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
ChatResult,
HumanMessage,
SystemMessage,
)
from langchain.utilities.vertexai import raise_vertex_import_error
@dataclass
class _MessagePair:
"""InputOutputTextPair represents a pair of input and output texts."""
question: HumanMessage
answer: AIMessage
@dataclass
class _ChatHistory:
"""InputOutputTextPair represents a pair of input and output texts."""
history: List[_MessagePair] = field(default_factory=list)
system_message: Optional[SystemMessage] = None
def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory:
"""Parse a sequence of messages into history.
A sequence should be either (SystemMessage, HumanMessage, AIMessage,
HumanMessage, AIMessage, ...) or (HumanMessage, AIMessage, HumanMessage,
AIMessage, ...).
Args:
history: The list of messages to re-create the history of the chat.
Returns:
A parsed chat history.
Raises:
ValueError: If a sequence of message is odd, or a human message is not followed
by a message from AI (e.g., Human, Human, AI or AI, AI, Human).
"""
if not history:
return _ChatHistory()
first_message = history[0]
system_message = first_message if isinstance(first_message, SystemMessage) else None
chat_history = _ChatHistory(system_message=system_message)
messages_left = history[1:] if system_message else history
if len(messages_left) % 2 != 0:
raise ValueError(
f"Amount of messages in history should be even, got {len(messages_left)}!"
)
for question, answer in zip(messages_left[::2], messages_left[1::2]):
if not isinstance(question, HumanMessage) or not isinstance(answer, AIMessage):
raise ValueError(
"A human message should follow a bot one, "
f"got {question.type}, {answer.type}."
)
chat_history.history.append(_MessagePair(question=question, answer=answer))
return chat_history
class ChatVertexAI(_VertexAICommon, BaseChatModel):
"""Wrapper around Vertex AI large language models."""
model_name: str = "chat-bison"
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""
cls._try_init_vertexai(values)
try:
from vertexai.preview.language_models import ChatModel
except ImportError:
raise_vertex_import_error()
values["client"] = ChatModel.from_pretrained(values["model_name"])
return values
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> ChatResult:
"""Generate next turn in the conversation.
Args:
messages: The history of the conversation as a list of messages.
stop: The list of stop words (optional).
run_manager: The Callbackmanager for LLM run, it's not used at the moment.
Returns:
The ChatResult that contains outputs generated by the model.
Raises:
ValueError: if the last message in the list is not from human.
"""
if not messages:
raise ValueError(
"You should provide at least one message to start the chat!"
)
question = messages[-1]
if not isinstance(question, HumanMessage):
raise ValueError(
f"Last message in the list should be from human, got {question.type}."
)
history = _parse_chat_history(messages[:-1])
context = history.system_message.content if history.system_message else None
chat = self.client.start_chat(context=context, **self._default_params)
for pair in history.history:
chat._history.append((pair.question.content, pair.answer.content))
response = chat.send_message(question.content)
text = self._enforce_stop_words(response.text, stop)
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> ChatResult:
raise NotImplementedError(
"""Vertex AI doesn't support async requests at the moment."""
)

@ -28,6 +28,7 @@ from langchain.embeddings.self_hosted_hugging_face import (
)
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain.embeddings.tensorflow_hub import TensorflowHubEmbeddings
from langchain.embeddings.vertexai import VertexAIEmbeddings
logger = logging.getLogger(__name__)
@ -52,6 +53,7 @@ __all__ = [
"AlephAlphaSymmetricSemanticEmbedding",
"SentenceTransformerEmbeddings",
"GooglePalmEmbeddings",
"VertexAIEmbeddings",
]

@ -0,0 +1,47 @@
"""Wrapper around Google VertexAI embedding models."""
from typing import Dict, List
from pydantic import root_validator
from langchain.embeddings.base import Embeddings
from langchain.llms.vertexai import _VertexAICommon
from langchain.utilities.vertexai import raise_vertex_import_error
class VertexAIEmbeddings(_VertexAICommon, Embeddings):
model_name: str = "textembedding-gecko"
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validates that the python package exists in environment."""
cls._try_init_vertexai(values)
try:
from vertexai.preview.language_models import TextEmbeddingModel
except ImportError:
raise_vertex_import_error()
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of strings.
Args:
texts: List[str] The list of strings to embed.
Returns:
List of embeddings, one for each text.
"""
embeddings = self.client.get_embeddings(texts)
return [el.values for el in embeddings]
def embed_query(self, text: str) -> List[float]:
"""Embed a text.
Args:
text: The text to embed.
Returns:
Embedding for the text.
"""
embeddings = self.client.get_embeddings([text])
return embeddings[0].values

@ -37,6 +37,7 @@ from langchain.llms.sagemaker_endpoint import SagemakerEndpoint
from langchain.llms.self_hosted import SelfHostedPipeline
from langchain.llms.self_hosted_hugging_face import SelfHostedHuggingFaceLLM
from langchain.llms.stochasticai import StochasticAI
from langchain.llms.vertexai import VertexAI
from langchain.llms.writer import Writer
__all__ = [
@ -79,6 +80,7 @@ __all__ = [
"HumanInputLLM",
"HuggingFaceTextGenInference",
"FakeListLLM",
"VertexAI",
]
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
@ -117,4 +119,5 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"rwkv": RWKV,
"huggingface_textgen_inference": HuggingFaceTextGenInference,
"fake-list": FakeListLLM,
"vertexai": VertexAI,
}

@ -0,0 +1,110 @@
"""Wrapper around Google VertexAI models."""
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from pydantic import BaseModel, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utilities.vertexai import (
init_vertexai,
raise_vertex_import_error,
)
if TYPE_CHECKING:
from google.auth.credentials import Credentials
from vertexai.language_models._language_models import _LanguageModel
class _VertexAICommon(BaseModel):
client: "_LanguageModel" = None #: :meta private:
model_name: str
"Model name to use."
temperature: float = 0.0
"Sampling temperature, it controls the degree of randomness in token selection."
max_output_tokens: int = 128
"Token limit determines the maximum amount of text output from one prompt."
top_p: float = 0.95
"Tokens are selected from most probable to least until the sum of their "
"probabilities equals the top-p value."
top_k: int = 40
"How the model selects tokens for output, the next token is selected from "
"among the top-k most probable tokens."
project: Optional[str] = None
"The default GCP project to use when making Vertex API calls."
location: str = "us-central1"
"The default location to use when making API calls."
credentials: Optional["Credentials"] = None
"The default custom credentials to use when making API calls. If not provided "
"credentials will be ascertained from the environment." ""
@property
def _default_params(self) -> Dict[str, Any]:
base_params = {
"temperature": self.temperature,
"max_output_tokens": self.max_output_tokens,
"top_k": self.top_p,
"top_p": self.top_k,
}
return {**base_params}
def _predict(self, prompt: str, stop: Optional[List[str]]) -> str:
res = self.client.predict(prompt, **self._default_params)
return self._enforce_stop_words(res.text, stop)
def _enforce_stop_words(self, text: str, stop: Optional[List[str]]) -> str:
if stop:
return enforce_stop_tokens(text, stop)
return text
@property
def _llm_type(self) -> str:
return "vertexai"
@classmethod
def _try_init_vertexai(cls, values: Dict) -> None:
allowed_params = ["project", "location", "credentials"]
params = {k: v for k, v in values.items() if v in allowed_params}
init_vertexai(**params)
return None
class VertexAI(_VertexAICommon, LLM):
"""Wrapper around Google Vertex AI large language models."""
model_name: str = "text-bison"
tuned_model_name: Optional[str] = None
"The name of a tuned model, if it's provided, model_name is ignored."
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""
cls._try_init_vertexai(values)
try:
from vertexai.preview.language_models import TextGenerationModel
except ImportError:
raise_vertex_import_error()
tuned_model_name = values.get("tuned_model_name")
if tuned_model_name:
values["client"] = TextGenerationModel.get_tuned_model(tuned_model_name)
else:
values["client"] = TextGenerationModel.from_pretrained(values["model_name"])
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Call Vertex model to get predictions based on the prompt.
Args:
prompt: The prompt to pass into the model.
stop: A list of stop words (optional).
run_manager: A Callbackmanager for LLM run, optional.
Returns:
The string generated by the model.
"""
return self._predict(prompt, stop)

@ -0,0 +1,46 @@
"""Utilities to init Vertex AI."""
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from google.auth.credentials import Credentials
def raise_vertex_import_error() -> None:
"""Raise ImportError related to Vertex SDK being not available.
Raises:
ImportError: an ImportError that mentions a required version of the SDK.
"""
sdk = "'google-cloud-aiplatform>=1.25.0'"
raise ImportError(
"Could not import VertexAI. Please, install it with " f"pip install {sdk}"
)
def init_vertexai(
project_id: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional["Credentials"] = None,
) -> None:
"""Init vertexai.
Args:
project: The default GCP project to use when making Vertex API calls.
location: The default location to use when making API calls.
credentials: The default custom
credentials to use when making API calls. If not provided credentials
will be ascertained from the environment.
Raises:
ImportError: If importing vertexai SDK didn't not succeed.
"""
try:
import vertexai
except ImportError:
raise_vertex_import_error()
vertexai.init(
project=project_id,
location=location,
credentials=credentials,
)

@ -0,0 +1,88 @@
"""Test Vertex AI API wrapper.
In order to run this test, you need to install VertexAI SDK (that is is the private
preview) and be whitelisted to list the models themselves:
In order to run this test, you need to install VertexAI SDK
pip install google-cloud-aiplatform>=1.25.0
Your end-user credentials would be used to make the calls (make sure you've run
`gcloud auth login` first).
"""
import pytest
from langchain.chat_models import ChatVertexAI
from langchain.chat_models.vertexai import _MessagePair, _parse_chat_history
from langchain.schema import (
AIMessage,
HumanMessage,
SystemMessage,
)
def test_vertexai_single_call() -> None:
model = ChatVertexAI()
message = HumanMessage(content="Hello")
response = model([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert model._llm_type == "vertexai"
assert model.model_name == model.client._model_id
def test_vertexai_single_call_with_context() -> None:
model = ChatVertexAI()
raw_context = (
"My name is Ned. You are my personal assistant. My favorite movies "
"are Lord of the Rings and Hobbit."
)
question = (
"Hello, could you recommend a good movie for me to watch this evening, please?"
)
context = SystemMessage(content=raw_context)
message = HumanMessage(content=question)
response = model([context, message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_parse_chat_history_correct() -> None:
text_context = (
"My name is Ned. You are my personal assistant. My "
"favorite movies are Lord of the Rings and Hobbit."
)
context = SystemMessage(content=text_context)
text_question = (
"Hello, could you recommend a good movie for me to watch this evening, please?"
)
question = HumanMessage(content=text_question)
text_answer = (
"Sure, You might enjoy The Lord of the Rings: The Fellowship of the Ring "
"(2001): This is the first movie in the Lord of the Rings trilogy."
)
answer = AIMessage(content=text_answer)
history = _parse_chat_history([context, question, answer, question, answer])
assert history.system_message == context
assert len(history.history) == 2
assert history.history[0] == _MessagePair(question=question, answer=answer)
def test_parse_chat_history_wrong_sequence() -> None:
text_question = (
"Hello, could you recommend a good movie for me to watch this evening, please?"
)
question = HumanMessage(content=text_question)
with pytest.raises(ValueError) as exc_info:
_ = _parse_chat_history([question, question])
assert (
str(exc_info.value)
== "A human message should follow a bot one, got human, human."
)
def test_vertexai_single_call_failes_no_message() -> None:
chat = ChatVertexAI()
with pytest.raises(ValueError) as exc_info:
_ = chat([])
assert (
str(exc_info.value)
== "You should provide at least one message to start the chat!"
)

@ -0,0 +1,25 @@
"""Test Vertex AI API wrapper.
In order to run this test, you need to install VertexAI SDK
pip install google-cloud-aiplatform>=1.25.0
Your end-user credentials would be used to make the calls (make sure you've run
`gcloud auth login` first).
"""
from langchain.embeddings import VertexAIEmbeddings
def test_embedding_documents() -> None:
documents = ["foo bar"]
model = VertexAIEmbeddings()
output = model.embed_documents(documents)
assert len(output) == 1
assert len(output[0]) == 768
assert model._llm_type == "vertexai"
assert model.model_name == model.client._model_id
def test_embedding_query() -> None:
document = "foo bar"
model = VertexAIEmbeddings()
output = model.embed_query(document)
assert len(output) == 768

@ -0,0 +1,18 @@
"""Test Vertex AI API wrapper.
In order to run this test, you need to install VertexAI SDK (that is is the private
preview) and be whitelisted to list the models themselves:
In order to run this test, you need to install VertexAI SDK
pip install google-cloud-aiplatform>=1.25.0
Your end-user credentials would be used to make the calls (make sure you've run
`gcloud auth login` first).
"""
from langchain.llms import VertexAI
def test_vertex_call() -> None:
llm = VertexAI()
output = llm("Say foo:")
assert isinstance(output, str)
assert llm._llm_type == "vertexai"
assert llm.model_name == llm.client._model_id
Loading…
Cancel
Save