add vertex prod features (#10910)

- chat vertex async
- vertex stream
- vertex full generation info
- vertex use server-side stopping
- model garden async
- update docs for all the above

in follow up will add
[] chat vertex full generation info
[] chat vertex retries
[] scheduled tests
pull/10928/head
Bagatur 10 months ago committed by GitHub
parent dccc20b402
commit cab55e9bc1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,7 +5,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Google Cloud Platform Vertex AI PaLM \n",
"# GCP Vertex AI \n",
"\n",
"Note: This is seperate from the Google PaLM integration. Google has chosen to offer an enterprise version of PaLM through GCP, and this supports the models made available through there. \n",
"\n",
@ -31,7 +31,7 @@
},
"outputs": [],
"source": [
"#!pip install google-cloud-aiplatform"
"#!pip install langchain google-cloud-aiplatform"
]
},
{
@ -41,12 +41,7 @@
"outputs": [],
"source": [
"from langchain.chat_models import ChatVertexAI\n",
"from langchain.prompts.chat import (\n",
" ChatPromptTemplate,\n",
" SystemMessagePromptTemplate,\n",
" HumanMessagePromptTemplate,\n",
")\n",
"from langchain.schema import HumanMessage, SystemMessage"
"from langchain.prompts import ChatPromptTemplate"
]
},
{
@ -60,82 +55,78 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"system = \"You are a helpful assistant who translate English to French\"\n",
"human = \"Translate this sentence from English to French. I love programming.\"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [(\"system\", system), (\"human\", human)]\n",
")\n",
"messages = prompt.format_messages()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='Sure, here is the translation of the sentence \"I love programming\" from English to French:\\n\\nJ\\'aime programmer.', additional_kwargs={}, example=False)"
"AIMessage(content=\" J'aime la programmation.\", additional_kwargs={}, example=False)"
]
},
"execution_count": 4,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = [\n",
" SystemMessage(\n",
" content=\"You are a helpful assistant that translates English to French.\"\n",
" ),\n",
" HumanMessage(\n",
" content=\"Translate this sentence from English to French. I love programming.\"\n",
" ),\n",
"]\n",
"chat(messages)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"You can make use of templating by using a `MessagePromptTemplate`. You can build a `ChatPromptTemplate` from one or more `MessagePromptTemplates`. You can use `ChatPromptTemplate`'s `format_prompt` -- this returns a `PromptValue`, which you can convert to a string or Message object, depending on whether you want to use the formatted value as input to an llm or chat model.\n",
"\n",
"For convenience, there is a `from_template` method exposed on the template. If you were to use this template, this is what it would look like:"
"If we want to construct a simple chain that takes user specified parameters:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"template = (\n",
" \"You are a helpful assistant that translates {input_language} to {output_language}.\"\n",
")\n",
"system_message_prompt = SystemMessagePromptTemplate.from_template(template)\n",
"human_template = \"{text}\"\n",
"human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)"
"system = \"You are a helpful assistant that translates {input_language} to {output_language}.\"\n",
"human = \"{text}\"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [(\"system\", system), (\"human\", human)]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='Sure, here is the translation of \"I love programming\" in French:\\n\\nJ\\'aime programmer.', additional_kwargs={}, example=False)"
"AIMessage(content=' 私はプログラミングが大好きです。', additional_kwargs={}, example=False)"
]
},
"execution_count": 7,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chat_prompt = ChatPromptTemplate.from_messages(\n",
" [system_message_prompt, human_message_prompt]\n",
")\n",
"\n",
"# get a chat completion from the formatted messages\n",
"chat(\n",
" chat_prompt.format_prompt(\n",
" input_language=\"English\", output_language=\"French\", text=\"I love programming.\"\n",
" ).to_messages()\n",
"chain = prompt | chat\n",
"chain.invoke(\n",
" {\"input_language\": \"English\", \"output_language\": \"Japanese\", \"text\": \"I love programming\"}\n",
")"
]
},
@ -153,60 +144,129 @@
"tags": []
},
"source": [
"## Code generation chat models\n",
"You can now leverage the Codey API for code chat within Vertex AI. The model name is:\n",
"- codechat-bison: for code assistance"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 18,
"metadata": {
"execution": {
"iopub.execute_input": "2023-06-17T21:30:43.974841Z",
"iopub.status.busy": "2023-06-17T21:30:43.974431Z",
"iopub.status.idle": "2023-06-17T21:30:44.248119Z",
"shell.execute_reply": "2023-06-17T21:30:44.247362Z",
"shell.execute_reply.started": "2023-06-17T21:30:43.974820Z"
},
"tags": []
},
"outputs": [],
"source": [
"chat = ChatVertexAI(model_name=\"codechat-bison\")"
"chat = ChatVertexAI(\n",
" model_name=\"codechat-bison\",\n",
" max_output_tokens=1000,\n",
" temperature=0.5\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 20,
"metadata": {
"execution": {
"iopub.execute_input": "2023-06-17T21:30:45.146093Z",
"iopub.status.busy": "2023-06-17T21:30:45.145752Z",
"iopub.status.idle": "2023-06-17T21:30:47.449126Z",
"shell.execute_reply": "2023-06-17T21:30:47.448609Z",
"shell.execute_reply.started": "2023-06-17T21:30:45.146069Z"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" ```python\n",
"def is_prime(x): \n",
" if (x <= 1): \n",
" return False\n",
" for i in range(2, x): \n",
" if (x % i == 0): \n",
" return False\n",
" return True\n",
"```\n"
]
}
],
"source": [
"# For simple string in string out usage, we can use the `predict` method:\n",
"print(chat.predict(\"Write a Python function to identify all prime numbers\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Asynchronous calls\n",
"\n",
"We can make asynchronous calls via the `agenerate` and `ainvoke` methods."
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"import asyncio\n",
"# import nest_asyncio\n",
"# nest_asyncio.apply()"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='The following Python function can be used to identify all prime numbers up to a given integer:\\n\\n```\\ndef is_prime(n):\\n \"\"\"\\n Determines whether the given integer is prime.\\n\\n Args:\\n n: The integer to be tested for primality.\\n\\n Returns:\\n True if n is prime, False otherwise.\\n \"\"\"\\n\\n # Check if n is divisible by 2.\\n if n % 2 == 0:\\n return False\\n\\n # Check if n is divisible by any integer from 3 to the square root', additional_kwargs={}, example=False)"
"LLMResult(generations=[[ChatGeneration(text=\" J'aime la programmation.\", generation_info=None, message=AIMessage(content=\" J'aime la programmation.\", additional_kwargs={}, example=False))]], llm_output={}, run=[RunInfo(run_id=UUID('223599ef-38f8-4c79-ac6d-a5013060eb9d'))])"
]
},
"execution_count": 4,
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = [\n",
" HumanMessage(\n",
" content=\"How do I create a python function to identify all prime numbers?\"\n",
" )\n",
"]\n",
"chat(messages)"
"chat = ChatVertexAI(\n",
" model_name=\"chat-bison\",\n",
" max_output_tokens=1000,\n",
" temperature=0.7,\n",
" top_p=0.95,\n",
" top_k=40,\n",
")\n",
"\n",
"asyncio.run(chat.agenerate([messages]))"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=' अहं प्रोग्रामिंग प्रेमामि', additional_kwargs={}, example=False)"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"asyncio.run(chain.ainvoke({\"input_language\": \"English\", \"output_language\": \"Sanskrit\", \"text\": \"I love programming\"}))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Streaming calls\n",
"\n",
"We can also stream outputs via the `stream` method:"
]
},
{
@ -214,14 +274,51 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"import sys"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 1. China (1,444,216,107)\n",
"2. India (1,393,409,038)\n",
"3. United States (332,403,650)\n",
"4. Indonesia (273,523,615)\n",
"5. Pakistan (220,892,340)\n",
"6. Brazil (212,559,409)\n",
"7. Nigeria (206,139,589)\n",
"8. Bangladesh (164,689,383)\n",
"9. Russia (145,934,462)\n",
"10. Mexico (128,932,488)\n",
"11. Japan (126,476,461)\n",
"12. Ethiopia (115,063,982)\n",
"13. Philippines (109,581,078)\n",
"14. Egypt (102,334,404)\n",
"15. Vietnam (97,338,589)"
]
}
],
"source": [
"prompt = ChatPromptTemplate.from_messages([(\"human\", \"List out the 15 most populous countries in the world\")])\n",
"messages = prompt.format_messages()\n",
"for chunk in chat.stream(messages):\n",
" sys.stdout.write(chunk.content)\n",
" sys.stdout.flush()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "poetry-venv",
"language": "python",
"name": "python3"
"name": "poetry-venv"
},
"language_info": {
"codemirror_mode": {

@ -26,7 +26,7 @@ ChatLiteLLM|✅|✅|✅|✅
ChatMLflowAIGateway|✅|❌|❌|❌
ChatOllama|✅|❌|✅|❌
ChatOpenAI|✅|✅|✅|✅
ChatVertexAI|✅||✅|❌
ChatVertexAI|✅||✅|❌
ErnieBotChat|✅|❌|❌|❌
JinaChat|✅|✅|✅|✅
MiniMaxChat|✅|✅|❌|❌

@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Google Vertex AI PaLM \n",
"# GCP Vertex AI\n",
"\n",
"**Note:** This is separate from the `Google PaLM` integration, it exposes [Vertex AI PaLM API](https://cloud.google.com/vertex-ai/docs/generative-ai/learn/overview) on `Google Cloud`. \n"
]
@ -41,12 +41,12 @@
},
"outputs": [],
"source": [
"#!pip install google-cloud-aiplatform"
"#!pip install langchain google-cloud-aiplatform"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@ -54,41 +54,55 @@
]
},
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Python is a widely used, interpreted, object-oriented, and high-level programming language with dynamic semantics, used for general-purpose programming. It is known for its readability, simplicity, and versatility. Here are some of the pros and cons of Python:\n",
"\n",
"**Pros:**\n",
"\n",
"- **Easy to learn:** Python is known for its simple and intuitive syntax, making it easy for beginners to learn. It has a relatively shallow learning curve compared to other programming languages.\n",
"\n",
"- **Versatile:** Python is a general-purpose programming language, meaning it can be used for a wide variety of tasks, including web development, data science, machine\n"
]
}
],
"source": [
"## Question-answering example"
"llm = VertexAI()\n",
"print(llm(\"What are some of the pros and cons of Python as a programming language?\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"cell_type": "markdown",
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts import PromptTemplate\nfrom langchain.chains import LLMChain"
"## Using in a chain"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"Question: {question}\n",
"\n",
"Answer: Let's think step by step.\"\"\"\n",
"\n",
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
"from langchain.prompts import PromptTemplate"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"llm = VertexAI()"
"template = \"\"\"Question: {question}\n",
"\n",
"Answer: Let's think step by step.\"\"\"\n",
"prompt = PromptTemplate.from_template(template)"
]
},
{
@ -97,29 +111,26 @@
"metadata": {},
"outputs": [],
"source": [
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
"chain = prompt | llm"
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Justin Bieber was born on March 1, 1994. The Super Bowl in 1994 was won by the San Francisco 49ers.\\nThe final answer: San Francisco 49ers.'"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
"name": "stdout",
"output_type": "stream",
"text": [
" Justin Bieber was born on March 1, 1994. Bill Clinton was the president of the United States from January 20, 1993, to January 20, 2001.\n",
"The final answer is Bill Clinton\n"
]
}
],
"source": [
"question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n",
"\n",
"llm_chain.run(question)"
"question = \"Who was the president in the year Justin Beiber was born?\"\n",
"print(chain.invoke({\"question\": question}))"
]
},
{
@ -142,76 +153,198 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 15,
"metadata": {
"execution": {
"iopub.execute_input": "2023-06-17T21:16:53.149438Z",
"iopub.status.busy": "2023-06-17T21:16:53.149065Z",
"iopub.status.idle": "2023-06-17T21:16:53.421824Z",
"shell.execute_reply": "2023-06-17T21:16:53.421136Z",
"shell.execute_reply.started": "2023-06-17T21:16:53.149415Z"
},
"tags": []
},
"outputs": [],
"source": [
"llm = VertexAI(model_name=\"code-bison\")"
"llm = VertexAI(model_name=\"code-bison\", max_output_tokens=1000, temperature=0.3)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"execution": {
"iopub.execute_input": "2023-06-17T21:17:11.179077Z",
"iopub.status.busy": "2023-06-17T21:17:11.178686Z",
"iopub.status.idle": "2023-06-17T21:17:11.182499Z",
"shell.execute_reply": "2023-06-17T21:17:11.181895Z",
"shell.execute_reply.started": "2023-06-17T21:17:11.179052Z"
},
"tags": []
},
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
"question = \"Write a python function that checks if a string is a valid email address\""
]
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 19,
"metadata": {
"execution": {
"iopub.execute_input": "2023-06-17T21:18:47.024785Z",
"iopub.status.busy": "2023-06-17T21:18:47.024230Z",
"iopub.status.idle": "2023-06-17T21:18:49.352249Z",
"shell.execute_reply": "2023-06-17T21:18:49.351695Z",
"shell.execute_reply.started": "2023-06-17T21:18:47.024762Z"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"```python\n",
"import re\n",
"\n",
"def is_valid_email(email):\n",
" pattern = re.compile(r\"[^@]+@[^@]+\\.[^@]+\")\n",
" return pattern.match(email)\n",
"```\n"
]
}
],
"source": [
"print(llm(question))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Full generation info\n",
"\n",
"We can use the `generate` method to get back extra metadata like [safety attributes](https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_attribute_confidence_scoring) and not just text completions"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[[GenerationChunk(text='```python\\nimport re\\n\\ndef is_valid_email(email):\\n pattern = re.compile(r\"[^@]+@[^@]+\\\\.[^@]+\")\\n return pattern.match(email)\\n```', generation_info={'is_blocked': False, 'safety_attributes': {'Health': 0.1}})]]"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result = llm.generate([question])\n",
"result.generations"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Asynchronous calls\n",
"\n",
"With `agenerate` we can make asynchronous calls"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# If running in a Jupyter notebook you'll need to install nest_asyncio\n",
"\n",
"# !pip install nest_asyncio"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"import asyncio\n",
"# import nest_asyncio\n",
"# nest_asyncio.apply()"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'```python\\ndef is_prime(n):\\n \"\"\"\\n Determines if a number is prime.\\n\\n Args:\\n n: The number to be tested.\\n\\n Returns:\\n True if the number is prime, False otherwise.\\n \"\"\"\\n\\n # Check if the number is 1.\\n if n == 1:\\n return False\\n\\n # Check if the number is 2.\\n if n == 2:\\n return True\\n\\n'"
"LLMResult(generations=[[GenerationChunk(text='```python\\nimport re\\n\\ndef is_valid_email(email):\\n pattern = re.compile(r\"[^@]+@[^@]+\\\\.[^@]+\")\\n return pattern.match(email)\\n```', generation_info={'is_blocked': False, 'safety_attributes': {'Health': 0.1}})]], llm_output=None, run=[RunInfo(run_id=UUID('caf74e91-aefb-48ac-8031-0c505fcbbcc6'))])"
]
},
"execution_count": 15,
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"question = \"Write a python function that identifies if the number is a prime number?\"\n",
"asyncio.run(llm.agenerate([question]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Streaming calls\n",
"\n",
"llm_chain.run(question)"
"With `stream` we can stream results from the model"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"import sys"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"```python\n",
"import re\n",
"\n",
"def is_valid_email(email):\n",
" \"\"\"\n",
" Checks if a string is a valid email address.\n",
"\n",
" Args:\n",
" email: The string to check.\n",
"\n",
" Returns:\n",
" True if the string is a valid email address, False otherwise.\n",
" \"\"\"\n",
"\n",
" # Check for a valid email address format.\n",
" if not re.match(r\"^[A-Za-z0-9\\.\\+_-]+@[A-Za-z0-9\\._-]+\\.[a-zA-Z]*$\", email):\n",
" return False\n",
"\n",
" # Check if the domain name exists.\n",
" try:\n",
" domain = email.split(\"@\")[1]\n",
" socket.gethostbyname(domain)\n",
" except socket.gaierror:\n",
" return False\n",
"\n",
" return True\n",
"```"
]
}
],
"source": [
"for chunk in llm.stream(question):\n",
" sys.stdout.write(chunk)\n",
" sys.stdout.flush()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using models deployed on Vertex Model Garden"
"## Vertex Model Garden"
]
},
{
@ -248,7 +381,7 @@
"metadata": {},
"outputs": [],
"source": [
"llm(\"What is the meaning of life?\")"
"print(llm(\"What is the meaning of life?\"))"
]
},
{
@ -264,8 +397,6 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts import PromptTemplate\n",
"\n",
"prompt = PromptTemplate.from_template(\"What is the meaning of {thing}?\")"
]
},
@ -275,9 +406,8 @@
"metadata": {},
"outputs": [],
"source": [
"llm_oss_chain = prompt | llm\n",
"\n",
"llm_oss_chain.invoke({\"thing\": \"life\"})"
"chian = prompt | llm\n",
"print(chain.invoke({\"thing\": \"life\"}))"
]
}
],

