Be able to use Codey models on Vertex AI (#6354)

Added the functionality to leverage 3 new Codey models from Vertex AI:
- code-bison - Code generation using the existing LLM integration
- code-gecko - Code completion using the existing LLM integration
- codechat-bison - Code chat using the existing chat_model integration

---------

Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
master
Hassan Ouda 12 months ago committed by GitHub
parent 0fce8ef178
commit 456ca3d587
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -141,6 +141,73 @@
")" ")"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {
"execution": {
"iopub.execute_input": "2023-06-17T21:09:25.423568Z",
"iopub.status.busy": "2023-06-17T21:09:25.423213Z",
"iopub.status.idle": "2023-06-17T21:09:25.429641Z",
"shell.execute_reply": "2023-06-17T21:09:25.429060Z",
"shell.execute_reply.started": "2023-06-17T21:09:25.423546Z"
},
"tags": []
},
"source": [
"You can now leverage the Codey API for code chat within Vertex AI. The model name is:\n",
"- codechat-bison: for code assistance"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"execution": {
"iopub.execute_input": "2023-06-17T21:30:43.974841Z",
"iopub.status.busy": "2023-06-17T21:30:43.974431Z",
"iopub.status.idle": "2023-06-17T21:30:44.248119Z",
"shell.execute_reply": "2023-06-17T21:30:44.247362Z",
"shell.execute_reply.started": "2023-06-17T21:30:43.974820Z"
},
"tags": []
},
"outputs": [],
"source": [
"chat = ChatVertexAI(model_name=\"codechat-bison\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2023-06-17T21:30:45.146093Z",
"iopub.status.busy": "2023-06-17T21:30:45.145752Z",
"iopub.status.idle": "2023-06-17T21:30:47.449126Z",
"shell.execute_reply": "2023-06-17T21:30:47.448609Z",
"shell.execute_reply.started": "2023-06-17T21:30:45.146069Z"
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='The following Python function can be used to identify all prime numbers up to a given integer:\\n\\n```\\ndef is_prime(n):\\n \"\"\"\\n Determines whether the given integer is prime.\\n\\n Args:\\n n: The integer to be tested for primality.\\n\\n Returns:\\n True if n is prime, False otherwise.\\n \"\"\"\\n\\n # Check if n is divisible by 2.\\n if n % 2 == 0:\\n return False\\n\\n # Check if n is divisible by any integer from 3 to the square root', additional_kwargs={}, example=False)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = [\n",
" HumanMessage(content=\"How do I create a python function to identify all prime numbers?\")\n",
"]\n",
"chat(messages)"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,

@ -101,11 +101,80 @@
] ]
}, },
{ {
"cell_type": "code", "cell_type": "markdown",
"execution_count": null,
"metadata": {}, "metadata": {},
"source": [
"You can now leverage the Codey API for code generation within Vertex AI. The model names are:\n",
"- code-bison: for code suggestion\n",
"- code-gecko: for code completion"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"execution": {
"iopub.execute_input": "2023-06-17T21:16:53.149438Z",
"iopub.status.busy": "2023-06-17T21:16:53.149065Z",
"iopub.status.idle": "2023-06-17T21:16:53.421824Z",
"shell.execute_reply": "2023-06-17T21:16:53.421136Z",
"shell.execute_reply.started": "2023-06-17T21:16:53.149415Z"
},
"tags": []
},
"outputs": [], "outputs": [],
"source": [] "source": [
"llm = VertexAI(model_name=\"code-bison\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"execution": {
"iopub.execute_input": "2023-06-17T21:17:11.179077Z",
"iopub.status.busy": "2023-06-17T21:17:11.178686Z",
"iopub.status.idle": "2023-06-17T21:17:11.182499Z",
"shell.execute_reply": "2023-06-17T21:17:11.181895Z",
"shell.execute_reply.started": "2023-06-17T21:17:11.179052Z"
},
"tags": []
},
"outputs": [],
"source": [
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"execution": {
"iopub.execute_input": "2023-06-17T21:18:47.024785Z",
"iopub.status.busy": "2023-06-17T21:18:47.024230Z",
"iopub.status.idle": "2023-06-17T21:18:49.352249Z",
"shell.execute_reply": "2023-06-17T21:18:49.351695Z",
"shell.execute_reply.started": "2023-06-17T21:18:47.024762Z"
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"'```python\\ndef is_prime(n):\\n \"\"\"\\n Determines if a number is prime.\\n\\n Args:\\n n: The number to be tested.\\n\\n Returns:\\n True if the number is prime, False otherwise.\\n \"\"\"\\n\\n # Check if the number is 1.\\n if n == 1:\\n return False\\n\\n # Check if the number is 2.\\n if n == 2:\\n return True\\n\\n'"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"question = \"Write a python function that identifies if the number is a prime number?\"\n",
"\n",
"llm_chain.run(question)"
]
} }
], ],
"metadata": { "metadata": {

@ -9,7 +9,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import BaseChatModel
from langchain.llms.vertexai import _VertexAICommon from langchain.llms.vertexai import _VertexAICommon, is_codey_model
from langchain.schema import ( from langchain.schema import (
AIMessage, AIMessage,
BaseMessage, BaseMessage,
@ -42,7 +42,7 @@ def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory:
A sequence should be either (SystemMessage, HumanMessage, AIMessage, A sequence should be either (SystemMessage, HumanMessage, AIMessage,
HumanMessage, AIMessage, ...) or (HumanMessage, AIMessage, HumanMessage, HumanMessage, AIMessage, ...) or (HumanMessage, AIMessage, HumanMessage,
AIMessage, ...). AIMessage, ...). CodeChat does not support SystemMessage.
Args: Args:
history: The list of messages to re-create the history of the chat. history: The list of messages to re-create the history of the chat.
@ -82,10 +82,16 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
"""Validate that the python package exists in environment.""" """Validate that the python package exists in environment."""
cls._try_init_vertexai(values) cls._try_init_vertexai(values)
try: try:
from vertexai.preview.language_models import ChatModel if is_codey_model(values["model_name"]):
from vertexai.preview.language_models import CodeChatModel
values["client"] = CodeChatModel.from_pretrained(values["model_name"])
else:
from vertexai.preview.language_models import ChatModel
values["client"] = ChatModel.from_pretrained(values["model_name"])
except ImportError: except ImportError:
raise_vertex_import_error() raise_vertex_import_error()
values["client"] = ChatModel.from_pretrained(values["model_name"])
return values return values
def _generate( def _generate(
@ -98,9 +104,10 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
"""Generate next turn in the conversation. """Generate next turn in the conversation.
Args: Args:
messages: The history of the conversation as a list of messages. messages: The history of the conversation as a list of messages. Code chat
does not support context.
stop: The list of stop words (optional). stop: The list of stop words (optional).
run_manager: The Callbackmanager for LLM run, it's not used at the moment. run_manager: The CallbackManager for LLM run, it's not used at the moment.
Returns: Returns:
The ChatResult that contains outputs generated by the model. The ChatResult that contains outputs generated by the model.
@ -121,10 +128,12 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
history = _parse_chat_history(messages[:-1]) history = _parse_chat_history(messages[:-1])
context = history.system_message.content if history.system_message else None context = history.system_message.content if history.system_message else None
params = {**self._default_params, **kwargs} params = {**self._default_params, **kwargs}
chat = self.client.start_chat(context=context, **params) if not self.is_codey_model:
params["context"] = context
chat = self.client.start_chat(**params)
for pair in history.history: for pair in history.history:
chat._history.append((pair.question.content, pair.answer.content)) chat._history.append((pair.question.content, pair.answer.content))
response = chat.send_message(question.content, **self._default_params) response = chat.send_message(question.content, **params)
text = self._enforce_stop_words(response.text, stop) text = self._enforce_stop_words(response.text, stop)
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))]) return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])

