mirror of
https://github.com/openai/openai-cookbook
synced 2024-11-11 13:11:02 +00:00
785 lines
72 KiB
Plaintext
785 lines
72 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "xs8_Q1nPwxlK"
|
|
},
|
|
"source": [
|
|
"# Structured Outputs for Multi-Agent Systems\n",
|
|
"\n",
|
|
"In this cookbook, we will explore how to use Structured Outputs to build multi-agent systems.\n",
|
|
"\n",
|
|
"Structured Outputs is a new capability that builds upon JSON mode and function calling to enforce a strict schema in a model output.\n",
|
|
"\n",
|
|
"By using the new parameter `strict: true`, we are able to guarantee the response abides by a provided schema.\n",
|
|
"\n",
|
|
"To demonstrate the power of this feature, we will use it to build a multi-agent system.\n",
|
|
"\n",
|
|
"### Why build a Multi-Agent System?\n",
|
|
"\n",
|
|
"When using function calling, if the number of functions (or tools) increases, the performance may suffer.\n",
|
|
"\n",
|
|
"To mitigate this, we can logically group the tools together and have specialized \"agents\" that are able to solve specific tasks or sub-tasks, which will increase the overall system performance."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "7SLTnfLRKVnP"
|
|
},
|
|
"source": [
|
|
"## Environment set up"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 52,
|
|
"metadata": {
|
|
"id": "UCySx7jT6T7Y"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from openai import OpenAI\n",
|
|
"from IPython.display import Image\n",
|
|
"import json\n",
|
|
"import pandas as pd\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"from io import StringIO\n",
|
|
"import numpy as np\n",
|
|
"client = OpenAI()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 53,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"MODEL = \"gpt-4o-2024-08-06\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "X34-4ZYyK6-S"
|
|
},
|
|
"source": [
|
|
"## Agents set up\n",
|
|
"\n",
|
|
"The use case we will tackle is a data analysis task.\n",
|
|
"\n",
|
|
"Let's first set up our 4-agents system:\n",
|
|
"\n",
|
|
"1. **Triaging agent:** Decides which agent(s) to call\n",
|
|
"2. **Data pre-processing Agent:** Prepares data for analysis - for example by cleaning it up\n",
|
|
"3. **Data Analysis Agent:** Performs analysis on the data\n",
|
|
"4. **Data Visualization Agent:** Visualizes the output of the analysis to extract insights\n",
|
|
"\n",
|
|
"We will start by defining the system prompts for each of these agents."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 54,
|
|
"metadata": {
|
|
"id": "CewlAQuhKUIe"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"triaging_system_prompt = \"\"\"You are a Triaging Agent. Your role is to assess the user's query and route it to the relevant agents. The agents available are:\n",
|
|
"- Data Processing Agent: Cleans, transforms, and aggregates data.\n",
|
|
"- Analysis Agent: Performs statistical, correlation, and regression analysis.\n",
|
|
"- Visualization Agent: Creates bar charts, line charts, and pie charts.\n",
|
|
"\n",
|
|
"Use the send_query_to_agents tool to forward the user's query to the relevant agents. Also, use the speak_to_user tool to get more information from the user if needed.\"\"\"\n",
|
|
"\n",
|
|
"processing_system_prompt = \"\"\"You are a Data Processing Agent. Your role is to clean, transform, and aggregate data using the following tools:\n",
|
|
"- clean_data\n",
|
|
"- transform_data\n",
|
|
"- aggregate_data\"\"\"\n",
|
|
"\n",
|
|
"analysis_system_prompt = \"\"\"You are an Analysis Agent. Your role is to perform statistical, correlation, and regression analysis using the following tools:\n",
|
|
"- stat_analysis\n",
|
|
"- correlation_analysis\n",
|
|
"- regression_analysis\"\"\"\n",
|
|
"\n",
|
|
"visualization_system_prompt = \"\"\"You are a Visualization Agent. Your role is to create bar charts, line charts, and pie charts using the following tools:\n",
|
|
"- create_bar_chart\n",
|
|
"- create_line_chart\n",
|
|
"- create_pie_chart\"\"\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "vkpZ409POhiS"
|
|
},
|
|
"source": [
|
|
"We will then define the tools for each agent.\n",
|
|
"\n",
|
|
"Apart from the triaging agent, each agent will be equipped with tools specific to their role:\n",
|
|
"\n",
|
|
"#### Data pre-processing agent\n",
|
|
"\n",
|
|
"\n",
|
|
"1. Clean data\n",
|
|
"2. Transform data\n",
|
|
"3. Aggregate data\n",
|
|
"\n",
|
|
"#### Data analysis agent\n",
|
|
"\n",
|
|
"1. Statistical analysis\n",
|
|
"2. Correlation analysis\n",
|
|
"3. Regression Analysis\n",
|
|
"\n",
|
|
"#### Data visualization agent\n",
|
|
"\n",
|
|
"1. Create bar chart\n",
|
|
"2. Create line chart\n",
|
|
"3. Create pie chart"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 55,
|
|
"metadata": {
|
|
"id": "MzBvgBliOc9Y"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"triage_tools = [\n",
|
|
" {\n",
|
|
" \"type\": \"function\",\n",
|
|
" \"function\": {\n",
|
|
" \"name\": \"send_query_to_agents\",\n",
|
|
" \"description\": \"Sends the user query to relevant agents based on their capabilities.\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"agents\": {\n",
|
|
" \"type\": \"array\",\n",
|
|
" \"items\": {\"type\": \"string\"},\n",
|
|
" \"description\": \"An array of agent names to send the query to.\"\n",
|
|
" },\n",
|
|
" \"query\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The user query to send.\"\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"required\": [\"agents\", \"query\"]\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"strict\": True\n",
|
|
" }\n",
|
|
"]\n",
|
|
"\n",
|
|
"preprocess_tools = [\n",
|
|
" {\n",
|
|
" \"type\": \"function\",\n",
|
|
" \"function\": {\n",
|
|
" \"name\": \"clean_data\",\n",
|
|
" \"description\": \"Cleans the provided data by removing duplicates and handling missing values.\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"data\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The dataset to clean. Should be in a suitable format such as JSON or CSV.\"\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"required\": [\"data\"],\n",
|
|
" \"additionalProperties\": False\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"strict\": True\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"type\": \"function\",\n",
|
|
" \"function\": {\n",
|
|
" \"name\": \"transform_data\",\n",
|
|
" \"description\": \"Transforms data based on specified rules.\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"data\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The data to transform. Should be in a suitable format such as JSON or CSV.\"\n",
|
|
" },\n",
|
|
" \"rules\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"Transformation rules to apply, specified in a structured format.\"\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"required\": [\"data\", \"rules\"],\n",
|
|
" \"additionalProperties\": False\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"strict\": True\n",
|
|
"\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"type\": \"function\",\n",
|
|
" \"function\": {\n",
|
|
" \"name\": \"aggregate_data\",\n",
|
|
" \"description\": \"Aggregates data by specified columns and operations.\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"data\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The data to aggregate. Should be in a suitable format such as JSON or CSV.\"\n",
|
|
" },\n",
|
|
" \"group_by\": {\n",
|
|
" \"type\": \"array\",\n",
|
|
" \"items\": {\"type\": \"string\"},\n",
|
|
" \"description\": \"Columns to group by.\"\n",
|
|
" },\n",
|
|
" \"operations\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"Aggregation operations to perform, specified in a structured format.\"\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"required\": [\"data\", \"group_by\", \"operations\"],\n",
|
|
" \"additionalProperties\": False\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"strict\": True\n",
|
|
" }\n",
|
|
"]\n",
|
|
"\n",
|
|
"\n",
|
|
"analysis_tools = [\n",
|
|
" {\n",
|
|
" \"type\": \"function\",\n",
|
|
" \"function\": {\n",
|
|
" \"name\": \"stat_analysis\",\n",
|
|
" \"description\": \"Performs statistical analysis on the given dataset.\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"data\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The dataset to analyze. Should be in a suitable format such as JSON or CSV.\"\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"required\": [\"data\"],\n",
|
|
" \"additionalProperties\": False\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"strict\": True\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"type\": \"function\",\n",
|
|
" \"function\": {\n",
|
|
" \"name\": \"correlation_analysis\",\n",
|
|
" \"description\": \"Calculates correlation coefficients between variables in the dataset.\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"data\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The dataset to analyze. Should be in a suitable format such as JSON or CSV.\"\n",
|
|
" },\n",
|
|
" \"variables\": {\n",
|
|
" \"type\": \"array\",\n",
|
|
" \"items\": {\"type\": \"string\"},\n",
|
|
" \"description\": \"List of variables to calculate correlations for.\"\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"required\": [\"data\", \"variables\"],\n",
|
|
" \"additionalProperties\": False\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"strict\": True\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"type\": \"function\",\n",
|
|
" \"function\": {\n",
|
|
" \"name\": \"regression_analysis\",\n",
|
|
" \"description\": \"Performs regression analysis on the dataset.\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"data\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The dataset to analyze. Should be in a suitable format such as JSON or CSV.\"\n",
|
|
" },\n",
|
|
" \"dependent_var\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The dependent variable for regression.\"\n",
|
|
" },\n",
|
|
" \"independent_vars\": {\n",
|
|
" \"type\": \"array\",\n",
|
|
" \"items\": {\"type\": \"string\"},\n",
|
|
" \"description\": \"List of independent variables.\"\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"required\": [\"data\", \"dependent_var\", \"independent_vars\"],\n",
|
|
" \"additionalProperties\": False\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"strict\": True\n",
|
|
" }\n",
|
|
"]\n",
|
|
"\n",
|
|
"visualization_tools = [\n",
|
|
" {\n",
|
|
" \"type\": \"function\",\n",
|
|
" \"function\": {\n",
|
|
" \"name\": \"create_bar_chart\",\n",
|
|
" \"description\": \"Creates a bar chart from the provided data.\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"data\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The data for the bar chart. Should be in a suitable format such as JSON or CSV.\"\n",
|
|
" },\n",
|
|
" \"x\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"Column for the x-axis.\"\n",
|
|
" },\n",
|
|
" \"y\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"Column for the y-axis.\"\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"required\": [\"data\", \"x\", \"y\"],\n",
|
|
" \"additionalProperties\": False\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"strict\": True\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"type\": \"function\",\n",
|
|
" \"function\": {\n",
|
|
" \"name\": \"create_line_chart\",\n",
|
|
" \"description\": \"Creates a line chart from the provided data.\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"data\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The data for the line chart. Should be in a suitable format such as JSON or CSV.\"\n",
|
|
" },\n",
|
|
" \"x\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"Column for the x-axis.\"\n",
|
|
" },\n",
|
|
" \"y\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"Column for the y-axis.\"\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"required\": [\"data\", \"x\", \"y\"],\n",
|
|
" \"additionalProperties\": False\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"strict\": True\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"type\": \"function\",\n",
|
|
" \"function\": {\n",
|
|
" \"name\": \"create_pie_chart\",\n",
|
|
" \"description\": \"Creates a pie chart from the provided data.\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"data\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The data for the pie chart. Should be in a suitable format such as JSON or CSV.\"\n",
|
|
" },\n",
|
|
" \"labels\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"Column for the labels.\"\n",
|
|
" },\n",
|
|
" \"values\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"Column for the values.\"\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"required\": [\"data\", \"labels\", \"values\"],\n",
|
|
" \"additionalProperties\": False\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"strict\": True\n",
|
|
" }\n",
|
|
"]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "yh8tRZHkQJVv"
|
|
},
|
|
"source": [
|
|
"## Tool execution\n",
|
|
"\n",
|
|
"We need to write the code logic to:\n",
|
|
"- handle passing the user query to the multi-agent system\n",
|
|
"- handle the internal workings of the multi-agent system\n",
|
|
"- execute the tool calls\n",
|
|
"\n",
|
|
"For the sake of brevity, we will only define the logic for tools that are relevant to the user query."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 56,
|
|
"metadata": {
|
|
"id": "dwM_0mHZ5pXx"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Example query\n",
|
|
"\n",
|
|
"user_query = \"\"\"\n",
|
|
"Below is some data. I want you to first remove the duplicates then analyze the statistics of the data as well as plot a line chart.\n",
|
|
"\n",
|
|
"house_size (m3), house_price ($)\n",
|
|
"90, 100\n",
|
|
"80, 90\n",
|
|
"100, 120\n",
|
|
"90, 100\n",
|
|
"\"\"\"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"From the user query, we can infer that the tools we would need to call are `clean_data`, `start_analysis` and `use_line_chart`.\n",
|
|
"\n",
|
|
"We will first define the execution function which runs tool calls.\n",
|
|
"\n",
|
|
"This maps a tool call to the corresponding function. It then appends the output of the function to the conversation history."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 57,
|
|
"metadata": {
|
|
"id": "XH6wgrATUA_l"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def clean_data(data):\n",
|
|
" data_io = StringIO(data)\n",
|
|
" df = pd.read_csv(data_io, sep=\",\")\n",
|
|
" df_deduplicated = df.drop_duplicates()\n",
|
|
" return df_deduplicated\n",
|
|
"\n",
|
|
"def stat_analysis(data):\n",
|
|
" data_io = StringIO(data)\n",
|
|
" df = pd.read_csv(data_io, sep=\",\")\n",
|
|
" return df.describe()\n",
|
|
"\n",
|
|
"def plot_line_chart(data):\n",
|
|
" data_io = StringIO(data)\n",
|
|
" df = pd.read_csv(data_io, sep=\",\")\n",
|
|
" \n",
|
|
" x = df.iloc[:, 0]\n",
|
|
" y = df.iloc[:, 1]\n",
|
|
" \n",
|
|
" coefficients = np.polyfit(x, y, 1)\n",
|
|
" polynomial = np.poly1d(coefficients)\n",
|
|
" y_fit = polynomial(x)\n",
|
|
" \n",
|
|
" plt.figure(figsize=(10, 6))\n",
|
|
" plt.plot(x, y, 'o', label='Data Points')\n",
|
|
" plt.plot(x, y_fit, '-', label='Best Fit Line')\n",
|
|
" plt.title('Line Chart with Best Fit Line')\n",
|
|
" plt.xlabel(df.columns[0])\n",
|
|
" plt.ylabel(df.columns[1])\n",
|
|
" plt.legend()\n",
|
|
" plt.grid(True)\n",
|
|
" plt.show()\n",
|
|
"\n",
|
|
"# Define the function to execute the tools\n",
|
|
"def execute_tool(tool_calls, messages):\n",
|
|
" for tool_call in tool_calls:\n",
|
|
" tool_name = tool_call.function.name\n",
|
|
" tool_arguments = json.loads(tool_call.function.arguments)\n",
|
|
"\n",
|
|
" if tool_name == 'clean_data':\n",
|
|
" # Simulate data cleaning\n",
|
|
" cleaned_df = clean_data(tool_arguments['data'])\n",
|
|
" cleaned_data = {\"cleaned_data\": cleaned_df.to_dict()}\n",
|
|
" messages.append({\"role\": \"tool\", \"name\": tool_name, \"content\": json.dumps(cleaned_data)})\n",
|
|
" print('Cleaned data: ', cleaned_df)\n",
|
|
" elif tool_name == 'transform_data':\n",
|
|
" # Simulate data transformation\n",
|
|
" transformed_data = {\"transformed_data\": \"sample_transformed_data\"}\n",
|
|
" messages.append({\"role\": \"tool\", \"name\": tool_name, \"content\": json.dumps(transformed_data)})\n",
|
|
" elif tool_name == 'aggregate_data':\n",
|
|
" # Simulate data aggregation\n",
|
|
" aggregated_data = {\"aggregated_data\": \"sample_aggregated_data\"}\n",
|
|
" messages.append({\"role\": \"tool\", \"name\": tool_name, \"content\": json.dumps(aggregated_data)})\n",
|
|
" elif tool_name == 'stat_analysis':\n",
|
|
" # Simulate statistical analysis\n",
|
|
" stats_df = stat_analysis(tool_arguments['data'])\n",
|
|
" stats = {\"stats\": stats_df.to_dict()}\n",
|
|
" messages.append({\"role\": \"tool\", \"name\": tool_name, \"content\": json.dumps(stats)})\n",
|
|
" print('Statistical Analysis: ', stats_df)\n",
|
|
" elif tool_name == 'correlation_analysis':\n",
|
|
" # Simulate correlation analysis\n",
|
|
" correlations = {\"correlations\": \"sample_correlations\"}\n",
|
|
" messages.append({\"role\": \"tool\", \"name\": tool_name, \"content\": json.dumps(correlations)})\n",
|
|
" elif tool_name == 'regression_analysis':\n",
|
|
" # Simulate regression analysis\n",
|
|
" regression_results = {\"regression_results\": \"sample_regression_results\"}\n",
|
|
" messages.append({\"role\": \"tool\", \"name\": tool_name, \"content\": json.dumps(regression_results)})\n",
|
|
" elif tool_name == 'create_bar_chart':\n",
|
|
" # Simulate bar chart creation\n",
|
|
" bar_chart = {\"bar_chart\": \"sample_bar_chart\"}\n",
|
|
" messages.append({\"role\": \"tool\", \"name\": tool_name, \"content\": json.dumps(bar_chart)})\n",
|
|
" elif tool_name == 'create_line_chart':\n",
|
|
" # Simulate line chart creation\n",
|
|
" line_chart = {\"line_chart\": \"sample_line_chart\"}\n",
|
|
" messages.append({\"role\": \"tool\", \"name\": tool_name, \"content\": json.dumps(line_chart)})\n",
|
|
" plot_line_chart(tool_arguments['data'])\n",
|
|
" elif tool_name == 'create_pie_chart':\n",
|
|
" # Simulate pie chart creation\n",
|
|
" pie_chart = {\"pie_chart\": \"sample_pie_chart\"}\n",
|
|
" messages.append({\"role\": \"tool\", \"name\": tool_name, \"content\": json.dumps(pie_chart)})\n",
|
|
" return messages"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Next, we will create the tool handlers for each of the sub-agents.\n",
|
|
"\n",
|
|
"These have a unique prompt and tool set passed to the model. \n",
|
|
"\n",
|
|
"The output is then passed to an execution function which runs the tool calls.\n",
|
|
"\n",
|
|
"We will also append the messages to the conversation history."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 58,
|
|
"metadata": {
|
|
"id": "EcOGJ0AZTmkp"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Define the functions to handle each agent's processing\n",
|
|
"def handle_data_processing_agent(query, conversation_messages):\n",
|
|
" messages = [{\"role\": \"system\", \"content\": processing_system_prompt}]\n",
|
|
" messages.append({\"role\": \"user\", \"content\": query})\n",
|
|
"\n",
|
|
" response = client.chat.completions.create(\n",
|
|
" model=MODEL,\n",
|
|
" messages=messages,\n",
|
|
" temperature=0,\n",
|
|
" tools=preprocess_tools,\n",
|
|
" )\n",
|
|
"\n",
|
|
" conversation_messages.append([tool_call.function for tool_call in response.choices[0].message.tool_calls])\n",
|
|
" execute_tool(response.choices[0].message.tool_calls, conversation_messages)\n",
|
|
"\n",
|
|
"def handle_analysis_agent(query, conversation_messages):\n",
|
|
" messages = [{\"role\": \"system\", \"content\": analysis_system_prompt}]\n",
|
|
" messages.append({\"role\": \"user\", \"content\": query})\n",
|
|
"\n",
|
|
" response = client.chat.completions.create(\n",
|
|
" model=MODEL,\n",
|
|
" messages=messages,\n",
|
|
" temperature=0,\n",
|
|
" tools=analysis_tools,\n",
|
|
" )\n",
|
|
"\n",
|
|
" conversation_messages.append([tool_call.function for tool_call in response.choices[0].message.tool_calls])\n",
|
|
" execute_tool(response.choices[0].message.tool_calls, conversation_messages)\n",
|
|
"\n",
|
|
"def handle_visualization_agent(query, conversation_messages):\n",
|
|
" messages = [{\"role\": \"system\", \"content\": visualization_system_prompt}]\n",
|
|
" messages.append({\"role\": \"user\", \"content\": query})\n",
|
|
"\n",
|
|
" response = client.chat.completions.create(\n",
|
|
" model=MODEL,\n",
|
|
" messages=messages,\n",
|
|
" temperature=0,\n",
|
|
" tools=visualization_tools,\n",
|
|
" )\n",
|
|
"\n",
|
|
" conversation_messages.append([tool_call.function for tool_call in response.choices[0].message.tool_calls])\n",
|
|
" execute_tool(response.choices[0].message.tool_calls, conversation_messages)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Finally, we create the overarching tool to handle processing the user query.\n",
|
|
"\n",
|
|
"This function takes the user query, gets a response from the model and handles passing it to the other agents to execute. In addition to this, we will keep the state of the ongoing conversation."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 59,
|
|
"metadata": {
|
|
"id": "4skE5-KYI9Tw"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Function to handle user input and triaging\n",
|
|
"def handle_user_message(user_query, conversation_messages=[]):\n",
|
|
" user_message = {\"role\": \"user\", \"content\": user_query}\n",
|
|
" conversation_messages.append(user_message)\n",
|
|
"\n",
|
|
"\n",
|
|
" messages = [{\"role\": \"system\", \"content\": triaging_system_prompt}]\n",
|
|
" messages.extend(conversation_messages)\n",
|
|
"\n",
|
|
" response = client.chat.completions.create(\n",
|
|
" model=MODEL,\n",
|
|
" messages=messages,\n",
|
|
" temperature=0,\n",
|
|
" tools=triage_tools,\n",
|
|
" )\n",
|
|
"\n",
|
|
" conversation_messages.append([tool_call.function for tool_call in response.choices[0].message.tool_calls])\n",
|
|
"\n",
|
|
" for tool_call in response.choices[0].message.tool_calls:\n",
|
|
" if tool_call.function.name == 'send_query_to_agents':\n",
|
|
" agents = json.loads(tool_call.function.arguments)['agents']\n",
|
|
" query = json.loads(tool_call.function.arguments)['query']\n",
|
|
" for agent in agents:\n",
|
|
" if agent == \"Data Processing Agent\":\n",
|
|
" handle_data_processing_agent(query, conversation_messages)\n",
|
|
" elif agent == \"Analysis Agent\":\n",
|
|
" handle_analysis_agent(query, conversation_messages)\n",
|
|
" elif agent == \"Visualization Agent\":\n",
|
|
" handle_visualization_agent(query, conversation_messages)\n",
|
|
"\n",
|
|
" return conversation_messages"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "jzQAwIW_WL3k"
|
|
},
|
|
"source": [
|
|
"## Multi-agent system execution\n",
|
|
"\n",
|
|
"Finally, we run the overarching `handle_user_message` function on the user query and view the output."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 60,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "a0h10s_W49ct",
|
|
"outputId": "7e340af9-dc3d-44ba-aa0c-e613fbdcc153"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Cleaned data: house_size (m3) house_price ($)\n",
|
|
"0 90 100\n",
|
|
"1 80 90\n",
|
|
"2 100 120\n",
|
|
"Statistical Analysis: house_size house_price\n",
|
|
"count 4.000000 4.000000\n",
|
|
"mean 90.000000 102.500000\n",
|
|
"std 8.164966 12.583057\n",
|
|
"min 80.000000 90.000000\n",
|
|
"25% 87.500000 97.500000\n",
|
|
"50% 90.000000 100.000000\n",
|
|
"75% 92.500000 105.000000\n",
|
|
"max 100.000000 120.000000\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 1000x600 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[{'role': 'user',\n",
|
|
" 'content': '\\nBelow is some data. I want you to first remove the duplicates then analyze the statistics of the data as well as plot a line chart.\\n\\nhouse_size (m3), house_price ($)\\n90, 100\\n80, 90\\n100, 120\\n90, 100\\n'},\n",
|
|
" [Function(arguments='{\"agents\": [\"Data Processing Agent\"], \"query\": \"Remove duplicates from the data: house_size (m3), house_price ($)\\\\n90, 100\\\\n80, 90\\\\n100, 120\\\\n90, 100\"}', name='send_query_to_agents'),\n",
|
|
" Function(arguments='{\"agents\": [\"Analysis Agent\"], \"query\": \"Analyze the statistics of the data: house_size (m3), house_price ($)\\\\n90, 100\\\\n80, 90\\\\n100, 120\\\\n90, 100\"}', name='send_query_to_agents'),\n",
|
|
" Function(arguments='{\"agents\": [\"Visualization Agent\"], \"query\": \"Plot a line chart for the data: house_size (m3), house_price ($)\\\\n90, 100\\\\n80, 90\\\\n100, 120\\\\n90, 100\"}', name='send_query_to_agents')],\n",
|
|
" [Function(arguments='{\"data\":\"house_size (m3), house_price ($)\\\\n90, 100\\\\n80, 90\\\\n100, 120\\\\n90, 100\"}', name='clean_data')],\n",
|
|
" {'role': 'tool',\n",
|
|
" 'name': 'clean_data',\n",
|
|
" 'content': '{\"cleaned_data\": {\"house_size (m3)\": {\"0\": 90, \"1\": 80, \"2\": 100}, \" house_price ($)\": {\"0\": 100, \"1\": 90, \"2\": 120}}}'},\n",
|
|
" [Function(arguments='{\"data\":\"house_size,house_price\\\\n90,100\\\\n80,90\\\\n100,120\\\\n90,100\"}', name='stat_analysis')],\n",
|
|
" {'role': 'tool',\n",
|
|
" 'name': 'stat_analysis',\n",
|
|
" 'content': '{\"stats\": {\"house_size\": {\"count\": 4.0, \"mean\": 90.0, \"std\": 8.16496580927726, \"min\": 80.0, \"25%\": 87.5, \"50%\": 90.0, \"75%\": 92.5, \"max\": 100.0}, \"house_price\": {\"count\": 4.0, \"mean\": 102.5, \"std\": 12.583057392117917, \"min\": 90.0, \"25%\": 97.5, \"50%\": 100.0, \"75%\": 105.0, \"max\": 120.0}}}'},\n",
|
|
" [Function(arguments='{\"data\":\"house_size,house_price\\\\n90,100\\\\n80,90\\\\n100,120\\\\n90,100\",\"x\":\"house_size\",\"y\":\"house_price\"}', name='create_line_chart')],\n",
|
|
" {'role': 'tool',\n",
|
|
" 'name': 'create_line_chart',\n",
|
|
" 'content': '{\"line_chart\": \"sample_line_chart\"}'}]"
|
|
]
|
|
},
|
|
"execution_count": 60,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"handle_user_message(user_query)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Conclusion\n",
|
|
"\n",
|
|
"In this cookbook, we've explored how to leverage Structured Outputs to build more robust multi-agent systems.\n",
|
|
"\n",
|
|
"Using this new feature allows to make sure that tool calls follow the specified schema and avoids having to handle edge cases or validate arguments on your side.\n",
|
|
"\n",
|
|
"This can be applied to many more use cases, and we hope you can take inspiration from this to build your own use case!"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"colab": {
|
|
"provenance": []
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.9"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 1
|
|
}
|