diff --git a/docs/extras/modules/model_io/models/chat/integrations/google_vertex_ai_palm.ipynb b/docs/extras/modules/model_io/models/chat/integrations/google_vertex_ai_palm.ipynb index 67981d98..95a00a47 100644 --- a/docs/extras/modules/model_io/models/chat/integrations/google_vertex_ai_palm.ipynb +++ b/docs/extras/modules/model_io/models/chat/integrations/google_vertex_ai_palm.ipynb @@ -141,6 +141,73 @@ ")" ] }, + { + "cell_type": "markdown", + "metadata": { + "execution": { + "iopub.execute_input": "2023-06-17T21:09:25.423568Z", + "iopub.status.busy": "2023-06-17T21:09:25.423213Z", + "iopub.status.idle": "2023-06-17T21:09:25.429641Z", + "shell.execute_reply": "2023-06-17T21:09:25.429060Z", + "shell.execute_reply.started": "2023-06-17T21:09:25.423546Z" + }, + "tags": [] + }, + "source": [ + "You can now leverage the Codey API for code chat within Vertex AI. The model name is:\n", + "- codechat-bison: for code assistance" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "execution": { + "iopub.execute_input": "2023-06-17T21:30:43.974841Z", + "iopub.status.busy": "2023-06-17T21:30:43.974431Z", + "iopub.status.idle": "2023-06-17T21:30:44.248119Z", + "shell.execute_reply": "2023-06-17T21:30:44.247362Z", + "shell.execute_reply.started": "2023-06-17T21:30:43.974820Z" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "chat = ChatVertexAI(model_name=\"codechat-bison\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "execution": { + "iopub.execute_input": "2023-06-17T21:30:45.146093Z", + "iopub.status.busy": "2023-06-17T21:30:45.145752Z", + "iopub.status.idle": "2023-06-17T21:30:47.449126Z", + "shell.execute_reply": "2023-06-17T21:30:47.448609Z", + "shell.execute_reply.started": "2023-06-17T21:30:45.146069Z" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='The following Python function can be used to identify all prime numbers up to a given integer:\\n\\n```\\ndef is_prime(n):\\n \"\"\"\\n Determines whether the given integer is prime.\\n\\n Args:\\n n: The integer to be tested for primality.\\n\\n Returns:\\n True if n is prime, False otherwise.\\n \"\"\"\\n\\n # Check if n is divisible by 2.\\n if n % 2 == 0:\\n return False\\n\\n # Check if n is divisible by any integer from 3 to the square root', additional_kwargs={}, example=False)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "messages = [\n", + " HumanMessage(content=\"How do I create a python function to identify all prime numbers?\")\n", + "]\n", + "chat(messages)" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/docs/extras/modules/model_io/models/llms/integrations/google_vertex_ai_palm.ipynb b/docs/extras/modules/model_io/models/llms/integrations/google_vertex_ai_palm.ipynb index d2551fd6..0854478d 100644 --- a/docs/extras/modules/model_io/models/llms/integrations/google_vertex_ai_palm.ipynb +++ b/docs/extras/modules/model_io/models/llms/integrations/google_vertex_ai_palm.ipynb @@ -101,11 +101,80 @@ ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, + "source": [ + "You can now leverage the Codey API for code generation within Vertex AI. The model names are:\n", + "- code-bison: for code suggestion\n", + "- code-gecko: for code completion" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "execution": { + "iopub.execute_input": "2023-06-17T21:16:53.149438Z", + "iopub.status.busy": "2023-06-17T21:16:53.149065Z", + "iopub.status.idle": "2023-06-17T21:16:53.421824Z", + "shell.execute_reply": "2023-06-17T21:16:53.421136Z", + "shell.execute_reply.started": "2023-06-17T21:16:53.149415Z" + }, + "tags": [] + }, "outputs": [], - "source": [] + "source": [ + "llm = VertexAI(model_name=\"code-bison\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "execution": { + "iopub.execute_input": "2023-06-17T21:17:11.179077Z", + "iopub.status.busy": "2023-06-17T21:17:11.178686Z", + "iopub.status.idle": "2023-06-17T21:17:11.182499Z", + "shell.execute_reply": "2023-06-17T21:17:11.181895Z", + "shell.execute_reply.started": "2023-06-17T21:17:11.179052Z" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "llm_chain = LLMChain(prompt=prompt, llm=llm)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "execution": { + "iopub.execute_input": "2023-06-17T21:18:47.024785Z", + "iopub.status.busy": "2023-06-17T21:18:47.024230Z", + "iopub.status.idle": "2023-06-17T21:18:49.352249Z", + "shell.execute_reply": "2023-06-17T21:18:49.351695Z", + "shell.execute_reply.started": "2023-06-17T21:18:47.024762Z" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'```python\\ndef is_prime(n):\\n \"\"\"\\n Determines if a number is prime.\\n\\n Args:\\n n: The number to be tested.\\n\\n Returns:\\n True if the number is prime, False otherwise.\\n \"\"\"\\n\\n # Check if the number is 1.\\n if n == 1:\\n return False\\n\\n # Check if the number is 2.\\n if n == 2:\\n return True\\n\\n'" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "question = \"Write a python function that identifies if the number is a prime number?\"\n", + "\n", + "llm_chain.run(question)" + ] } ], "metadata": { diff --git a/langchain/chat_models/vertexai.py b/langchain/chat_models/vertexai.py index bd2ecbb2..b9440476 100644 --- a/langchain/chat_models/vertexai.py +++ b/langchain/chat_models/vertexai.py @@ -9,7 +9,7 @@ from langchain.callbacks.manager import ( CallbackManagerForLLMRun, ) from langchain.chat_models.base import BaseChatModel -from langchain.llms.vertexai import _VertexAICommon +from langchain.llms.vertexai import _VertexAICommon, is_codey_model from langchain.schema import ( AIMessage, BaseMessage, @@ -42,7 +42,7 @@ def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory: A sequence should be either (SystemMessage, HumanMessage, AIMessage, HumanMessage, AIMessage, ...) or (HumanMessage, AIMessage, HumanMessage, - AIMessage, ...). + AIMessage, ...). CodeChat does not support SystemMessage. Args: history: The list of messages to re-create the history of the chat. @@ -82,10 +82,16 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): """Validate that the python package exists in environment.""" cls._try_init_vertexai(values) try: - from vertexai.preview.language_models import ChatModel + if is_codey_model(values["model_name"]): + from vertexai.preview.language_models import CodeChatModel + + values["client"] = CodeChatModel.from_pretrained(values["model_name"]) + else: + from vertexai.preview.language_models import ChatModel + + values["client"] = ChatModel.from_pretrained(values["model_name"]) except ImportError: raise_vertex_import_error() - values["client"] = ChatModel.from_pretrained(values["model_name"]) return values def _generate( @@ -98,9 +104,10 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): """Generate next turn in the conversation. Args: - messages: The history of the conversation as a list of messages. + messages: The history of the conversation as a list of messages. Code chat + does not support context. stop: The list of stop words (optional). - run_manager: The Callbackmanager for LLM run, it's not used at the moment. + run_manager: The CallbackManager for LLM run, it's not used at the moment. Returns: The ChatResult that contains outputs generated by the model. @@ -121,10 +128,12 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): history = _parse_chat_history(messages[:-1]) context = history.system_message.content if history.system_message else None params = {**self._default_params, **kwargs} - chat = self.client.start_chat(context=context, **params) + if not self.is_codey_model: + params["context"] = context + chat = self.client.start_chat(**params) for pair in history.history: chat._history.append((pair.question.content, pair.answer.content)) - response = chat.send_message(question.content, **self._default_params) + response = chat.send_message(question.content, **params) text = self._enforce_stop_words(response.text, stop) return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))]) diff --git a/langchain/llms/vertexai.py b/langchain/llms/vertexai.py index 522c8cd5..0f85920c 100644 --- a/langchain/llms/vertexai.py +++ b/langchain/llms/vertexai.py @@ -15,6 +15,10 @@ if TYPE_CHECKING: from vertexai.language_models._language_models import _LanguageModel +def is_codey_model(model_name: str) -> bool: + return "code" in model_name + + class _VertexAICommon(BaseModel): client: "_LanguageModel" = None #: :meta private: model_name: str @@ -25,10 +29,10 @@ class _VertexAICommon(BaseModel): "Token limit determines the maximum amount of text output from one prompt." top_p: float = 0.95 "Tokens are selected from most probable to least until the sum of their " - "probabilities equals the top-p value." + "probabilities equals the top-p value. Top-p is ignored for Codey models." top_k: int = 40 "How the model selects tokens for output, the next token is selected from " - "among the top-k most probable tokens." + "among the top-k most probable tokens. Top-k is ignored for Codey models." stop: Optional[List[str]] = None "Optional list of stop words to use when generating." project: Optional[str] = None @@ -40,15 +44,24 @@ class _VertexAICommon(BaseModel): "when making API calls. If not provided, credentials will be ascertained from " "the environment." + @property + def is_codey_model(self) -> bool: + return is_codey_model(self.model_name) + @property def _default_params(self) -> Dict[str, Any]: - base_params = { - "temperature": self.temperature, - "max_output_tokens": self.max_output_tokens, - "top_k": self.top_k, - "top_p": self.top_p, - } - return {**base_params} + if self.is_codey_model: + return { + "temperature": self.temperature, + "max_output_tokens": self.max_output_tokens, + } + else: + return { + "temperature": self.temperature, + "max_output_tokens": self.max_output_tokens, + "top_k": self.top_k, + "top_p": self.top_p, + } def _predict( self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any @@ -80,22 +93,32 @@ class VertexAI(_VertexAICommon, LLM): """Wrapper around Google Vertex AI large language models.""" model_name: str = "text-bison" + "The name of the Vertex AI large language model." tuned_model_name: Optional[str] = None - "The name of a tuned model, if it's provided, model_name is ignored." + "The name of a tuned model. If provided, model_name is ignored." @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that the python package exists in environment.""" cls._try_init_vertexai(values) + tuned_model_name = values.get("tuned_model_name") + model_name = values["model_name"] try: - from vertexai.preview.language_models import TextGenerationModel + if tuned_model_name or not is_codey_model(model_name): + from vertexai.preview.language_models import TextGenerationModel + + if tuned_model_name: + values["client"] = TextGenerationModel.get_tuned_model( + tuned_model_name + ) + else: + values["client"] = TextGenerationModel.from_pretrained(model_name) + else: + from vertexai.preview.language_models import CodeGenerationModel + + values["client"] = CodeGenerationModel.from_pretrained(model_name) except ImportError: raise_vertex_import_error() - tuned_model_name = values.get("tuned_model_name") - if tuned_model_name: - values["client"] = TextGenerationModel.get_tuned_model(tuned_model_name) - else: - values["client"] = TextGenerationModel.from_pretrained(values["model_name"]) return values def _call( diff --git a/langchain/utilities/vertexai.py b/langchain/utilities/vertexai.py index 60050f11..8dde3d3c 100644 --- a/langchain/utilities/vertexai.py +++ b/langchain/utilities/vertexai.py @@ -11,7 +11,7 @@ def raise_vertex_import_error() -> None: Raises: ImportError: an ImportError that mentions a required version of the SDK. """ - sdk = "'google-cloud-aiplatform>=1.25.0'" + sdk = "'google-cloud-aiplatform>=1.26.0'" raise ImportError( "Could not import VertexAI. Please, install it with " f"pip install {sdk}" )