@ -15,6 +15,10 @@ if TYPE_CHECKING:
from vertexai.language_models._language_models import _LanguageModel from vertexai.language_models._language_models import _LanguageModel
def is_codey_model(model_name: str) -> bool:
return "code" in model_name
class _VertexAICommon(BaseModel): class _VertexAICommon(BaseModel):
client: "_LanguageModel" = None #: :meta private: client: "_LanguageModel" = None #: :meta private:
model_name: str model_name: str
@ -25,10 +29,10 @@ class _VertexAICommon(BaseModel):
"Token limit determines the maximum amount of text output from one prompt." "Token limit determines the maximum amount of text output from one prompt."
top_p: float = 0.95 top_p: float = 0.95
"Tokens are selected from most probable to least until the sum of their " "Tokens are selected from most probable to least until the sum of their "
"probabilities equals the top-p value." "probabilities equals the top-p value. Top-p is ignored for Codey models."
top_k: int = 40 top_k: int = 40
"How the model selects tokens for output, the next token is selected from " "How the model selects tokens for output, the next token is selected from "
"among the top-k most probable tokens." "among the top-k most probable tokens. Top-k is ignored for Codey models."
stop: Optional[List[str]] = None stop: Optional[List[str]] = None
"Optional list of stop words to use when generating." "Optional list of stop words to use when generating."
project: Optional[str] = None project: Optional[str] = None
@ -40,15 +44,24 @@ class _VertexAICommon(BaseModel):
"when making API calls. If not provided, credentials will be ascertained from " "when making API calls. If not provided, credentials will be ascertained from "
"the environment." "the environment."
@property
def is_codey_model(self) -> bool:
return is_codey_model(self.model_name)
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> Dict[str, Any]:
base_params = { if self.is_codey_model:
"temperature": self.temperature, return {
"max_output_tokens": self.max_output_tokens, "temperature": self.temperature,
"top_k": self.top_k, "max_output_tokens": self.max_output_tokens,
"top_p": self.top_p, }
} else:
return {**base_params} return {
"temperature": self.temperature,
"max_output_tokens": self.max_output_tokens,
"top_k": self.top_k,
"top_p": self.top_p,
}
def _predict( def _predict(
self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
@ -80,22 +93,32 @@ class VertexAI(_VertexAICommon, LLM):
"""Wrapper around Google Vertex AI large language models.""" """Wrapper around Google Vertex AI large language models."""
model_name: str = "text-bison" model_name: str = "text-bison"
"The name of the Vertex AI large language model."
tuned_model_name: Optional[str] = None tuned_model_name: Optional[str] = None
"The name of a tuned model, if it's provided, model_name is ignored." "The name of a tuned model. If provided, model_name is ignored."
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment.""" """Validate that the python package exists in environment."""
cls._try_init_vertexai(values) cls._try_init_vertexai(values)
tuned_model_name = values.get("tuned_model_name")
model_name = values["model_name"]
try: try:
from vertexai.preview.language_models import TextGenerationModel if tuned_model_name or not is_codey_model(model_name):
from vertexai.preview.language_models import TextGenerationModel
if tuned_model_name:
values["client"] = TextGenerationModel.get_tuned_model(
tuned_model_name
)
else:
values["client"] = TextGenerationModel.from_pretrained(model_name)
else:
from vertexai.preview.language_models import CodeGenerationModel
values["client"] = CodeGenerationModel.from_pretrained(model_name)
except ImportError: except ImportError:
raise_vertex_import_error() raise_vertex_import_error()
tuned_model_name = values.get("tuned_model_name")
if tuned_model_name:
values["client"] = TextGenerationModel.get_tuned_model(tuned_model_name)
else:
values["client"] = TextGenerationModel.from_pretrained(values["model_name"])
return values return values
def _call( def _call(

@ -11,7 +11,7 @@ def raise_vertex_import_error() -> None:
Raises: Raises:
ImportError: an ImportError that mentions a required version of the SDK. ImportError: an ImportError that mentions a required version of the SDK.
""" """
sdk = "'google-cloud-aiplatform>=1.25.0'" sdk = "'google-cloud-aiplatform>=1.26.0'"
raise ImportError( raise ImportError(
"Could not import VertexAI. Please, install it with " f"pip install {sdk}" "Could not import VertexAI. Please, install it with " f"pip install {sdk}"
) )

Loading…
Cancel
Save