@ -83,8 +83,8 @@ TitanTakeoff|✅|❌|✅|❌|❌|❌
Tongyi|✅|❌|❌|❌|❌|❌
VLLM|✅|❌|❌|❌|✅|❌
VLLMOpenAI|✅|✅|✅|✅|✅|✅
VertexAI|✅|✅|❌|❌|❌|❌
VertexAIModelGarden|✅|✅|❌|❌|❌|❌
VertexAI|✅|✅|✅|❌|✅|✅
VertexAIModelGarden|✅|✅|❌|❌|✅|✅
Writer|✅|❌|❌|❌|❌|❌
Xinference|✅|❌|❌|❌|❌|❌

@ -2,6 +2,35 @@
All functionality related to Google Platform
## LLMs
### Vertex AI
Access PaLM LLMs like `text-bison` and `code-bison` via Google Cloud.
```python
from langchain.llms import VertexAI
```
### Model Garden
Access PaLM and hundreds of OSS models via Vertex AI Model Garden.
```python
from langchain.llms import VertexAIModelGarden
```
## Chat models
### Vertex AI
Access PaLM chat models like `chat-bison` and `codechat-bison` via Google Cloud.
```python
from langchain.chat_models import ChatVertexAI
```
## Document Loader
### Google BigQuery

@ -1,10 +1,14 @@
"""Wrapper around Google VertexAI chat-based models."""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel, _generate_from_stream
from langchain.llms.vertexai import _VertexAICommon, is_codey_model
from langchain.pydantic_v1 import root_validator
@ -30,6 +34,8 @@ if TYPE_CHECKING:
InputOutputTextPair,
)
logger = logging.getLogger(__name__)
@dataclass
class _ChatHistory:
@ -116,7 +122,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
"""`Vertex AI` Chat large language models API."""
model_name: str = "chat-bison"
streaming: bool = False
"Underlying model name."
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
@ -177,6 +183,42 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
text = self._enforce_stop_words(response.text, stop)
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Asynchronously generate next turn in the conversation.
Args:
messages: The history of the conversation as a list of messages. Code chat
does not support context.
stop: The list of stop words (optional).
run_manager: The CallbackManager for LLM run, it's not used at the moment.
Returns:
The ChatResult that contains outputs generated by the model.
Raises:
ValueError: if the last message in the list is not from human.
"""
if "stream" in kwargs:
kwargs.pop("stream")
logger.warning("ChatVertexAI does not currently support async streaming.")
question = _get_question(messages)
history = _parse_chat_history(messages[:-1])
params = {**self._default_params, **kwargs}
examples = kwargs.get("examples", None)
if examples:
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, params)
response = await chat.send_message_async(question.content)
text = self._enforce_stop_words(response.text, stop)
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
def _stream(
self,
messages: List[BaseMessage],

@ -1,28 +1,58 @@
from __future__ import annotations
import asyncio
from concurrent.futures import Executor, ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
Iterator,
List,
Optional,
Union,
)
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM, create_base_retry_decorator
from langchain.llms.base import BaseLLM, create_base_retry_decorator
from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain.schema import (
Generation,
LLMResult,
)
from langchain.schema.output import GenerationChunk
from langchain.utilities.vertexai import (
init_vertexai,
raise_vertex_import_error,
)
if TYPE_CHECKING:
from google.cloud.aiplatform.gapic import PredictionServiceClient
from vertexai.language_models._language_models import _LanguageModel
from google.cloud.aiplatform.gapic import (
PredictionServiceAsyncClient,
PredictionServiceClient,
)
from vertexai.language_models._language_models import (
TextGenerationResponse,
_LanguageModel,
)
def _response_to_generation(
response: TextGenerationResponse,
) -> GenerationChunk:
"""Convert a stream response to a generation chunk."""
try:
generation_info = {
"is_blocked": response.is_blocked,
"safety_attributes": response.safety_attributes,
}
except Exception:
generation_info = None
return GenerationChunk(text=response.text, generation_info=generation_info)
def is_codey_model(model_name: str) -> bool:
@ -36,7 +66,13 @@ def is_codey_model(model_name: str) -> bool:
return "code" in model_name
def _create_retry_decorator(llm: VertexAI) -> Callable[[Any], Any]:
def _create_retry_decorator(
llm: VertexAI,
*,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
import google.api_core
errors = [
@ -46,14 +82,19 @@ def _create_retry_decorator(llm: VertexAI) -> Callable[[Any], Any]:
google.api_core.exceptions.DeadlineExceeded,
]
decorator = create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries # type: ignore
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)
return decorator
def completion_with_retry(llm: VertexAI, *args: Any, **kwargs: Any) -> Any:
def completion_with_retry(
llm: VertexAI,
*args: Any,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm)
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
@ -62,6 +103,38 @@ def completion_with_retry(llm: VertexAI, *args: Any, **kwargs: Any) -> Any:
return _completion_with_retry(*args, **kwargs)
def stream_completion_with_retry(
llm: VertexAI,
*args: Any,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
return llm.client.predict_streaming(*args, **kwargs)
return _completion_with_retry(*args, **kwargs)
async def acompletion_with_retry(
llm: VertexAI,
*args: Any,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
async def _acompletion_with_retry(*args: Any, **kwargs: Any) -> Any:
return await llm.client.predict_async(*args, **kwargs)
return await _acompletion_with_retry(*args, **kwargs)
class _VertexAIBase(BaseModel):
project: Optional[str] = None
"The default GCP project to use when making Vertex API calls."
@ -110,6 +183,11 @@ class _VertexAICommon(_VertexAIBase):
"The default custom credentials (google.auth.credentials.Credentials) to use "
"when making API calls. If not provided, credentials will be ascertained from "
"the environment."
streaming: bool = False
@property
def _llm_type(self) -> str:
return "vertexai"
@property
def is_codey_model(self) -> bool:
@ -135,17 +213,6 @@ class _VertexAICommon(_VertexAIBase):
"top_p": self.top_p,
}
def _predict(
self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
) -> str:
params = {**self._default_params, **kwargs}
res = completion_with_retry(self, prompt, **params) # type: ignore
return self._enforce_stop_words(res.text, stop)
@property
def _llm_type(self) -> str:
return "vertexai"
@classmethod
def _try_init_vertexai(cls, values: Dict) -> None:
allowed_params = ["project", "location", "credentials"]
@ -154,13 +221,14 @@ class _VertexAICommon(_VertexAIBase):
return None
class VertexAI(_VertexAICommon, LLM):
class VertexAI(_VertexAICommon, BaseLLM):
"""Google Vertex AI large language models."""
model_name: str = "text-bison"
"The name of the Vertex AI large language model."
tuned_model_name: Optional[str] = None
"The name of a tuned model. If provided, model_name is ignored."
streaming: bool = False
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
@ -191,51 +259,78 @@ class VertexAI(_VertexAICommon, LLM):
raise_vertex_import_error()
return values
def _call(
def _generate(
self,
prompt: str,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> str:
"""Call Vertex model to get predictions based on the prompt.
Args:
prompt: The prompt to pass into the model.
stop: A list of stop words (optional).
run_manager: A Callbackmanager for LLM run, optional.
) -> LLMResult:
stop_sequences = stop or self.stop
should_stream = stream if stream is not None else self.streaming
Returns:
The string generated by the model.
"""
return self._predict(prompt, stop, **kwargs)
params = {**self._default_params, "stop_sequences": stop_sequences, **kwargs}
generations = []
for prompt in prompts:
if should_stream:
generation = GenerationChunk(text="")
for chunk in self._stream(
prompt, stop=stop, run_manager=run_manager, **kwargs
):
generation += chunk
generations.append([generation])
else:
res = completion_with_retry(
self, prompt, run_manager=run_manager, **params
)
generations.append([_response_to_generation(res)])
return LLMResult(generations=generations)
async def _acall(
async def _agenerate(
self,
prompt: str,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call Vertex model to get predictions based on the prompt.
Args:
prompt: The prompt to pass into the model.
stop: A list of stop words (optional).
run_manager: A callback manager for async interaction with LLMs.
Returns:
The string generated by the model.
"""
return await asyncio.wrap_future(
self._get_task_executor().submit(self._call, prompt, stop)
)
) -> LLMResult:
stop_sequences = stop or self.stop
params = {**self._default_params, "stop_sequences": stop_sequences, **kwargs}
generations = []
for prompt in prompts:
res = await acompletion_with_retry(
self, prompt, run_manager=run_manager, **params
)
generations.append([_response_to_generation(res)])
return LLMResult(generations=generations)
class VertexAIModelGarden(_VertexAIBase, LLM):
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
stop_sequences = stop or self.stop
params = {**self._default_params, "stop_sequences": stop_sequences, **kwargs}
for stream_resp in stream_completion_with_retry(
self, prompt, run_manager=run_manager, **params
):
chunk = _response_to_generation(stream_resp)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
chunk=chunk,
verbose=self.verbose,
)
class VertexAIModelGarden(_VertexAIBase, BaseLLM):
"""Large language models served from Vertex AI Model Garden."""
client: "PredictionServiceClient" = None #: :meta private:
async_client: "PredictionServiceAsyncClient" = None #: :meta private:
endpoint_id: str
"A name of an endpoint where the model has been deployed."
allowed_model_args: Optional[List[str]] = None
@ -247,7 +342,11 @@ class VertexAIModelGarden(_VertexAIBase, LLM):
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""
try:
from google.cloud.aiplatform.gapic import PredictionServiceClient
from google.api_core.client_options import ClientOptions
from google.cloud.aiplatform.gapic import (
PredictionServiceAsyncClient,
PredictionServiceClient,
)
except ImportError:
raise_vertex_import_error()
@ -256,38 +355,19 @@ class VertexAIModelGarden(_VertexAIBase, LLM):
"A GCP project should be provided to run inference on Model Garden!"
)
client_options = {
"api_endpoint": f"{values['location']}-aiplatform.googleapis.com"
}
client_options = ClientOptions(
api_endpoint=f"{values['location']}-aiplatform.googleapis.com"
)
values["client"] = PredictionServiceClient(client_options=client_options)
values["async_client"] = PredictionServiceAsyncClient(
client_options=client_options
)
return values
@property
def _llm_type(self) -> str:
return "vertexai_model_garden"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call Vertex model to get predictions based on the prompt.
Args:
prompt: The prompt to pass into the model.
stop: A list of stop words (optional).
run_manager: A Callbackmanager for LLM run, optional.
Returns:
The string generated by the model.
"""
result = self._generate(
prompts=[prompt], stop=stop, run_manager=run_manager, **kwargs
)
return result.generations[0][0].text
def _generate(
self,
prompts: List[str],
@ -331,23 +411,47 @@ class VertexAIModelGarden(_VertexAIBase, LLM):
)
return LLMResult(generations=generations)
async def _acall(
async def _agenerate(
self,
prompt: str,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call Vertex model to get predictions based on the prompt.
Args:
prompt: The prompt to pass into the model.
stop: A list of stop words (optional).
run_manager: A callback manager for async interaction with LLMs.
Returns:
The string generated by the model.
"""
return await asyncio.wrap_future(
self._get_task_executor().submit(self._call, prompt, stop)
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
try:
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value
except ImportError:
raise ImportError(
"protobuf package not found, please install it with"
" `pip install protobuf`"
)
instances = []
for prompt in prompts:
if self.allowed_model_args:
instance = {
k: v for k, v in kwargs.items() if k in self.allowed_model_args
}
else:
instance = {}
instance[self.prompt_arg] = prompt
instances.append(instance)
predict_instances = [
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
]
endpoint = self.async_client.endpoint_path(
project=self.project, location=self.location, endpoint=self.endpoint_id
)
response = await self.async_client.predict(
endpoint=endpoint, instances=predict_instances
)
generations: List[List[Generation]] = []
for result in response.predictions:
generations.append(
[Generation(text=prediction[self.result_arg]) for prediction in result]
)
return LLMResult(generations=generations)

@ -13,6 +13,7 @@ import pytest
from langchain.chat_models import ChatVertexAI
from langchain.chat_models.vertexai import _parse_chat_history, _parse_examples
from langchain.schema import LLMResult
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
@ -26,10 +27,22 @@ def test_vertexai_single_call(model_name: str) -> None:
response = model([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert model._llm_type == "vertexai"
assert model._llm_type == "chat-vertexai"
assert model.model_name == model.client._model_id
@pytest.mark.asyncio
async def test_vertexai_agenerate() -> None:
model = ChatVertexAI(temperature=0)
message = HumanMessage(content="Hello")
response = await model.agenerate([[message]])
assert isinstance(response, LLMResult)
assert isinstance(response.generations[0][0].message, AIMessage) # type: ignore
sync_response = model.generate([[message]])
assert response.generations[0][0] == sync_response.generations[0][0]
def test_vertexai_single_call_with_context() -> None:
model = ChatVertexAI()
raw_context = (

@ -14,7 +14,6 @@ def test_embedding_documents() -> None:
output = model.embed_documents(documents)
assert len(output) == 1
assert len(output[0]) == 768
assert model._llm_type == "vertexai"
assert model.model_name == model.client._model_id
@ -40,5 +39,4 @@ def test_paginated_texts() -> None:
output = model.embed_documents(documents)
assert len(output) == 8
assert len(output[0]) == 768
assert model._llm_type == "vertexai"
assert model.model_name == model.client._model_id

@ -9,18 +9,49 @@ Your end-user credentials would be used to make the calls (make sure you've run
"""
import os
import pytest
from langchain.llms import VertexAI, VertexAIModelGarden
from langchain.schema import LLMResult
def test_vertex_call() -> None:
llm = VertexAI()
llm = VertexAI(temperature=0)
output = llm("Say foo:")
assert isinstance(output, str)
assert llm._llm_type == "vertexai"
assert llm.model_name == llm.client._model_id
def test_vertex_generate() -> None:
llm = VertexAI(temperate=0)
output = llm.generate(["Please say foo:"])
assert isinstance(output, LLMResult)
@pytest.mark.asyncio
async def test_vertex_agenerate() -> None:
llm = VertexAI(temperate=0)
output = await llm.agenerate(["Please say foo:"])
assert isinstance(output, LLMResult)
def test_vertext_stream() -> None:
llm = VertexAI(temperate=0)
outputs = list(llm.stream("Please say foo:"))
assert isinstance(outputs[0], str)
@pytest.mark.asyncio
async def test_vertex_consistency() -> None:
llm = VertexAI(temperate=0)
output = llm.generate(["Please say foo:"])
streaming_output = llm.generate(["Please say foo:"], stream=True)
async_output = await llm.agenerate(["Please say foo:"])
assert output.generations[0][0].text == streaming_output.generations[0][0].text
assert output.generations[0][0].text == async_output.generations[0][0].text
def test_model_garden() -> None:
"""In order to run this test, you should provide an endpoint name.
@ -37,7 +68,7 @@ def test_model_garden() -> None:
assert llm._llm_type == "vertexai_model_garden"
def test_model_garden_batch() -> None:
def test_model_garden_generate() -> None:
"""In order to run this test, you should provide an endpoint name.
Example:
@ -47,6 +78,16 @@ def test_model_garden_batch() -> None:
endpoint_id = os.environ["ENDPOINT_ID"]
project = os.environ["PROJECT"]
llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project)
output = llm._generate(["What is the meaning of life?", "How much is 2+2"])
output = llm.generate(["What is the meaning of life?", "How much is 2+2"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 2
@pytest.mark.asyncio
async def test_model_garden_agenerate() -> None:
endpoint_id = os.environ["ENDPOINT_ID"]
project = os.environ["PROJECT"]
llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project)
output = await llm.agenerate(["What is the meaning of life?", "How much is 2+2"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 2

Loading…
Cancel
Save