langchain/cookbook/azure_container_apps_dynamic_sessions_data_analyst.ipynb
2024-06-10 13:33:40 -07:00

827 lines
242 KiB
Plaintext
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"id": "4153b116-206b-40f8-a684-bf082c5ebcea",
"metadata": {},
"source": [
"# Building a data analyst agent with LangGraph and Azure Container Apps dynamic sessions\n",
"\n",
"In this example we'll build an agent that can query a Postgres database and run Python code to analyze the retrieved data. We'll use [LangGraph](https://langchain-ai.github.io/langgraph/) for agent orchestration and [Azure Container Apps dynamic sessions](https://python.langchain.com/v0.2/docs/integrations/tools/azure_dynamic_sessions/) for safe Python code execution.\n",
"\n",
"**NOTE**: Building LLM systems that interact with SQL databases requires executing model-generated SQL queries. There are inherent risks in doing this. Make sure that your database connection permissions are always scoped as narrowly as possible for your agent's needs. This will mitigate though not eliminate the risks of building a model-driven system. For more on general security best practices, see our [security guidelines](https://python.langchain.com/v0.2/docs/security/)."
]
},
{
"cell_type": "markdown",
"id": "3b70c2be-1141-4107-80db-787f7935102f",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"Let's get set up by installing our Python dependencies and setting our OpenAI credentials, Azure Container Apps sessions pool endpoint, and our SQL database connection string.\n",
"\n",
"### Install dependencies"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "302f827f-062c-4b83-8239-07b28bfc9651",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install -qU langgraph langchain-azure-dynamic-sessions langchain-openai langchain-community pandas matplotlib"
]
},
{
"cell_type": "markdown",
"id": "7621655b-605c-4690-8ee1-77a4bab8b383",
"metadata": {},
"source": [
"### Set credentials\n",
"\n",
"By default this demo uses:\n",
"- Azure OpenAI for the model: https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/create-resource\n",
"- Azure PostgreSQL for the db: https://learn.microsoft.com/en-us/cli/azure/postgres/server?view=azure-cli-latest#az-postgres-server-create\n",
"- Azure Container Apps dynamic sessions for code execution: https://learn.microsoft.com/en-us/azure/container-apps/sessions-code-interpreter?\n",
"\n",
"This LangGraph architecture can also be used with any other [tool-calling LLM](https://python.langchain.com/v0.2/docs/how_to/tool_calling/) and any SQL database."
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "be7c74d8-485b-4c51-aded-07e8af838efe",
"metadata": {},
"outputs": [
{
"name": "stdin",
"output_type": "stream",
"text": [
"Azure OpenAI API key ········\n",
"Azure OpenAI endpoint ········\n",
"Azure OpenAI deployment name ········\n",
"Azure Container Apps dynamic sessions pool management endpoint ········\n",
"PostgreSQL connection string ········\n"
]
}
],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"AZURE_OPENAI_API_KEY\"] = getpass.getpass(\"Azure OpenAI API key\")\n",
"os.environ[\"AZURE_OPENAI_ENDPOINT\"] = getpass.getpass(\"Azure OpenAI endpoint\")\n",
"\n",
"AZURE_OPENAI_DEPLOYMENT_NAME = getpass.getpass(\"Azure OpenAI deployment name\")\n",
"SESSIONS_POOL_MANAGEMENT_ENDPOINT = getpass.getpass(\n",
" \"Azure Container Apps dynamic sessions pool management endpoint\"\n",
")\n",
"SQL_DB_CONNECTION_STRING = getpass.getpass(\"PostgreSQL connection string\")"
]
},
{
"cell_type": "markdown",
"id": "3712a7b0-3f7d-4d90-9319-febf7b046aa6",
"metadata": {},
"source": [
"### Imports"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "09c0a46e-a8b4-44e3-8d90-2e5d0f66c1ad",
"metadata": {},
"outputs": [],
"source": [
"import ast\n",
"import base64\n",
"import io\n",
"import json\n",
"import operator\n",
"from functools import partial\n",
"from typing import Annotated, List, Literal, Optional, Sequence, TypedDict\n",
"\n",
"import pandas as pd\n",
"from IPython.display import display\n",
"from langchain_azure_dynamic_sessions import SessionsPythonREPLTool\n",
"from langchain_community.utilities import SQLDatabase\n",
"from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
"from langchain_core.runnables import RunnablePassthrough\n",
"from langchain_core.tools import tool\n",
"from langchain_openai import AzureChatOpenAI\n",
"from langgraph.graph import END, StateGraph\n",
"from langgraph.prebuilt import ToolNode\n",
"from matplotlib.pyplot import imshow\n",
"from PIL import Image"
]
},
{
"cell_type": "markdown",
"id": "5cc14582-313c-4a61-be5e-a7a1ba26a6e0",
"metadata": {},
"source": [
"## Instantiate model, DB, code interpreter\n",
"\n",
"We'll use the LangChain [SQLDatabase](https://api.python.langchain.com/en/latest/utilities/langchain_community.utilities.sql_database.SQLDatabase.html#langchain_community.utilities.sql_database.SQLDatabase) interface to connect to our DB and query it. This works with any SQL database supported by [SQLAlchemy](https://www.sqlalchemy.org/)."
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "9262ea34-c6ac-407c-96c3-aa5eaa1a8039",
"metadata": {},
"outputs": [],
"source": [
"db = SQLDatabase.from_uri(SQL_DB_CONNECTION_STRING)"
]
},
{
"cell_type": "markdown",
"id": "1982c6f2-aa4e-4842-83f2-951205aa0854",
"metadata": {},
"source": [
"For our LLM we need to make sure that we use a model that supports [tool-calling](https://python.langchain.com/v0.2/docs/how_to/tool_calling/)."
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "ba6201a1-d760-45f1-b14a-bf8d85ceb775",
"metadata": {},
"outputs": [],
"source": [
"llm = AzureChatOpenAI(\n",
" deployment_name=AZURE_OPENAI_DEPLOYMENT_NAME, openai_api_version=\"2024-02-01\"\n",
")"
]
},
{
"cell_type": "markdown",
"id": "92e2fcc7-812a-4d18-852f-2f814559b415",
"metadata": {},
"source": [
"And the [dynamic sessions tool](https://python.langchain.com/v0.2/docs/integrations/tools/azure_container_apps_dynamic_sessions/) is what we'll use for code execution."
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "89e5a315-c964-493d-84fb-1f453909caae",
"metadata": {},
"outputs": [],
"source": [
"repl = SessionsPythonREPLTool(\n",
" pool_management_endpoint=SESSIONS_POOL_MANAGEMENT_ENDPOINT\n",
")"
]
},
{
"cell_type": "markdown",
"id": "ee084fbd-10d3-4328-9d8c-75ffa9437b31",
"metadata": {},
"source": [
"## Define graph\n",
"\n",
"Now we're ready to define our application logic. The core elements are the [agent State, Nodes, and Edges](https://langchain-ai.github.io/langgraph/concepts/#core-design).\n",
"\n",
"### Define State\n",
"We'll use a simple agent State which is just a list of messages that every Node can append to:"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "7feef65d-bf11-41bb-9164-5249953eb02e",
"metadata": {},
"outputs": [],
"source": [
"class AgentState(TypedDict):\n",
" messages: Annotated[Sequence[BaseMessage], operator.add]"
]
},
{
"cell_type": "markdown",
"id": "58fe92a3-9a30-464b-bcf3-972af5b92e40",
"metadata": {},
"source": [
"Since our code interpreter can return results like base64-encoded images which we don't want to pass back to the model, we'll create a custom Tool message that allows us to track raw Tool outputs without sending them back to the model."
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "36e2d8a2-8881-40bc-81da-b40e8a152d9d",
"metadata": {},
"outputs": [],
"source": [
"class RawToolMessage(ToolMessage):\n",
" \"\"\"\n",
" Customized Tool message that lets us pass around the raw tool outputs (along with string contents for passing back to the model).\n",
" \"\"\"\n",
"\n",
" raw: dict\n",
" \"\"\"Arbitrary (non-string) tool outputs. Won't be sent to model.\"\"\"\n",
" tool_name: str\n",
" \"\"\"Name of tool that generated output.\"\"\""
]
},
{
"cell_type": "markdown",
"id": "ad1b681c-c918-4dfe-b671-9d6eee457a51",
"metadata": {},
"source": [
"### Define Nodes"
]
},
{
"cell_type": "markdown",
"id": "966aeec1-b930-442c-9ba3-d8ad3800d2a4",
"metadata": {},
"source": [
"First we'll define a node for calling our model. We need to make sure to bind our tools to the model so that it knows to call them. We'll also specify in our prompt the schema of the SQL tables the model has access to, so that it can write relevant SQL queries."
]
},
{
"cell_type": "markdown",
"id": "88f15581-11f6-4421-aa17-5762a84c8032",
"metadata": {},
"source": [
"We'll use our models tool-calling abilities to reliably generate our SQL queries and Python code. To do this we need to define schemas for our tools that the model can use for structuring its tool calls.\n",
"\n",
"Note that the class names, docstrings, and attribute typing and descriptions are crucial here, as they're actually passed in to the model (you can effectively think of them as part of the prompt)."
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "390f170b-ba13-41fc-8c9b-ee0efdb13b98",
"metadata": {},
"outputs": [],
"source": [
"# Tool schema for querying SQL db\n",
"class create_df_from_sql(BaseModel):\n",
" \"\"\"Execute a PostgreSQL SELECT statement and use the results to create a DataFrame with the given column names.\"\"\"\n",
"\n",
" select_query: str = Field(..., description=\"A PostgreSQL SELECT statement.\")\n",
" # We're going to convert the results to a Pandas DataFrame that we pass\n",
" # to the code intepreter, so we also have the model generate useful column and\n",
" # variable names for this DataFrame that the model will refer to when writing\n",
" # python code.\n",
" df_columns: List[str] = Field(\n",
" ..., description=\"Ordered names to give the DataFrame columns.\"\n",
" )\n",
" df_name: str = Field(\n",
" ..., description=\"The name to give the DataFrame variable in downstream code.\"\n",
" )\n",
"\n",
"\n",
"# Tool schema for writing Python code\n",
"class python_shell(BaseModel):\n",
" \"\"\"Execute Python code that analyzes the DataFrames that have been generated. Make sure to print any important results.\"\"\"\n",
"\n",
" code: str = Field(\n",
" ...,\n",
" description=\"The code to execute. Make sure to print any important results.\",\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "a98cf69a-e25b-4016-a565-aa16e43e417a",
"metadata": {},
"outputs": [],
"source": [
"system_prompt = f\"\"\"\\\n",
"You are an expert at PostgreSQL and Python. You have access to a PostgreSQL database \\\n",
"with the following tables\n",
"\n",
"{db.table_info}\n",
"\n",
"Given a user question related to the data in the database, \\\n",
"first get the relevant data from the table as a DataFrame using the create_df_from_sql tool. Then use the \\\n",
"python_shell to do any analysis required to answer the user question.\"\"\"\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\"system\", system_prompt),\n",
" (\"placeholder\", \"{messages}\"),\n",
" ]\n",
")\n",
"\n",
"\n",
"def call_model(state: AgentState) -> dict:\n",
" \"\"\"Call model with tools passed in.\"\"\"\n",
" messages = []\n",
"\n",
" chain = prompt | llm.bind_tools([create_df_from_sql, python_shell])\n",
" messages.append(chain.invoke({\"messages\": state[\"messages\"]}))\n",
"\n",
" return {\"messages\": messages}"
]
},
{
"cell_type": "markdown",
"id": "4e87c72e-7f9e-4377-94c9-abd9fb869866",
"metadata": {},
"source": [
"Now we can define the node for executing any SQL queries that were generated by the model. Notice that after we run the query we convert the results into Pandas DataFrames — these will be uploaded the the code interpreter tool in the next step so that it can use the retrieved data."
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "a229efba-e981-4403-a37c-ab030c929ea4",
"metadata": {},
"outputs": [],
"source": [
"def execute_sql_query(state: AgentState) -> dict:\n",
" \"\"\"Execute the latest SQL queries.\"\"\"\n",
" messages = []\n",
"\n",
" for tool_call in state[\"messages\"][-1].tool_calls:\n",
" if tool_call[\"name\"] != \"create_df_from_sql\":\n",
" continue\n",
"\n",
" # Execute SQL query\n",
" res = db.run(tool_call[\"args\"][\"select_query\"], fetch=\"cursor\").fetchall()\n",
"\n",
" # Convert result to Pandas DataFrame\n",
" df_columns = tool_call[\"args\"][\"df_columns\"]\n",
" df = pd.DataFrame(res, columns=df_columns)\n",
" df_name = tool_call[\"args\"][\"df_name\"]\n",
"\n",
" # Add tool output message\n",
" messages.append(\n",
" RawToolMessage(\n",
" f\"Generated dataframe {df_name} with columns {df_columns}\", # What's sent to model.\n",
" raw={df_name: df},\n",
" tool_call_id=tool_call[\"id\"],\n",
" tool_name=tool_call[\"name\"],\n",
" )\n",
" )\n",
"\n",
" return {\"messages\": messages}"
]
},
{
"cell_type": "markdown",
"id": "7a67eaaf-1587-4f32-ab5c-e1a04d273c3e",
"metadata": {},
"source": [
"Now we need a node for executing any model-generated Python code. The key steps here are:\n",
"- Uploading queried data to the code intepreter\n",
"- Executing model generated code\n",
"- Parsing results so that images are displayed and not passed in to future model calls\n",
"\n",
"To upload the queried data to the model we can take our DataFrames we generated by executing the SQL queries and upload them as CSVs to our code intepreter."
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "450c1dd0-4fe4-4ab7-b1d7-e012c3cf0102",
"metadata": {},
"outputs": [],
"source": [
"def _upload_dfs_to_repl(state: AgentState) -> str:\n",
" \"\"\"\n",
" Upload generated dfs to code intepreter and return code for loading them.\n",
"\n",
" Note that code intepreter sessions are short-lived so this needs to be done\n",
" every agent cycle, even if the dfs were previously uploaded.\n",
" \"\"\"\n",
" df_dicts = [\n",
" msg.raw\n",
" for msg in state[\"messages\"]\n",
" if isinstance(msg, RawToolMessage) and msg.tool_name == \"create_df_from_sql\"\n",
" ]\n",
" name_df_map = {name: df for df_dict in df_dicts for name, df in df_dict.items()}\n",
"\n",
" # Data should be uploaded as a BinaryIO.\n",
" # Files will be uploaded to the \"/mnt/data/\" directory on the container.\n",
" for name, df in name_df_map.items():\n",
" buffer = io.StringIO()\n",
" df.to_csv(buffer)\n",
" buffer.seek(0)\n",
" repl.upload_file(data=buffer, remote_file_path=name + \".csv\")\n",
"\n",
" # Code for loading the uploaded files.\n",
" df_code = \"import pandas as pd\\n\" + \"\\n\".join(\n",
" f\"{name} = pd.read_csv('/mnt/data/{name}.csv')\" for name in name_df_map\n",
" )\n",
" return df_code\n",
"\n",
"\n",
"def _repl_result_to_msg_content(repl_result: dict) -> str:\n",
" \"\"\"\n",
" Display images with including them in tool message content.\n",
" \"\"\"\n",
" content = {}\n",
" for k, v in repl_result.items():\n",
" # Any image results are returned as a dict of the form:\n",
" # {\"type\": \"image\", \"base64_data\": \"...\"}\n",
" if isinstance(repl_result[k], dict) and repl_result[k][\"type\"] == \"image\":\n",
" # Decode and display image\n",
" base64_str = repl_result[k][\"base64_data\"]\n",
" img = Image.open(io.BytesIO(base64.decodebytes(bytes(base64_str, \"utf-8\"))))\n",
" display(img)\n",
" else:\n",
" content[k] = repl_result[k]\n",
" return json.dumps(content, indent=2)\n",
"\n",
"\n",
"def execute_python(state: AgentState) -> dict:\n",
" \"\"\"\n",
" Execute the latest generated Python code.\n",
" \"\"\"\n",
" messages = []\n",
"\n",
" df_code = _upload_dfs_to_repl(state)\n",
" last_ai_msg = [msg for msg in state[\"messages\"] if isinstance(msg, AIMessage)][-1]\n",
" for tool_call in last_ai_msg.tool_calls:\n",
" if tool_call[\"name\"] != \"python_shell\":\n",
" continue\n",
"\n",
" generated_code = tool_call[\"args\"][\"code\"]\n",
" repl_result = repl.execute(df_code + \"\\n\" + generated_code)\n",
"\n",
" messages.append(\n",
" RawToolMessage(\n",
" _repl_result_to_msg_content(repl_result),\n",
" raw=repl_result,\n",
" tool_call_id=tool_call[\"id\"],\n",
" tool_name=tool_call[\"name\"],\n",
" )\n",
" )\n",
" return {\"messages\": messages}"
]
},
{
"cell_type": "markdown",
"id": "dd530250-60b6-40fb-b1f8-2ff32967ecc8",
"metadata": {},
"source": [
"### Define Edges\n",
"\n",
"Now we're ready to put all the pieces together into a graph."
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "a04e0a82-1c3e-46d3-95ea-2461c21202ef",
"metadata": {},
"outputs": [],
"source": [
"def should_continue(state: AgentState) -> str:\n",
" \"\"\"\n",
" If any Tool messages were generated in the last cycle that means we need to call the model again to interpret the latest results.\n",
" \"\"\"\n",
" return \"execute_sql_query\" if state[\"messages\"][-1].tool_calls else END"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "b2857ba9-da80-443f-8217-ac0523f90593",
"metadata": {},
"outputs": [],
"source": [
"workflow = StateGraph(AgentState)\n",
"\n",
"workflow.add_node(\"call_model\", call_model)\n",
"workflow.add_node(\"execute_sql_query\", execute_sql_query)\n",
"workflow.add_node(\"execute_python\", execute_python)\n",
"\n",
"workflow.set_entry_point(\"call_model\")\n",
"workflow.add_edge(\"execute_sql_query\", \"execute_python\")\n",
"workflow.add_edge(\"execute_python\", \"call_model\")\n",
"workflow.add_conditional_edges(\"call_model\", should_continue)\n",
"\n",
"app = workflow.compile()"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "74dc8c6c-b520-4f17-88ec-fa789ed911e6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" +-----------+ \n",
" | __start__ | \n",
" +-----------+ \n",
" * \n",
" * \n",
" * \n",
" +------------+ \n",
" ...| call_model |*** \n",
" ....... +------------+ ******* \n",
" ........ .. ... ******* \n",
" ....... .. ... ****** \n",
" .... .. .. ******* \n",
"+---------+ +-------------------+ .. **** \n",
"| __end__ | | execute_sql_query | . **** \n",
"+---------+ +-------------------+* . **** \n",
" ***** . ***** \n",
" **** . **** \n",
" *** . *** \n",
" +----------------+ \n",
" | execute_python | \n",
" +----------------+ \n"
]
}
],
"source": [
"print(app.get_graph().draw_ascii())"
]
},
{
"cell_type": "markdown",
"id": "6d4e079b-0cf8-4f9d-a52b-6a8f980eee4b",
"metadata": {},
"source": [
"## Test it out\n",
"\n",
"Replace these examples with questions related to the database you've connected your agent to."
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "2c173d6d-a212-448e-b309-299e87f205b8",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<PIL.PngImagePlugin.PngImageFile image mode=RGBA size=989x590>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"The graph of the average latency by model has been generated successfully. However, it seems that the output is not displayed here directly. To view the graph, you would typically run the provided Python code in an environment where graphical output is supported, such as a Jupyter notebook or a Python script executed in a local environment with access to a display server.\n"
]
}
],
"source": [
"output = app.invoke({\"messages\": [(\"human\", \"graph the average latency by model\")]})\n",
"print(output[\"messages\"][-1].content)"
]
},
{
"cell_type": "markdown",
"id": "a67fbc65-2161-4518-9eea-f0cdd99b5f59",
"metadata": {},
"source": [
"**LangSmith Trace**: https://smith.langchain.com/public/9c8afcce-0ed1-4fb1-b719-767e6432bd8e/r"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "1d512f95-7490-483e-a748-abf708fbd20c",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<PIL.PngImagePlugin.PngImageFile image mode=RGBA size=571x453>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"The correlation coefficient between the number of prompt tokens and latency is approximately 0.305, indicating a positive but relatively weak relationship. This suggests that as the number of input tokens increases, there tends to be an increase in latency, but the relationship is not strong and other factors may also influence latency.\n",
"\n",
"Here is the scatter plot showing the relationship visually:\n",
"\n",
"![Scatter Plot of Prompt Tokens and Latency](sandbox:/2)\n"
]
}
],
"source": [
"output = app.invoke(\n",
" {\n",
" \"messages\": [\n",
" (\"human\", \"what's the relationship between latency and input tokens?\")\n",
" ]\n",
" }\n",
")\n",
"print(output[\"messages\"][-1].content)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "10071b83-19c6-468d-b5fc-600b42cd57ac",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<PIL.PngImagePlugin.PngImageFile image mode=RGBA size=670x453>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Continue the conversation\n",
"output = app.invoke(\n",
" {\"messages\": output[\"messages\"] + [(\"human\", \"now control for model\")]}\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "81fb6102-c427-41c1-97cf-54e5944d1c79",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"After controlling for each model, here are the individual correlations between prompt tokens and latency:\n",
"\n",
"- `anthropic_claude_3_sonnet`: Correlation = 0.7659\n",
"- `openai_gpt_3_5_turbo`: Correlation = 0.2833\n",
"- `fireworks_mixtral`: Correlation = 0.1673\n",
"- `cohere_command`: Correlation = 0.1434\n",
"- `google_gemini_pro`: Correlation = 0.4928\n",
"\n",
"These correlations indicate that the `anthropic_claude_3_sonnet` model has the strongest positive correlation between the number of prompt tokens and latency, while the `cohere_command` model has the weakest positive correlation.\n",
"\n",
"Scatter plots were generated for each model individually to illustrate the relationship between prompt tokens and latency. Below are the plots for each model:\n",
"\n",
"1. Model: anthropic_claude_3_sonnet\n",
"![Scatter Plot for anthropic_claude_3_sonnet](sandbox:/2)\n",
"\n",
"2. Model: openai_gpt_3_5_turbo\n",
"![Scatter Plot for openai_gpt_3_5_turbo](sandbox:/2)\n",
"\n",
"3. Model: fireworks_mixtral\n",
"![Scatter Plot for fireworks_mixtral](sandbox:/2)\n",
"\n",
"4. Model: cohere_command\n",
"![Scatter Plot for cohere_command](sandbox:/2)\n",
"\n",
"5. Model: google_gemini_pro\n",
"![Scatter Plot for google_gemini_pro](sandbox:/2)\n",
"\n",
"The plots and correlations together provide an understanding of how latency changes with the number of prompt tokens for each model.\n"
]
}
],
"source": [
"print(output[\"messages\"][-1].content)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "09167fa6-132a-4696-a4ee-eda80a41d3dd",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<PIL.PngImagePlugin.PngImageFile image mode=RGBA size=703x510>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"output = app.invoke(\n",
" {\n",
" \"messages\": output[\"messages\"]\n",
" + [(\"human\", \"what about latency vs output tokens\")]\n",
" }\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "f0c48828-07ae-43df-b27f-14fdfbd835f6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The correlation between the number of output tokens (completion_tokens) and latency varies by model, as shown below:\n",
"\n",
"- `anthropic_claude_3_sonnet`: Correlation = 0.910274\n",
"- `cohere_command`: Correlation = 0.910292\n",
"- `fireworks_mixtral`: Correlation = 0.681286\n",
"- `google_gemini_pro`: Correlation = 0.151549\n",
"- `openai_gpt_3_5_turbo`: Correlation = 0.449127\n",
"\n",
"The `anthropic_claude_3_sonnet` and `cohere_command` models show a very strong positive correlation, indicating that an increase in the number of output tokens is associated with a substantial increase in latency for these models. The `fireworks_mixtral` model also shows a strong positive correlation, but less strong than the first two. The `google_gemini_pro` model shows a weak positive correlation, and the `openai_gpt_3_5_turbo` model shows a moderate positive correlation.\n",
"\n",
"Below is the scatter plot with a regression line showing the relationship between output tokens and latency for each model:\n",
"\n",
"![Scatter Plot with Regression Line for Each Model](sandbox:/2)\n"
]
}
],
"source": [
"print(output[\"messages\"][-1].content)"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "4114c16d-c727-49c2-beb1-27c5982b0948",
"metadata": {},
"outputs": [],
"source": [
"output = app.invoke(\n",
" {\n",
" \"messages\": [\n",
" (\n",
" \"human\",\n",
" \"what's the better explanatory variable for latency: input or output tokens?\",\n",
" )\n",
" ]\n",
" }\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "7f983c4a-60b6-4dd6-ab22-2b59971e2fcd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The correlation between input tokens and latency is 0.305, while the correlation between output tokens and latency is 0.487. Therefore, the better explanatory variable for latency is output tokens.\n"
]
}
],
"source": [
"print(output[\"messages\"][-1].content)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "poetry-venv-2",
"language": "python",
"name": "poetry-venv-2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}