diff --git a/docs/modules/models/chat/integrations/google_vertex_ai_palm.ipynb b/docs/modules/models/chat/integrations/google_vertex_ai_palm.ipynb new file mode 100644 index 00000000..f5333d8c --- /dev/null +++ b/docs/modules/models/chat/integrations/google_vertex_ai_palm.ipynb @@ -0,0 +1,170 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Google Cloud Platform Vertex AI PaLM \n", + "\n", + "Note: This is seperate from the Google PaLM integration. Google has chosen to offer an enterprise version of PaLM through GCP, and this supports the models made available through there. \n", + "\n", + "PaLM API on Vertex AI is a Preview offering, subject to the Pre-GA Offerings Terms of the [GCP Service Specific Terms](https://cloud.google.com/terms/service-terms). \n", + "\n", + "Pre-GA products and features may have limited support, and changes to pre-GA products and features may not be compatible with other pre-GA versions. For more information, see the [launch stage descriptions](https://cloud.google.com/products#product-launch-stages). Further, by using PaLM API on Vertex AI, you agree to the Generative AI Preview [terms and conditions](https://cloud.google.com/trustedtester/aitos) (Preview Terms).\n", + "\n", + "For PaLM API on Vertex AI, you can process personal data as outlined in the Cloud Data Processing Addendum, subject to applicable restrictions and obligations in the Agreement (as defined in the Preview Terms).\n", + "\n", + "To use Vertex AI PaLM you must have the `google-cloud-aiplatform` Python package installed and either:\n", + "- Have credentials configured for your environment (gcloud, workload identity, etc...)\n", + "- Store the path to a service account JSON file as the GOOGLE_APPLICATION_CREDENTIALS environment variable\n", + "\n", + "This codebase uses the `google.auth` library which first looks for the application credentials variable mentioned above, and then looks for system-level auth.\n", + "\n", + "For more information, see: \n", + "- https://cloud.google.com/docs/authentication/application-default-credentials#GAC\n", + "- https://googleapis.dev/python/google-auth/latest/reference/google.auth.html#module-google.auth\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "#!pip install google-cloud-aiplatform" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "from langchain.chat_models import ChatVertexAI\n", + "from langchain.prompts.chat import (\n", + " ChatPromptTemplate,\n", + " SystemMessagePromptTemplate,\n", + " HumanMessagePromptTemplate,\n", + ")\n", + "from langchain.schema import (\n", + " HumanMessage,\n", + " SystemMessage\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "chat = ChatVertexAI()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='Sure, here is the translation of the sentence \"I love programming\" from English to French:\\n\\nJ\\'aime programmer.', additional_kwargs={}, example=False)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "messages = [\n", + " SystemMessage(content=\"You are a helpful assistant that translates English to French.\"),\n", + " HumanMessage(content=\"Translate this sentence from English to French. I love programming.\")\n", + "]\n", + "chat(messages)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can make use of templating by using a `MessagePromptTemplate`. You can build a `ChatPromptTemplate` from one or more `MessagePromptTemplates`. You can use `ChatPromptTemplate`'s `format_prompt` -- this returns a `PromptValue`, which you can convert to a string or Message object, depending on whether you want to use the formatted value as input to an llm or chat model.\n", + "\n", + "For convenience, there is a `from_template` method exposed on the template. If you were to use this template, this is what it would look like:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "template=\"You are a helpful assistant that translates {input_language} to {output_language}.\"\n", + "system_message_prompt = SystemMessagePromptTemplate.from_template(template)\n", + "human_template=\"{text}\"\n", + "human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='Sure, here is the translation of \"I love programming\" in French:\\n\\nJ\\'aime programmer.', additional_kwargs={}, example=False)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])\n", + "\n", + "# get a chat completion from the formatted messages\n", + "chat(chat_prompt.format_prompt(input_language=\"English\", output_language=\"French\", text=\"I love programming.\").to_messages())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + }, + "vscode": { + "interpreter": { + "hash": "cc99336516f23363341912c6723b01ace86f02e26b4290be1efc0677e2e2ec24" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/modules/models/llms/integrations/google_vertex_ai_palm.ipynb b/docs/modules/models/llms/integrations/google_vertex_ai_palm.ipynb new file mode 100644 index 00000000..0000d73a --- /dev/null +++ b/docs/modules/models/llms/integrations/google_vertex_ai_palm.ipynb @@ -0,0 +1,138 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Google Cloud Platform Vertex AI PaLM \n", + "\n", + "Note: This is seperate from the Google PaLM integration. Google has chosen to offer an enterprise version of PaLM through GCP, and this supports the models made available through there. \n", + "\n", + "PaLM API on Vertex AI is a Preview offering, subject to the Pre-GA Offerings Terms of the [GCP Service Specific Terms](https://cloud.google.com/terms/service-terms). \n", + "\n", + "Pre-GA products and features may have limited support, and changes to pre-GA products and features may not be compatible with other pre-GA versions. For more information, see the [launch stage descriptions](https://cloud.google.com/products#product-launch-stages). Further, by using PaLM API on Vertex AI, you agree to the Generative AI Preview [terms and conditions](https://cloud.google.com/trustedtester/aitos) (Preview Terms).\n", + "\n", + "For PaLM API on Vertex AI, you can process personal data as outlined in the Cloud Data Processing Addendum, subject to applicable restrictions and obligations in the Agreement (as defined in the Preview Terms).\n", + "\n", + "To use Vertex AI PaLM you must have the `google-cloud-aiplatform` Python package installed and either:\n", + "- Have credentials configured for your environment (gcloud, workload identity, etc...)\n", + "- Store the path to a service account JSON file as the GOOGLE_APPLICATION_CREDENTIALS environment variable\n", + "\n", + "This codebase uses the `google.auth` library which first looks for the application credentials variable mentioned above, and then looks for system-level auth.\n", + "\n", + "For more information, see: \n", + "- https://cloud.google.com/docs/authentication/application-default-credentials#GAC\n", + "- https://googleapis.dev/python/google-auth/latest/reference/google.auth.html#module-google.auth\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "#!pip install google-cloud-aiplatform" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "from langchain.llms import VertexAI\n", + "from langchain import PromptTemplate, LLMChain" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "template = \"\"\"Question: {question}\n", + "\n", + "Answer: Let's think step by step.\"\"\"\n", + "\n", + "prompt = PromptTemplate(template=template, input_variables=[\"question\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "llm = VertexAI()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "llm_chain = LLMChain(prompt=prompt, llm=llm)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'Justin Bieber was born on March 1, 1994. The Super Bowl in 1994 was won by the San Francisco 49ers.\\nThe final answer: San Francisco 49ers.'" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n", + "\n", + "llm_chain.run(question)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + }, + "vscode": { + "interpreter": { + "hash": "cc99336516f23363341912c6723b01ace86f02e26b4290be1efc0677e2e2ec24" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/modules/models/text_embedding/examples/google_vertex_ai_palm.ipynb b/docs/modules/models/text_embedding/examples/google_vertex_ai_palm.ipynb new file mode 100644 index 00000000..ed40ca3f --- /dev/null +++ b/docs/modules/models/text_embedding/examples/google_vertex_ai_palm.ipynb @@ -0,0 +1,113 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Google Cloud Platform Vertex AI PaLM \n", + "\n", + "Note: This is seperate from the Google PaLM integration. Google has chosen to offer an enterprise version of PaLM through GCP, and this supports the models made available through there. \n", + "\n", + "PaLM API on Vertex AI is a Preview offering, subject to the Pre-GA Offerings Terms of the [GCP Service Specific Terms](https://cloud.google.com/terms/service-terms). \n", + "\n", + "Pre-GA products and features may have limited support, and changes to pre-GA products and features may not be compatible with other pre-GA versions. For more information, see the [launch stage descriptions](https://cloud.google.com/products#product-launch-stages). Further, by using PaLM API on Vertex AI, you agree to the Generative AI Preview [terms and conditions](https://cloud.google.com/trustedtester/aitos) (Preview Terms).\n", + "\n", + "For PaLM API on Vertex AI, you can process personal data as outlined in the Cloud Data Processing Addendum, subject to applicable restrictions and obligations in the Agreement (as defined in the Preview Terms).\n", + "\n", + "To use Vertex AI PaLM you must have the `google-cloud-aiplatform` Python package installed and either:\n", + "- Have credentials configured for your environment (gcloud, workload identity, etc...)\n", + "- Store the path to a service account JSON file as the GOOGLE_APPLICATION_CREDENTIALS environment variable\n", + "\n", + "This codebase uses the `google.auth` library which first looks for the application credentials variable mentioned above, and then looks for system-level auth.\n", + "\n", + "For more information, see: \n", + "- https://cloud.google.com/docs/authentication/application-default-credentials#GAC\n", + "- https://googleapis.dev/python/google-auth/latest/reference/google.auth.html#module-google.auth\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "#!pip install google-cloud-aiplatform" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "from langchain.embeddings import VertexAIEmbeddings" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "embeddings = VertexAIEmbeddings()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "text = \"This is a test document.\"" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "query_result = embeddings.embed_query(text)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "doc_result = embeddings.embed_documents([text])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + }, + "vscode": { + "interpreter": { + "hash": "cc99336516f23363341912c6723b01ace86f02e26b4290be1efc0677e2e2ec24" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/langchain/chat_models/__init__.py b/langchain/chat_models/__init__.py index 11322b04..0c048ec3 100644 --- a/langchain/chat_models/__init__.py +++ b/langchain/chat_models/__init__.py @@ -3,6 +3,7 @@ from langchain.chat_models.azure_openai import AzureChatOpenAI from langchain.chat_models.google_palm import ChatGooglePalm from langchain.chat_models.openai import ChatOpenAI from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI +from langchain.chat_models.vertexai import ChatVertexAI __all__ = [ "ChatOpenAI", @@ -10,4 +11,5 @@ __all__ = [ "PromptLayerChatOpenAI", "ChatAnthropic", "ChatGooglePalm", + "ChatVertexAI", ] diff --git a/langchain/chat_models/vertexai.py b/langchain/chat_models/vertexai.py new file mode 100644 index 00000000..87cf0ee9 --- /dev/null +++ b/langchain/chat_models/vertexai.py @@ -0,0 +1,137 @@ +"""Wrapper around Google VertexAI chat-based models.""" +from dataclasses import dataclass, field +from typing import Dict, List, Optional + +from pydantic import root_validator + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain.chat_models.base import BaseChatModel +from langchain.llms.vertexai import _VertexAICommon +from langchain.schema import ( + AIMessage, + BaseMessage, + ChatGeneration, + ChatResult, + HumanMessage, + SystemMessage, +) +from langchain.utilities.vertexai import raise_vertex_import_error + + +@dataclass +class _MessagePair: + """InputOutputTextPair represents a pair of input and output texts.""" + + question: HumanMessage + answer: AIMessage + + +@dataclass +class _ChatHistory: + """InputOutputTextPair represents a pair of input and output texts.""" + + history: List[_MessagePair] = field(default_factory=list) + system_message: Optional[SystemMessage] = None + + +def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory: + """Parse a sequence of messages into history. + + A sequence should be either (SystemMessage, HumanMessage, AIMessage, + HumanMessage, AIMessage, ...) or (HumanMessage, AIMessage, HumanMessage, + AIMessage, ...). + + Args: + history: The list of messages to re-create the history of the chat. + Returns: + A parsed chat history. + Raises: + ValueError: If a sequence of message is odd, or a human message is not followed + by a message from AI (e.g., Human, Human, AI or AI, AI, Human). + """ + if not history: + return _ChatHistory() + first_message = history[0] + system_message = first_message if isinstance(first_message, SystemMessage) else None + chat_history = _ChatHistory(system_message=system_message) + messages_left = history[1:] if system_message else history + if len(messages_left) % 2 != 0: + raise ValueError( + f"Amount of messages in history should be even, got {len(messages_left)}!" + ) + for question, answer in zip(messages_left[::2], messages_left[1::2]): + if not isinstance(question, HumanMessage) or not isinstance(answer, AIMessage): + raise ValueError( + "A human message should follow a bot one, " + f"got {question.type}, {answer.type}." + ) + chat_history.history.append(_MessagePair(question=question, answer=answer)) + return chat_history + + +class ChatVertexAI(_VertexAICommon, BaseChatModel): + """Wrapper around Vertex AI large language models.""" + + model_name: str = "chat-bison" + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that the python package exists in environment.""" + cls._try_init_vertexai(values) + try: + from vertexai.preview.language_models import ChatModel + except ImportError: + raise_vertex_import_error() + values["client"] = ChatModel.from_pretrained(values["model_name"]) + return values + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> ChatResult: + """Generate next turn in the conversation. + + Args: + messages: The history of the conversation as a list of messages. + stop: The list of stop words (optional). + run_manager: The Callbackmanager for LLM run, it's not used at the moment. + + Returns: + The ChatResult that contains outputs generated by the model. + + Raises: + ValueError: if the last message in the list is not from human. + """ + if not messages: + raise ValueError( + "You should provide at least one message to start the chat!" + ) + question = messages[-1] + if not isinstance(question, HumanMessage): + raise ValueError( + f"Last message in the list should be from human, got {question.type}." + ) + + history = _parse_chat_history(messages[:-1]) + context = history.system_message.content if history.system_message else None + chat = self.client.start_chat(context=context, **self._default_params) + for pair in history.history: + chat._history.append((pair.question.content, pair.answer.content)) + response = chat.send_message(question.content) + text = self._enforce_stop_words(response.text, stop) + return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))]) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + ) -> ChatResult: + raise NotImplementedError( + """Vertex AI doesn't support async requests at the moment.""" + ) diff --git a/langchain/embeddings/__init__.py b/langchain/embeddings/__init__.py index 2ae59266..d0db6078 100644 --- a/langchain/embeddings/__init__.py +++ b/langchain/embeddings/__init__.py @@ -28,6 +28,7 @@ from langchain.embeddings.self_hosted_hugging_face import ( ) from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings from langchain.embeddings.tensorflow_hub import TensorflowHubEmbeddings +from langchain.embeddings.vertexai import VertexAIEmbeddings logger = logging.getLogger(__name__) @@ -52,6 +53,7 @@ __all__ = [ "AlephAlphaSymmetricSemanticEmbedding", "SentenceTransformerEmbeddings", "GooglePalmEmbeddings", + "VertexAIEmbeddings", ] diff --git a/langchain/embeddings/vertexai.py b/langchain/embeddings/vertexai.py new file mode 100644 index 00000000..0730a5fc --- /dev/null +++ b/langchain/embeddings/vertexai.py @@ -0,0 +1,47 @@ +"""Wrapper around Google VertexAI embedding models.""" +from typing import Dict, List + +from pydantic import root_validator + +from langchain.embeddings.base import Embeddings +from langchain.llms.vertexai import _VertexAICommon +from langchain.utilities.vertexai import raise_vertex_import_error + + +class VertexAIEmbeddings(_VertexAICommon, Embeddings): + model_name: str = "textembedding-gecko" + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validates that the python package exists in environment.""" + cls._try_init_vertexai(values) + try: + from vertexai.preview.language_models import TextEmbeddingModel + except ImportError: + raise_vertex_import_error() + values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"]) + return values + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed a list of strings. + + Args: + texts: List[str] The list of strings to embed. + + Returns: + List of embeddings, one for each text. + """ + embeddings = self.client.get_embeddings(texts) + return [el.values for el in embeddings] + + def embed_query(self, text: str) -> List[float]: + """Embed a text. + + Args: + text: The text to embed. + + Returns: + Embedding for the text. + """ + embeddings = self.client.get_embeddings([text]) + return embeddings[0].values diff --git a/langchain/llms/__init__.py b/langchain/llms/__init__.py index b3454394..6786f8cd 100644 --- a/langchain/llms/__init__.py +++ b/langchain/llms/__init__.py @@ -37,6 +37,7 @@ from langchain.llms.sagemaker_endpoint import SagemakerEndpoint from langchain.llms.self_hosted import SelfHostedPipeline from langchain.llms.self_hosted_hugging_face import SelfHostedHuggingFaceLLM from langchain.llms.stochasticai import StochasticAI +from langchain.llms.vertexai import VertexAI from langchain.llms.writer import Writer __all__ = [ @@ -79,6 +80,7 @@ __all__ = [ "HumanInputLLM", "HuggingFaceTextGenInference", "FakeListLLM", + "VertexAI", ] type_to_cls_dict: Dict[str, Type[BaseLLM]] = { @@ -117,4 +119,5 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = { "rwkv": RWKV, "huggingface_textgen_inference": HuggingFaceTextGenInference, "fake-list": FakeListLLM, + "vertexai": VertexAI, } diff --git a/langchain/llms/vertexai.py b/langchain/llms/vertexai.py new file mode 100644 index 00000000..3c006f4f --- /dev/null +++ b/langchain/llms/vertexai.py @@ -0,0 +1,110 @@ +"""Wrapper around Google VertexAI models.""" +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from pydantic import BaseModel, root_validator + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.llms.base import LLM +from langchain.llms.utils import enforce_stop_tokens +from langchain.utilities.vertexai import ( + init_vertexai, + raise_vertex_import_error, +) + +if TYPE_CHECKING: + from google.auth.credentials import Credentials + from vertexai.language_models._language_models import _LanguageModel + + +class _VertexAICommon(BaseModel): + client: "_LanguageModel" = None #: :meta private: + model_name: str + "Model name to use." + temperature: float = 0.0 + "Sampling temperature, it controls the degree of randomness in token selection." + max_output_tokens: int = 128 + "Token limit determines the maximum amount of text output from one prompt." + top_p: float = 0.95 + "Tokens are selected from most probable to least until the sum of their " + "probabilities equals the top-p value." + top_k: int = 40 + "How the model selects tokens for output, the next token is selected from " + "among the top-k most probable tokens." + project: Optional[str] = None + "The default GCP project to use when making Vertex API calls." + location: str = "us-central1" + "The default location to use when making API calls." + credentials: Optional["Credentials"] = None + "The default custom credentials to use when making API calls. If not provided " + "credentials will be ascertained from the environment." "" + + @property + def _default_params(self) -> Dict[str, Any]: + base_params = { + "temperature": self.temperature, + "max_output_tokens": self.max_output_tokens, + "top_k": self.top_p, + "top_p": self.top_k, + } + return {**base_params} + + def _predict(self, prompt: str, stop: Optional[List[str]]) -> str: + res = self.client.predict(prompt, **self._default_params) + return self._enforce_stop_words(res.text, stop) + + def _enforce_stop_words(self, text: str, stop: Optional[List[str]]) -> str: + if stop: + return enforce_stop_tokens(text, stop) + return text + + @property + def _llm_type(self) -> str: + return "vertexai" + + @classmethod + def _try_init_vertexai(cls, values: Dict) -> None: + allowed_params = ["project", "location", "credentials"] + params = {k: v for k, v in values.items() if v in allowed_params} + init_vertexai(**params) + return None + + +class VertexAI(_VertexAICommon, LLM): + """Wrapper around Google Vertex AI large language models.""" + + model_name: str = "text-bison" + tuned_model_name: Optional[str] = None + "The name of a tuned model, if it's provided, model_name is ignored." + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that the python package exists in environment.""" + cls._try_init_vertexai(values) + try: + from vertexai.preview.language_models import TextGenerationModel + except ImportError: + raise_vertex_import_error() + tuned_model_name = values.get("tuned_model_name") + if tuned_model_name: + values["client"] = TextGenerationModel.get_tuned_model(tuned_model_name) + else: + values["client"] = TextGenerationModel.from_pretrained(values["model_name"]) + return values + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: + """Call Vertex model to get predictions based on the prompt. + + Args: + prompt: The prompt to pass into the model. + stop: A list of stop words (optional). + run_manager: A Callbackmanager for LLM run, optional. + + Returns: + The string generated by the model. + """ + return self._predict(prompt, stop) diff --git a/langchain/utilities/vertexai.py b/langchain/utilities/vertexai.py new file mode 100644 index 00000000..68455873 --- /dev/null +++ b/langchain/utilities/vertexai.py @@ -0,0 +1,46 @@ +"""Utilities to init Vertex AI.""" +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from google.auth.credentials import Credentials + + +def raise_vertex_import_error() -> None: + """Raise ImportError related to Vertex SDK being not available. + + Raises: + ImportError: an ImportError that mentions a required version of the SDK. + """ + sdk = "'google-cloud-aiplatform>=1.25.0'" + raise ImportError( + "Could not import VertexAI. Please, install it with " f"pip install {sdk}" + ) + + +def init_vertexai( + project_id: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional["Credentials"] = None, +) -> None: + """Init vertexai. + + Args: + project: The default GCP project to use when making Vertex API calls. + location: The default location to use when making API calls. + credentials: The default custom + credentials to use when making API calls. If not provided credentials + will be ascertained from the environment. + + Raises: + ImportError: If importing vertexai SDK didn't not succeed. + """ + try: + import vertexai + except ImportError: + raise_vertex_import_error() + + vertexai.init( + project=project_id, + location=location, + credentials=credentials, + ) diff --git a/tests/integration_tests/chat_models/test_vertexai.py b/tests/integration_tests/chat_models/test_vertexai.py new file mode 100644 index 00000000..10b9c698 --- /dev/null +++ b/tests/integration_tests/chat_models/test_vertexai.py @@ -0,0 +1,88 @@ +"""Test Vertex AI API wrapper. +In order to run this test, you need to install VertexAI SDK (that is is the private +preview) and be whitelisted to list the models themselves: +In order to run this test, you need to install VertexAI SDK +pip install google-cloud-aiplatform>=1.25.0 + +Your end-user credentials would be used to make the calls (make sure you've run +`gcloud auth login` first). +""" +import pytest + +from langchain.chat_models import ChatVertexAI +from langchain.chat_models.vertexai import _MessagePair, _parse_chat_history +from langchain.schema import ( + AIMessage, + HumanMessage, + SystemMessage, +) + + +def test_vertexai_single_call() -> None: + model = ChatVertexAI() + message = HumanMessage(content="Hello") + response = model([message]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + assert model._llm_type == "vertexai" + assert model.model_name == model.client._model_id + + +def test_vertexai_single_call_with_context() -> None: + model = ChatVertexAI() + raw_context = ( + "My name is Ned. You are my personal assistant. My favorite movies " + "are Lord of the Rings and Hobbit." + ) + question = ( + "Hello, could you recommend a good movie for me to watch this evening, please?" + ) + context = SystemMessage(content=raw_context) + message = HumanMessage(content=question) + response = model([context, message]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + + +def test_parse_chat_history_correct() -> None: + text_context = ( + "My name is Ned. You are my personal assistant. My " + "favorite movies are Lord of the Rings and Hobbit." + ) + context = SystemMessage(content=text_context) + text_question = ( + "Hello, could you recommend a good movie for me to watch this evening, please?" + ) + question = HumanMessage(content=text_question) + text_answer = ( + "Sure, You might enjoy The Lord of the Rings: The Fellowship of the Ring " + "(2001): This is the first movie in the Lord of the Rings trilogy." + ) + answer = AIMessage(content=text_answer) + history = _parse_chat_history([context, question, answer, question, answer]) + assert history.system_message == context + assert len(history.history) == 2 + assert history.history[0] == _MessagePair(question=question, answer=answer) + + +def test_parse_chat_history_wrong_sequence() -> None: + text_question = ( + "Hello, could you recommend a good movie for me to watch this evening, please?" + ) + question = HumanMessage(content=text_question) + with pytest.raises(ValueError) as exc_info: + _ = _parse_chat_history([question, question]) + assert ( + str(exc_info.value) + == "A human message should follow a bot one, got human, human." + ) + + +def test_vertexai_single_call_failes_no_message() -> None: + chat = ChatVertexAI() + with pytest.raises(ValueError) as exc_info: + _ = chat([]) + assert ( + str(exc_info.value) + == "You should provide at least one message to start the chat!" + ) diff --git a/tests/integration_tests/embeddings/test_vertexai.py b/tests/integration_tests/embeddings/test_vertexai.py new file mode 100644 index 00000000..ce389f89 --- /dev/null +++ b/tests/integration_tests/embeddings/test_vertexai.py @@ -0,0 +1,25 @@ +"""Test Vertex AI API wrapper. +In order to run this test, you need to install VertexAI SDK +pip install google-cloud-aiplatform>=1.25.0 + +Your end-user credentials would be used to make the calls (make sure you've run +`gcloud auth login` first). +""" +from langchain.embeddings import VertexAIEmbeddings + + +def test_embedding_documents() -> None: + documents = ["foo bar"] + model = VertexAIEmbeddings() + output = model.embed_documents(documents) + assert len(output) == 1 + assert len(output[0]) == 768 + assert model._llm_type == "vertexai" + assert model.model_name == model.client._model_id + + +def test_embedding_query() -> None: + document = "foo bar" + model = VertexAIEmbeddings() + output = model.embed_query(document) + assert len(output) == 768 diff --git a/tests/integration_tests/llms/test_vertexai.py b/tests/integration_tests/llms/test_vertexai.py new file mode 100644 index 00000000..8ee7836b --- /dev/null +++ b/tests/integration_tests/llms/test_vertexai.py @@ -0,0 +1,18 @@ +"""Test Vertex AI API wrapper. +In order to run this test, you need to install VertexAI SDK (that is is the private +preview) and be whitelisted to list the models themselves: +In order to run this test, you need to install VertexAI SDK +pip install google-cloud-aiplatform>=1.25.0 + +Your end-user credentials would be used to make the calls (make sure you've run +`gcloud auth login` first). +""" +from langchain.llms import VertexAI + + +def test_vertex_call() -> None: + llm = VertexAI() + output = llm("Say foo:") + assert isinstance(output, str) + assert llm._llm_type == "vertexai" + assert llm.model_name == llm.client._model_id