From 57bb940c17fd8d7037b95817524e6a540305c569 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Fri, 12 Apr 2024 10:09:54 -0700 Subject: [PATCH] docs: vertexai tool call update (#20362) --- .../chat/google_vertex_ai_palm.ipynb | 229 ++++++++++++++---- 1 file changed, 183 insertions(+), 46 deletions(-) diff --git a/docs/docs/integrations/chat/google_vertex_ai_palm.ipynb b/docs/docs/integrations/chat/google_vertex_ai_palm.ipynb index ecc7a653a0..6f44b09459 100644 --- a/docs/docs/integrations/chat/google_vertex_ai_palm.ipynb +++ b/docs/docs/integrations/chat/google_vertex_ai_palm.ipynb @@ -51,7 +51,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -261,31 +261,46 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from pprint import pprint\n", + "\n", + "from langchain_core.messages import HumanMessage\n", + "from langchain_google_vertexai import HarmBlockThreshold, HarmCategory" + ] + }, + { + "cell_type": "code", + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "{'is_blocked': False,\n", - " 'safety_ratings': [{'category': 'HARM_CATEGORY_HARASSMENT',\n", + "{'citation_metadata': None,\n", + " 'is_blocked': False,\n", + " 'safety_ratings': [{'blocked': False,\n", + " 'category': 'HARM_CATEGORY_HATE_SPEECH',\n", " 'probability_label': 'NEGLIGIBLE'},\n", - " {'category': 'HARM_CATEGORY_HATE_SPEECH',\n", + " {'blocked': False,\n", + " 'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',\n", " 'probability_label': 'NEGLIGIBLE'},\n", - " {'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT',\n", + " {'blocked': False,\n", + " 'category': 'HARM_CATEGORY_HARASSMENT',\n", " 'probability_label': 'NEGLIGIBLE'},\n", - " {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',\n", - " 'probability_label': 'NEGLIGIBLE'}]}\n" + " {'blocked': False,\n", + " 'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT',\n", + " 'probability_label': 'NEGLIGIBLE'}],\n", + " 'usage_metadata': {'candidates_token_count': 6,\n", + " 'prompt_token_count': 12,\n", + " 'total_token_count': 18}}\n" ] } ], "source": [ - "from pprint import pprint\n", - "\n", - "from langchain_core.messages import HumanMessage\n", - "from langchain_google_vertexai import ChatVertexAI, HarmBlockThreshold, HarmCategory\n", - "\n", "human = \"Translate this sentence from English to French. I love programming.\"\n", "messages = [HumanMessage(content=human)]\n", "\n", @@ -315,18 +330,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "{'is_blocked': False,\n", - " 'safety_attributes': {'Derogatory': 0.1,\n", - " 'Finance': 0.3,\n", - " 'Insult': 0.1,\n", - " 'Sexual': 0.1}}\n" + "{'errors': (),\n", + " 'grounding_metadata': {'citations': [], 'search_queries': []},\n", + " 'is_blocked': False,\n", + " 'safety_attributes': [{'Derogatory': 0.1, 'Insult': 0.1, 'Sexual': 0.2}],\n", + " 'usage_metadata': {'candidates_billable_characters': 88.0,\n", + " 'candidates_token_count': 24.0,\n", + " 'prompt_billable_characters': 58.0,\n", + " 'prompt_token_count': 12.0}}\n" ] } ], @@ -341,40 +359,149 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Function Calling with Gemini\n", + "## Tool calling (a.k.a. function calling) with Gemini\n", "\n", - "We can call Gemini models with tools." + "We can pass tool definitions to Gemini models to get the model to invoke those tools when appropriate. This is useful not only for LLM-powered tool use but also for getting structured outputs out of models more generally.\n", + "\n", + "With `ChatVertexAI.bind_tools()`, we can easily pass in Pydantic classes, dict schemas, LangChain tools, or even functions as tools to the model. Under the hood these are converted to a Gemini tool schema, which looks like:\n", + "```python\n", + "{\n", + " \"name\": \"...\", # tool name\n", + " \"description\": \"...\", # tool description\n", + " \"parameters\": {...} # tool input schema as JSONSchema\n", + "}\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='', additional_kwargs={'function_call': {'name': 'GetWeather', 'arguments': '{\"location\": \"San Francisco, CA\"}'}}, response_metadata={'is_blocked': False, 'safety_ratings': [{'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability_label': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability_label': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability_label': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability_label': 'NEGLIGIBLE', 'blocked': False}], 'citation_metadata': None, 'usage_metadata': {'prompt_token_count': 41, 'candidates_token_count': 7, 'total_token_count': 48}}, id='run-05e760dc-0682-4286-88e1-5b23df69b083-0', tool_calls=[{'name': 'GetWeather', 'args': {'location': 'San Francisco, CA'}, 'id': 'cd2499c4-4513-4059-bfff-5321b6e922d0'}])" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain.pydantic_v1 import BaseModel, Field\n", + "\n", + "\n", + "class GetWeather(BaseModel):\n", + " \"\"\"Get the current weather in a given location\"\"\"\n", + "\n", + " location: str = Field(..., description=\"The city and state, e.g. San Francisco, CA\")\n", + "\n", + "\n", + "llm = ChatVertexAI(model_name=\"gemini-pro\", temperature=0)\n", + "llm_with_tools = llm.bind_tools([GetWeather])\n", + "ai_msg = llm_with_tools.invoke(\n", + " \"what is the weather like in San Francisco\",\n", + ")\n", + "ai_msg" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The tool calls can be access via the `AIMessage.tool_calls` attribute, where they are extracted in a model-agnostic format:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'name': 'GetWeather',\n", + " 'args': {'location': 'San Francisco, CA'},\n", + " 'id': 'cd2499c4-4513-4059-bfff-5321b6e922d0'}]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ai_msg.tool_calls" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For a complete guide on tool calling [head here](/docs/modules/model_io/chat/function_calling/)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Structured outputs\n", + "\n", + "Many applications require structured model outputs. Tool calling makes it much easier to do this reliably. The [with_structured_outputs](https://api.python.langchain.com/en/latest/chat_models/langchain_google_vertexai.chat_models.ChatVertexAI.html) constructor provides a simple interface built on top of tool calling for getting structured outputs out of a model. For a complete guide on structured outputs [head here](/docs/modules/model_io/chat/structured_output/).\n", + "\n", + "### ChatVertexAI.with_structured_outputs()\n", + "\n", + "To get structured outputs from our Gemini model all we need to do is to specify a desired schema, either as a Pydantic class or as a JSON schema, " + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Person(name='Stefan', age=13)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class Person(BaseModel):\n", + " \"\"\"Save information about a person.\"\"\"\n", + "\n", + " name: str = Field(..., description=\"The person's name.\")\n", + " age: int = Field(..., description=\"The person's age.\")\n", + "\n", + "\n", + "structured_llm = llm.with_structured_output(Person)\n", + "structured_llm.invoke(\"Stefan is already 13 years old\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### [Legacy] Using `create_structured_runnable()`\n", + "\n", + "The legacy wasy to get structured outputs is using the `create_structured_runnable` constructor:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "MyModel(name='Erick', age=27)" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "from langchain.pydantic_v1 import BaseModel\n", "from langchain_google_vertexai import create_structured_runnable\n", "\n", - "llm = ChatVertexAI(model_name=\"gemini-pro\")\n", - "\n", - "\n", - "class MyModel(BaseModel):\n", - " name: str\n", - " age: int\n", - "\n", - "\n", - "chain = create_structured_runnable(MyModel, llm)\n", + "chain = create_structured_runnable(Person, llm)\n", "chain.invoke(\"My name is Erick and I'm 27 years old\")" ] }, @@ -484,11 +611,21 @@ ], "metadata": { "kernelspec": { - "display_name": "", - "name": "" + "display_name": "poetry-venv-2", + "language": "python", + "name": "poetry-venv-2" }, "language_info": { - "name": "python" + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" } }, "nbformat": 4,