Be able to use Codey models on Vertex AI (#6354)

Added the functionality to leverage 3 new Codey models from Vertex AI:
- code-bison - Code generation using the existing LLM integration
- code-gecko - Code completion using the existing LLM integration
- codechat-bison - Code chat using the existing chat_model integration

---------

Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
master
Hassan Ouda 11 months ago committed by GitHub
parent 0fce8ef178
commit 456ca3d587
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -141,6 +141,73 @@
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {
"iopub.execute_input": "2023-06-17T21:09:25.423568Z",
"iopub.status.busy": "2023-06-17T21:09:25.423213Z",
"iopub.status.idle": "2023-06-17T21:09:25.429641Z",
"shell.execute_reply": "2023-06-17T21:09:25.429060Z",
"shell.execute_reply.started": "2023-06-17T21:09:25.423546Z"
},
"tags": []
},
"source": [
"You can now leverage the Codey API for code chat within Vertex AI. The model name is:\n",
"- codechat-bison: for code assistance"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"execution": {
"iopub.execute_input": "2023-06-17T21:30:43.974841Z",
"iopub.status.busy": "2023-06-17T21:30:43.974431Z",
"iopub.status.idle": "2023-06-17T21:30:44.248119Z",
"shell.execute_reply": "2023-06-17T21:30:44.247362Z",
"shell.execute_reply.started": "2023-06-17T21:30:43.974820Z"
},
"tags": []
},
"outputs": [],
"source": [
"chat = ChatVertexAI(model_name=\"codechat-bison\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2023-06-17T21:30:45.146093Z",
"iopub.status.busy": "2023-06-17T21:30:45.145752Z",
"iopub.status.idle": "2023-06-17T21:30:47.449126Z",
"shell.execute_reply": "2023-06-17T21:30:47.448609Z",
"shell.execute_reply.started": "2023-06-17T21:30:45.146069Z"
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='The following Python function can be used to identify all prime numbers up to a given integer:\\n\\n```\\ndef is_prime(n):\\n \"\"\"\\n Determines whether the given integer is prime.\\n\\n Args:\\n n: The integer to be tested for primality.\\n\\n Returns:\\n True if n is prime, False otherwise.\\n \"\"\"\\n\\n # Check if n is divisible by 2.\\n if n % 2 == 0:\\n return False\\n\\n # Check if n is divisible by any integer from 3 to the square root', additional_kwargs={}, example=False)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = [\n",
" HumanMessage(content=\"How do I create a python function to identify all prime numbers?\")\n",
"]\n",
"chat(messages)"
]
},
{
"cell_type": "code",
"execution_count": null,

@ -101,11 +101,80 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"cell_type": "markdown",
"metadata": {},
"source": [
"You can now leverage the Codey API for code generation within Vertex AI. The model names are:\n",
"- code-bison: for code suggestion\n",
"- code-gecko: for code completion"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"execution": {
"iopub.execute_input": "2023-06-17T21:16:53.149438Z",
"iopub.status.busy": "2023-06-17T21:16:53.149065Z",
"iopub.status.idle": "2023-06-17T21:16:53.421824Z",
"shell.execute_reply": "2023-06-17T21:16:53.421136Z",
"shell.execute_reply.started": "2023-06-17T21:16:53.149415Z"
},
"tags": []
},
"outputs": [],
"source": []
"source": [
"llm = VertexAI(model_name=\"code-bison\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"execution": {
"iopub.execute_input": "2023-06-17T21:17:11.179077Z",
"iopub.status.busy": "2023-06-17T21:17:11.178686Z",
"iopub.status.idle": "2023-06-17T21:17:11.182499Z",
"shell.execute_reply": "2023-06-17T21:17:11.181895Z",
"shell.execute_reply.started": "2023-06-17T21:17:11.179052Z"
},
"tags": []
},
"outputs": [],
"source": [
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"execution": {
"iopub.execute_input": "2023-06-17T21:18:47.024785Z",
"iopub.status.busy": "2023-06-17T21:18:47.024230Z",
"iopub.status.idle": "2023-06-17T21:18:49.352249Z",
"shell.execute_reply": "2023-06-17T21:18:49.351695Z",
"shell.execute_reply.started": "2023-06-17T21:18:47.024762Z"
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"'```python\\ndef is_prime(n):\\n \"\"\"\\n Determines if a number is prime.\\n\\n Args:\\n n: The number to be tested.\\n\\n Returns:\\n True if the number is prime, False otherwise.\\n \"\"\"\\n\\n # Check if the number is 1.\\n if n == 1:\\n return False\\n\\n # Check if the number is 2.\\n if n == 2:\\n return True\\n\\n'"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"question = \"Write a python function that identifies if the number is a prime number?\"\n",
"\n",
"llm_chain.run(question)"
]
}
],
"metadata": {

@ -9,7 +9,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.llms.vertexai import _VertexAICommon
from langchain.llms.vertexai import _VertexAICommon, is_codey_model
from langchain.schema import (
AIMessage,
BaseMessage,
@ -42,7 +42,7 @@ def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory:
A sequence should be either (SystemMessage, HumanMessage, AIMessage,
HumanMessage, AIMessage, ...) or (HumanMessage, AIMessage, HumanMessage,
AIMessage, ...).
AIMessage, ...). CodeChat does not support SystemMessage.
Args:
history: The list of messages to re-create the history of the chat.
@ -82,10 +82,16 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
"""Validate that the python package exists in environment."""
cls._try_init_vertexai(values)
try:
from vertexai.preview.language_models import ChatModel
if is_codey_model(values["model_name"]):
from vertexai.preview.language_models import CodeChatModel
values["client"] = CodeChatModel.from_pretrained(values["model_name"])
else:
from vertexai.preview.language_models import ChatModel
values["client"] = ChatModel.from_pretrained(values["model_name"])
except ImportError:
raise_vertex_import_error()
values["client"] = ChatModel.from_pretrained(values["model_name"])
return values
def _generate(
@ -98,9 +104,10 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
"""Generate next turn in the conversation.
Args:
messages: The history of the conversation as a list of messages.
messages: The history of the conversation as a list of messages. Code chat
does not support context.
stop: The list of stop words (optional).
run_manager: The Callbackmanager for LLM run, it's not used at the moment.
run_manager: The CallbackManager for LLM run, it's not used at the moment.
Returns:
The ChatResult that contains outputs generated by the model.
@ -121,10 +128,12 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
history = _parse_chat_history(messages[:-1])
context = history.system_message.content if history.system_message else None
params = {**self._default_params, **kwargs}
chat = self.client.start_chat(context=context, **params)
if not self.is_codey_model:
params["context"] = context
chat = self.client.start_chat(**params)
for pair in history.history:
chat._history.append((pair.question.content, pair.answer.content))
response = chat.send_message(question.content, **self._default_params)
response = chat.send_message(question.content, **params)
text = self._enforce_stop_words(response.text, stop)
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])

@ -15,6 +15,10 @@ if TYPE_CHECKING:
from vertexai.language_models._language_models import _LanguageModel
def is_codey_model(model_name: str) -> bool:
return "code" in model_name
class _VertexAICommon(BaseModel):
client: "_LanguageModel" = None #: :meta private:
model_name: str
@ -25,10 +29,10 @@ class _VertexAICommon(BaseModel):
"Token limit determines the maximum amount of text output from one prompt."
top_p: float = 0.95
"Tokens are selected from most probable to least until the sum of their "
"probabilities equals the top-p value."
"probabilities equals the top-p value. Top-p is ignored for Codey models."
top_k: int = 40
"How the model selects tokens for output, the next token is selected from "
"among the top-k most probable tokens."
"among the top-k most probable tokens. Top-k is ignored for Codey models."
stop: Optional[List[str]] = None
"Optional list of stop words to use when generating."
project: Optional[str] = None
@ -40,15 +44,24 @@ class _VertexAICommon(BaseModel):
"when making API calls. If not provided, credentials will be ascertained from "
"the environment."
@property
def is_codey_model(self) -> bool:
return is_codey_model(self.model_name)
@property
def _default_params(self) -> Dict[str, Any]:
base_params = {
"temperature": self.temperature,
"max_output_tokens": self.max_output_tokens,
"top_k": self.top_k,
"top_p": self.top_p,
}
return {**base_params}
if self.is_codey_model:
return {
"temperature": self.temperature,
"max_output_tokens": self.max_output_tokens,
}
else:
return {
"temperature": self.temperature,
"max_output_tokens": self.max_output_tokens,
"top_k": self.top_k,
"top_p": self.top_p,
}
def _predict(
self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
@ -80,22 +93,32 @@ class VertexAI(_VertexAICommon, LLM):
"""Wrapper around Google Vertex AI large language models."""
model_name: str = "text-bison"
"The name of the Vertex AI large language model."
tuned_model_name: Optional[str] = None
"The name of a tuned model, if it's provided, model_name is ignored."
"The name of a tuned model. If provided, model_name is ignored."
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""
cls._try_init_vertexai(values)
tuned_model_name = values.get("tuned_model_name")
model_name = values["model_name"]
try:
from vertexai.preview.language_models import TextGenerationModel
if tuned_model_name or not is_codey_model(model_name):
from vertexai.preview.language_models import TextGenerationModel
if tuned_model_name:
values["client"] = TextGenerationModel.get_tuned_model(
tuned_model_name
)
else:
values["client"] = TextGenerationModel.from_pretrained(model_name)
else:
from vertexai.preview.language_models import CodeGenerationModel
values["client"] = CodeGenerationModel.from_pretrained(model_name)
except ImportError:
raise_vertex_import_error()
tuned_model_name = values.get("tuned_model_name")
if tuned_model_name:
values["client"] = TextGenerationModel.get_tuned_model(tuned_model_name)
else:
values["client"] = TextGenerationModel.from_pretrained(values["model_name"])
return values
def _call(

@ -11,7 +11,7 @@ def raise_vertex_import_error() -> None:
Raises:
ImportError: an ImportError that mentions a required version of the SDK.
"""
sdk = "'google-cloud-aiplatform>=1.25.0'"
sdk = "'google-cloud-aiplatform>=1.26.0'"
raise ImportError(
"Could not import VertexAI. Please, install it with " f"pip install {sdk}"
)

Loading…
Cancel
Save