You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
openai-cookbook/examples/evaluation/How_to_evaluate_LLMs_for_SQ...

2001 lines
137 KiB
Plaintext

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "markdown",
"id": "d5dd7679-5ed6-4243-8e0f-549f8118bff7",
"metadata": {},
"source": [
"# How to test and evaluate LLMs for SQL generation\n",
"\n",
"LLMs are fundamentatlly non-deterministic in their responses, this attribute makes them wonderfully creative and dynamic in their responses. However, this trait poses significant challenges in achieving consistency, a crucial aspect for integrating LLMs into production environments.\n",
"\n",
"The key to harnessing the potential of LLMs in practical applications lies in consistent and systematic evaluation. This enables the identification and rectification of inconsistencies and helps in monitoring progress over time as the application evolves.\n",
"\n",
"## Scope of this notebook\n",
"\n",
"This notebook aims to demonstrate a framework for evaluating LLMs, particularly focusing on:\n",
"\n",
"* **Unit Testing:** Essential for assessing individual components of the application.\n",
"* **Evaluation Metrics:** Methods to quantitatively measure the model's effectiveness.\n",
"* **Runbook Documentation:** A record of historical evaluations to track progress and regression.\n",
"\n",
"This example focuses on a natural language to SQL use case - code generation use cases fit well with this approach when you combine **code validation** with **code execution**, so your application can test code for real as it is generated to ensure consistency.\n",
"\n",
"Although this notebook uses SQL generation usecase to demonstrate the concept, the approach is generic and can be applied to a wide variety of LLM driven applications.\n",
"\n",
"We will use two versions of a prompt to perform SQL generation. We will then use the unit tests and evaluation functions to test the perforamance of the prompts. Specifically, in this demonstration, we will evaluate:\n",
"\n",
"1. The consistency of JSON response.\n",
"2. Syntactic correctness of SQL in response.\n",
"\n",
"\n",
"## Table of contents\n",
"\n",
"1. **[Setup](#Setup):** Install required libraries, download data consisting of SQL queries and corresponding natural language translations.\n",
"2. **[Test Development](#Test-development):** Create unit tests and define evaluation metrics for the SQL generation process.\n",
"4. **[Evaluation](#Evaluation):** Conduct tests using different prompts to assess the impact on performance.\n",
"5. **[Reporting](#Report):** Compile a report that succinctly presents the performance differences observed across various tests."
]
},
{
"cell_type": "markdown",
"id": "2913d615",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"Import our libraries and the dataset we'll use, which is the natural language to SQL [b-mc2/sql-create-context](https://huggingface.co/datasets/b-mc2/sql-create-context) dataset from HuggingFace."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "245fcedb",
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"from openai import OpenAI\n",
"import pandas as pd\n",
"import pydantic\n",
"import os\n",
"import sqlite3\n",
"from sqlite3 import Error\n",
"from pprint import pprint\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from dotenv import load_dotenv\n",
"\n",
"# Loads key from local .env file to setup API KEY in env variables\n",
"%reload_ext dotenv\n",
"%dotenv\n",
" \n",
"GPT_MODEL = 'gpt-3.5-turbo'\n",
"dataset = load_dataset(\"b-mc2/sql-create-context\")"
]
},
{
"cell_type": "markdown",
"id": "04c7fde6-d7dc-4a0d-b9a0-32858f3bac25",
"metadata": {},
"source": [
"### Looking at the dataset\n",
"\n",
"We use Huggingface datasets library to download SQL create context dataset. This dataset consists of:\n",
"\n",
"1. Question, expressed in natural language\n",
"2. Answer, expressed in SQL designed to answer the question in natural language.\n",
"3. Context, expressed as a CREATE SQL statement, that describes the table that may be used to answer the question.\n",
"\n",
"In our demonstration today, we will use LLM to attempt to answer the question (in natural language). The LLM will be expected to generate a CREATE SQL statement to create a context suitable to answer the user question and a coresponding SELECT SQL query designed to answer the user question completely.\n",
"\n",
"The dataset looks like this:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f8027115",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>question</th>\n",
" <th>answer</th>\n",
" <th>context</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>How many heads of the departments are older th...</td>\n",
" <td>SELECT COUNT(*) FROM head WHERE age &gt; 56</td>\n",
" <td>CREATE TABLE head (age INTEGER)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>List the name, born state and age of the heads...</td>\n",
" <td>SELECT name, born_state, age FROM head ORDER B...</td>\n",
" <td>CREATE TABLE head (name VARCHAR, born_state VA...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>List the creation year, name and budget of eac...</td>\n",
" <td>SELECT creation, name, budget_in_billions FROM...</td>\n",
" <td>CREATE TABLE department (creation VARCHAR, nam...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>What are the maximum and minimum budget of the...</td>\n",
" <td>SELECT MAX(budget_in_billions), MIN(budget_in_...</td>\n",
" <td>CREATE TABLE department (budget_in_billions IN...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>What is the average number of employees of the...</td>\n",
" <td>SELECT AVG(num_employees) FROM department WHER...</td>\n",
" <td>CREATE TABLE department (num_employees INTEGER...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" question \\\n",
"0 How many heads of the departments are older th... \n",
"1 List the name, born state and age of the heads... \n",
"2 List the creation year, name and budget of eac... \n",
"3 What are the maximum and minimum budget of the... \n",
"4 What is the average number of employees of the... \n",
"\n",
" answer \\\n",
"0 SELECT COUNT(*) FROM head WHERE age > 56 \n",
"1 SELECT name, born_state, age FROM head ORDER B... \n",
"2 SELECT creation, name, budget_in_billions FROM... \n",
"3 SELECT MAX(budget_in_billions), MIN(budget_in_... \n",
"4 SELECT AVG(num_employees) FROM department WHER... \n",
"\n",
" context \n",
"0 CREATE TABLE head (age INTEGER) \n",
"1 CREATE TABLE head (name VARCHAR, born_state VA... \n",
"2 CREATE TABLE department (creation VARCHAR, nam... \n",
"3 CREATE TABLE department (budget_in_billions IN... \n",
"4 CREATE TABLE department (num_employees INTEGER... "
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sql_df = dataset['train'].to_pandas()\n",
"sql_df.head()"
]
},
{
"cell_type": "markdown",
"id": "b04cb5eb",
"metadata": {},
"source": [
"\n",
"## Test development\n",
"\n",
"To test to output of the LLM generations, we'll develop two unit tests and an evaluation, which will combine to give us a basic evaluation framework to grade the quality of our LLM iterations.\n",
"\n",
"To re-iterate, our purpose is to measure the correctness and consistency of LLM output given our questions.\n",
"\n",
"### Unit tests\n",
"\n",
"Unit tests should test the most granular components of your LLM application.\n",
"\n",
"For this section we'll develop unit tests to test the following:\n",
"- `test_valid_schema` will check that a parseable `create` and `select` statement are returned by the LLM.\n",
"- `test_llm_sql` will execute both the `create` and `select` statements on a `sqlite` database to ensure they are syntactically correct."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "eb811101",
"metadata": {},
"outputs": [],
"source": [
"from pydantic import BaseModel\n",
"\n",
"class LLMResponse(BaseModel):\n",
" \"\"\"This simple Class expects to receive a JSON string that can be parsed into a `create` and `select` statement.\"\"\"\n",
" create: str\n",
" select: str"
]
},
{
"cell_type": "markdown",
"id": "19fadf67-8b2f-4e17-95df-030a36aad90b",
"metadata": {},
"source": [
"#### Prompt\n",
"\n",
"For this demonstration purposes, we use a fairly simple prompt requesting GPT to generate a pair of context CREATE SQL and a answering SELECT SQL query. We supply the natural language question as part of the prompt. We request the response to be in JSON format, so that it can be parsed easily."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c2be3ba4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"('Translate this natural language request into a JSON object containing two '\n",
" 'SQL queries. \\n'\n",
" 'The first query should be a CREATE statement for a table answering the '\n",
" \"user's request, while the second should be a SELECT query answering their \"\n",
" 'question. \\n'\n",
" 'This should be returned as a JSON with keys \"create\" and \"select\". \\n'\n",
" 'For example:\\n'\n",
" '\\n'\n",
" '{\"create\": \"\"\"CREATE_QUERY\"\"\",\"select\": \"\"\"SELECT_QUERY\"\"\"}\"')\n"
]
}
],
"source": [
"system_prompt = '''Translate this natural language request into a JSON object containing two SQL queries. \n",
"The first query should be a CREATE statement for a table answering the user's request, while the second should be a SELECT query answering their question. \n",
"This should be returned as a JSON with keys \"create\" and \"select\". \n",
"For example:\\n\\n{\"create\": \"\"\"CREATE_QUERY\"\"\",\"select\": \"\"\"SELECT_QUERY\"\"\"}\"'''\n",
"\n",
"pprint(system_prompt)\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "3a20d712",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[{'content': 'Translate this natural language request into a JSON object '\n",
" 'containing two SQL queries. \\n'\n",
" 'The first query should be a CREATE statement for a table '\n",
" \"answering the user's request, while the second should be a \"\n",
" 'SELECT query answering their question. \\n'\n",
" 'This should be returned as a JSON with keys \"create\" and '\n",
" '\"select\". \\n'\n",
" 'For example:\\n'\n",
" '\\n'\n",
" '{\"create\": \"\"\"CREATE_QUERY\"\"\",\"select\": \"\"\"SELECT_QUERY\"\"\"}\"',\n",
" 'role': 'system'},\n",
" {'content': 'How many heads of the departments are older than 56 ?',\n",
" 'role': 'user'}]\n"
]
}
],
"source": [
"# Compiling the system prompt and user question into message array\n",
"\n",
"messages = []\n",
"messages.append({\"role\": \"system\", \"content\": system_prompt})\n",
"messages.append({\"role\":\"user\",\"content\": sql_df.iloc[0]['question']})\n",
"pprint(messages)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "38b704b3-6f0e-4708-bc70-96723d69da6f",
"metadata": {},
"outputs": [],
"source": [
"# Sending the message array to GPT, requesting a response (ensure that you have API key loaded to Env for this step)\n",
"\n",
"client = OpenAI()\n",
"completion = client.chat.completions.create(model = GPT_MODEL, messages = messages)"
]
},
{
"cell_type": "markdown",
"id": "901e3bb7",
"metadata": {},
"source": [
"#### Check JSON formatting\n",
"\n",
"Our first simple unit test checks that the LLM response is parseable into the `LLMResponse` Pydantic class that we've defined.\n",
"\n",
"We'll test that our first response passes, then create a failing example to check that the check fails. This logic will be wrapped in a simple function `test_valid_schema`."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "2b057391-4f83-4b5a-8843-a9ee74bee871",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"('{\"create\": \"CREATE TABLE departments (department_id INTEGER PRIMARY KEY, '\n",
" 'department_name TEXT, head_name TEXT, age INTEGER)\", \"select\": \"SELECT '\n",
" 'COUNT(*) FROM departments WHERE age > 56\"}')\n"
]
}
],
"source": [
"# Viewing the output from GPT\n",
"\n",
"content = completion.choices[0].message.content\n",
"pprint(content)"
]
},
{
"cell_type": "markdown",
"id": "4b98bbb4-dd17-49bc-828c-e561abf5b481",
"metadata": {},
"source": [
"#### Validating the output schema\n",
"\n",
"We expect GPT to respond with a valid SQL, we can validate this using LLMResponse base model. `test_valid_schema` is designed to help us validate this."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "4c7133f1-74d6-43f1-9443-09a3f8308c35",
"metadata": {},
"outputs": [],
"source": [
"def test_valid_schema(content):\n",
" \"\"\"Tests whether the content provided can be parsed into our Pydantic model.\"\"\"\n",
" try:\n",
" LLMResponse.model_validate_json(content)\n",
" return True\n",
" # Catch pydantic's validation errors:\n",
" except pydantic.ValidationError as exc:\n",
" print(f\"ERROR: Invalid schema: {exc}\")\n",
" return False"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "6a9a9128",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_valid_schema(content)"
]
},
{
"cell_type": "markdown",
"id": "78f1af23-4dd0-4860-8a1a-88e5146be6ed",
"metadata": {},
"source": [
"#### Testing negative scenario\n",
"\n",
"To simulate a scenario in which we get an invalid JSON response from GPT, we hardcode an invalid JSON as response. We expect `test_valid_schema` function to throw an exception."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "a0a26690",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ERROR: Invalid schema: 1 validation error for LLMResponse\n",
" Invalid JSON: expected value at line 1 column 1 [type=json_invalid, input_value='CREATE departments, select * from departments', input_type=str]\n",
" For further information visit https://errors.pydantic.dev/2.5/v/json_invalid\n"
]
},
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"failing_query = 'CREATE departments, select * from departments'\n",
"test_valid_schema(failing_query)"
]
},
{
"cell_type": "markdown",
"id": "a5fdc420-94cc-4e47-80e1-82bf51e44f2a",
"metadata": {},
"source": [
"As expected, we get an exception thrown from the `test_valid_schema` fucntion."
]
},
{
"cell_type": "markdown",
"id": "a4e972cd-5734-43c0-a103-b9ceb41552fd",
"metadata": {},
"source": [
"### Test SQL queries\n",
"\n",
"Next we'll validate the correctness of the SQL. This test will be desined to validate:\n",
"\n",
"1. The CREATE SQL returned in GPT response is syntactically correct.\n",
"2. The SELECT SQL returned in the GPT response is syntactically correct.\n",
"\n",
"To achieve this, we will use a sqlite instance. We will direct the retured SQL functions to a sqlite instance. If the SQL statements are valid, sqlite instance will accept and execute the statements; otherwise we will expect an exception to be thrown.\n",
"\n",
"`create_connection` function below will setup a sqlite instance (in-memory by default) and create a connection to be used later."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "9cc95481",
"metadata": {},
"outputs": [],
"source": [
"# Set up SQLite to act as our test database\n",
"def create_connection(db_file=\":memory:\"):\n",
" \"\"\"create a database connection to a SQLite database\"\"\"\n",
" try:\n",
" conn = sqlite3.connect(db_file)\n",
" # print(sqlite3.version)\n",
" except Error as e:\n",
" print(e)\n",
" return None\n",
"\n",
" return conn\n",
"\n",
"def close_connection(conn):\n",
" \"\"\"close a database connection\"\"\"\n",
" try:\n",
" conn.close()\n",
" except Error as e:\n",
" print(e)\n",
"\n",
"\n",
"conn = create_connection()"
]
},
{
"cell_type": "markdown",
"id": "aa5c5cb8-1c81-403b-a3f2-f2784d8235fc",
"metadata": {},
"source": [
"Next, we will create the following functions to carry out the syntactical correctness checks.\n",
"\n",
"\n",
"- `test_create`: Function testing if the CREATE SQL statement succeeds.\n",
"- `test_select`: Function testing if the SELECT SQL statement succeeds.\n",
"- `test_llm_sql`: Wrapper function executing the two tests above."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "c6d2573d",
"metadata": {},
"outputs": [],
"source": [
"def test_select(conn, cursor, select):\n",
" \"\"\"Tests that a SQLite select query can be executed successfully.\"\"\"\n",
" try:\n",
" print(f\"Testing select query: {select}\")\n",
" cursor.execute(select)\n",
" record = cursor.fetchall()\n",
" print(record)\n",
"\n",
" return True\n",
"\n",
" except sqlite3.Error as error:\n",
" print(\"Error while executing select query:\", error)\n",
"\n",
" return False\n",
"\n",
"\n",
"def test_create(conn, cursor, create):\n",
" \"\"\"Tests that a SQLite create query can be executed successfully\"\"\"\n",
" try:\n",
" print(f\"Testing create query: {create}\")\n",
" cursor.execute(create)\n",
" conn.commit()\n",
"\n",
" return True\n",
"\n",
" except sqlite3.Error as error:\n",
" print(\"Error while creating the SQLite table:\", error)\n",
"\n",
" return False\n",
"\n",
"\n",
"def test_llm_sql(LLMResponse):\n",
" \"\"\"Runs a suite of SQLite tests\"\"\"\n",
" try:\n",
" conn = create_connection()\n",
" cursor = conn.cursor()\n",
"\n",
" create_response = test_create(conn, cursor, LLMResponse.create)\n",
"\n",
" select_response = test_select(conn, cursor, LLMResponse.select)\n",
"\n",
" if conn:\n",
" close_connection(conn)\n",
"\n",
" if create_response is not True:\n",
" return False\n",
"\n",
" elif select_response is not True:\n",
" return False\n",
"\n",
" else:\n",
" return True\n",
"\n",
" except sqlite3.Error as error:\n",
" print(\"Error while creating a sqlite table\", error)\n",
"\n",
" return False"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "a9266753-4646-4901-bc14-632d3bf47aaa",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CREATE SQL is: CREATE TABLE departments (department_id INTEGER PRIMARY KEY, department_name TEXT, head_name TEXT, age INTEGER)\n",
"SELECT SQL is: SELECT COUNT(*) FROM departments WHERE age > 56\n"
]
}
],
"source": [
"# Viewing CREATE and SELECT sqls returned by GPT\n",
"\n",
"test_query = LLMResponse.model_validate_json(content)\n",
"print(f\"CREATE SQL is: {test_query.create}\")\n",
"print(f\"SELECT SQL is: {test_query.select}\")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "83bc1f1b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Testing create query: CREATE TABLE departments (department_id INTEGER PRIMARY KEY, department_name TEXT, head_name TEXT, age INTEGER)\n",
"Testing select query: SELECT COUNT(*) FROM departments WHERE age > 56\n",
"[(0,)]\n"
]
},
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Testing the CREATE and SELECT sqls are valid (we expect this to be succesful)\n",
"\n",
"test_llm_sql(test_query)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "589c7cc7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Testing create query: CREATE TABLE departments (id INT, name VARCHAR(255), head_of_department VARCHAR(255))\n",
"Testing select query: SELECT COUNT(*) FROM departments WHERE age > 56\n",
"Error while executing select query: no such column: age\n"
]
},
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Again we'll perform a negative test to confirm that a failing SELECT will return an error.\n",
"\n",
"test_failure_query = '{\"create\": \"CREATE TABLE departments (id INT, name VARCHAR(255), head_of_department VARCHAR(255))\", \"select\": \"SELECT COUNT(*) FROM departments WHERE age > 56\"}'\n",
"test_failure_query = LLMResponse.model_validate_json(test_failure_query)\n",
"test_llm_sql(test_failure_query)"
]
},
{
"cell_type": "markdown",
"id": "8148f820",
"metadata": {},
"source": [
"### Evaluation\n",
"\n",
"The last component is to **evaluate** whether the generate SQL actually answers the user's question. This test will be performed by `gpt-4`, and will assess how **relevant** the produced SQL query is when compared to the initial user request.\n",
"\n",
"This is a simple example which adapts an approach outlined in the [G-Eval paper](https://arxiv.org/abs/2303.16634), and tested in one of our other [cookbooks](https://github.com/openai/openai-cookbook/blob/main/examples/evaluation/How_to_eval_abstractive_summarization.ipynb)."
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "029c8426",
"metadata": {},
"outputs": [],
"source": [
"EVALUATION_MODEL = \"gpt-4\"\n",
"\n",
"EVALUATION_PROMPT_TEMPLATE = \"\"\"\n",
"You will be given one summary written for an article. Your task is to rate the summary on one metric.\n",
"Please make sure you read and understand these instructions very carefully. \n",
"Please keep this document open while reviewing, and refer to it as needed.\n",
"\n",
"Evaluation Criteria:\n",
"\n",
"{criteria}\n",
"\n",
"Evaluation Steps:\n",
"\n",
"{steps}\n",
"\n",
"Example:\n",
"\n",
"Request:\n",
"\n",
"{request}\n",
"\n",
"Queries:\n",
"\n",
"{queries}\n",
"\n",
"Evaluation Form (scores ONLY):\n",
"\n",
"- {metric_name}\n",
"\"\"\"\n",
"\n",
"# Relevance\n",
"\n",
"RELEVANCY_SCORE_CRITERIA = \"\"\"\n",
"Relevance(1-5) - review of how relevant the produced SQL queries are to the original question. \\\n",
"The queries should contain all points highlighted in the user's request. \\\n",
"Annotators were instructed to penalize queries which contained redundancies and excess information.\n",
"\"\"\"\n",
"\n",
"RELEVANCY_SCORE_STEPS = \"\"\"\n",
"1. Read the request and the queries carefully.\n",
"2. Compare the queries to the request document and identify the main points of the request.\n",
"3. Assess how well the queries cover the main points of the request, and how much irrelevant or redundant information it contains.\n",
"4. Assign a relevance score from 1 to 5.\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "85cfb78d",
"metadata": {},
"outputs": [],
"source": [
"def get_geval_score(\n",
" criteria: str, steps: str, request: str, queries: str, metric_name: str\n",
"):\n",
" \"\"\"Given evaluation criteria and an observation, this function uses EVALUATION GPT to evaluate the observation against those criteria.\n",
"\"\"\"\n",
" prompt = EVALUATION_PROMPT_TEMPLATE.format(\n",
" criteria=criteria,\n",
" steps=steps,\n",
" request=request,\n",
" queries=queries,\n",
" metric_name=metric_name,\n",
" )\n",
" response = client.chat.completions.create(\n",
" model=EVALUATION_MODEL,\n",
" messages=[{\"role\": \"user\", \"content\": prompt}],\n",
" temperature=0,\n",
" max_tokens=5,\n",
" top_p=1,\n",
" frequency_penalty=0,\n",
" presence_penalty=0,\n",
" )\n",
" return response.choices[0].message.content"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "607ee304",
"metadata": {},
"outputs": [],
"source": [
"# Test out evaluation on a few records\n",
"\n",
"evaluation_results = []\n",
"\n",
"for x,y in sql_df.head(3).iterrows():\n",
" \n",
" score = get_geval_score(RELEVANCY_SCORE_CRITERIA,RELEVANCY_SCORE_STEPS,y['question'],y['context'] + '\\n' + y['answer'],'relevancy')\n",
" \n",
" evaluation_results.append((y['question'],y['context'] + '\\n' + y['answer'],score))"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "bd1002c2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"User Question \t: How many heads of the departments are older than 56 ?\n",
"CREATE SQL Returned \t: CREATE TABLE head (age INTEGER)\n",
"SELECT SQL Returned \t: SELECT COUNT(*) FROM head WHERE age > 56\n",
"5\n",
"********************\n",
"User Question \t: List the name, born state and age of the heads of departments ordered by age.\n",
"CREATE SQL Returned \t: CREATE TABLE head (name VARCHAR, born_state VARCHAR, age VARCHAR)\n",
"SELECT SQL Returned \t: SELECT name, born_state, age FROM head ORDER BY age\n",
"5\n",
"********************\n",
"User Question \t: List the creation year, name and budget of each department.\n",
"CREATE SQL Returned \t: CREATE TABLE department (creation VARCHAR, name VARCHAR, budget_in_billions VARCHAR)\n",
"SELECT SQL Returned \t: SELECT creation, name, budget_in_billions FROM department\n",
"5\n",
"********************\n"
]
}
],
"source": [
"for result in evaluation_results:\n",
" print(f\"User Question \\t: {result[0]}\")\n",
" print(f\"CREATE SQL Returned \\t: {result[1].splitlines()[0]}\")\n",
" print(f\"SELECT SQL Returned \\t: {result[1].splitlines()[1]}\")\n",
" print(f\"{result[2]}\")\n",
" print(\"*\" * 20)"
]
},
{
"cell_type": "markdown",
"id": "afe98f7a-3e88-437f-a5cd-d105969d3020",
"metadata": {},
"source": [
"## "
]
},
{
"cell_type": "markdown",
"id": "fe04c6c7",
"metadata": {},
"source": [
"## Putting it all together\n",
"\n",
"We'll now test these functions in combination including our unit test and evaluations to test out two system prompts.\n",
"\n",
"Each iteration of input/output and scores should be stored as a **run**. Optionally you can add GPT-4 annotation within your evaluations or as a separate step to review an entire run and highlight the reasons for errors.\n",
"\n",
"For this example, the second system prompt will include an extra line of clarification, so we can assess the impact of this for both SQL validity and quality of solution."
]
},
{
"cell_type": "markdown",
"id": "61b68e2a",
"metadata": {},
"source": [
"### First run - System Prompt 1\n",
"\n",
"The system under test is the first system prompt as shown below. This `run` will generate responses for this system prompt and evaluate the responses using the functions we've created so far."
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "85c44a17",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"('Translate this natural language request into a JSON object containing two SQL queries. \\n'\n",
" \"The first query should be a CREATE statement for a table answering the user's request, while the second should be a \"\n",
" 'SELECT query answering their question. \\n'\n",
" 'This should be returned as a JSON with keys \"create\" and \"select\". \\n'\n",
" 'For example:\\n'\n",
" '\\n'\n",
" '{\"create\": \"CREATE_QUERY\",\"select\": \"SELECT_QUERY\"}\" ')\n"
]
}
],
"source": [
"# Set first system prompt\n",
"system_prompt = \"\"\"Translate this natural language request into a JSON object containing two SQL queries. \n",
"The first query should be a CREATE statement for a table answering the user's request, while the second should be a SELECT query answering their question. \n",
"This should be returned as a JSON with keys \"create\" and \"select\". \n",
"For example:\\n\\n{\"create\": \"CREATE_QUERY\",\"select\": \"SELECT_QUERY\"}\" \"\"\"\n",
"\n",
"pprint(system_prompt, width = 120)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "1244c44e",
"metadata": {},
"outputs": [],
"source": [
"def get_response(system_prompt,user_message,model=GPT_MODEL):\n",
" messages = []\n",
" messages.append({\"role\": \"system\", \"content\": system_prompt})\n",
" messages.append({\"role\":\"user\",\"content\": user_message})\n",
" \n",
" response = client.chat.completions.create(model=GPT_MODEL,messages=messages,temperature=0)\n",
" \n",
" return response.choices[0].message.content"
]
},
{
"cell_type": "markdown",
"id": "76c2723b-3060-400f-b6fe-c3c3c9d6907e",
"metadata": {},
"source": [
"#### Run the tests and evaluations\n",
"\n",
"The functions below, run unit test and evaluate responses"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "a98afa30",
"metadata": {},
"outputs": [],
"source": [
"def execute_unit_tests(input_df,output_list,system_prompt):\n",
" \"\"\"Unit testing function that takes in a dataframe and appends test results to an output_list.\n",
" The system prompt is configurable to allow us to test a couple with this framework.\"\"\"\n",
"\n",
" for x,y in input_df.iterrows():\n",
" model_response = get_response(system_prompt,y['question'])\n",
"\n",
" format_valid = test_valid_schema(model_response)\n",
"\n",
" try:\n",
" test_query = LLMResponse.model_validate_json(model_response)\n",
" sql_valid = test_llm_sql(test_query)\n",
"\n",
" except:\n",
" sql_valid = False\n",
"\n",
" output_list.append((y['question'],model_response,format_valid,sql_valid))\n",
" \n",
"def evaluate_row(row):\n",
" \"\"\"Simple evaluation function to categorize unit testing results. \n",
" If the format or SQL are flagged it returns a label, otherwise it is correct\"\"\"\n",
" if row['format'] == False:\n",
" return 'Format incorrect'\n",
" \n",
" elif row['sql'] == False:\n",
" return 'SQL incorrect'\n",
" \n",
" else:\n",
" return 'SQL correct'"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "898e5069",
"metadata": {},
"outputs": [],
"source": [
"# Select 100 unseen queries to test this one\n",
"test_df = sql_df.tail(50)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "2baec278",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Testing create query: CREATE TABLE partnerships (id INT, player1 VARCHAR(255), player2 VARCHAR(255), venue VARCHAR(255))\n",
"Testing select query: SELECT venue FROM partnerships WHERE player1 = 'shoaib malik' AND player2 = 'misbah-ul-haq'\n",
"[]\n",
"Testing create query: CREATE TABLE partnerships (id INT, player1 VARCHAR(255), player2 VARCHAR(255), venue VARCHAR(255))\n",
"Testing select query: SELECT venue FROM partnerships WHERE player1 = 'herschelle gibbs' AND player2 = 'justin kemp'\n",
"[]\n",
"Testing create query: CREATE TABLE Played (Number INT, Points INT)\n",
"Testing select query: SELECT * FROM Played WHERE Points = 310\n",
"[]\n",
"Testing create query: CREATE TABLE losing_bonus (bonus_id INT, points_against INT)\n",
"Testing select query: SELECT * FROM losing_bonus WHERE points_against = 588\n",
"[]\n",
"Testing create query: CREATE TABLE rugby_teams (team_name VARCHAR(255), tries_against INT, losing_bonus INT);\n",
"Testing select query: SELECT * FROM rugby_teams WHERE tries_against = 7;\n",
"[]\n",
"Testing create query: CREATE TABLE try_bonus (id INT, points_against INT, try_bonus VARCHAR(50))\n",
"Testing select query: SELECT try_bonus FROM try_bonus WHERE points_against = 488\n",
"[]\n",
"Testing create query: CREATE TABLE Points (Team VARCHAR(255), Try_Bonus INT)\n",
"Testing select query: SELECT Team FROM Points WHERE Try_Bonus = 140\n",
"[]\n",
"Testing create query: CREATE TABLE Drawn (Team TEXT, Tries_against INT);\n",
"Testing select query: SELECT * FROM Drawn WHERE Tries_against = 0;\n",
"[]\n",
"Testing create query: CREATE TABLE champion (id INT, name VARCHAR(255), reign_days INT, defenses INT)\n",
"Testing select query: SELECT days_held FROM champion WHERE reign_days > 3 AND defenses = 1\n",
"Error while executing select query: no such column: days_held\n",
"Testing create query: CREATE TABLE champion (id INT, name VARCHAR(255), reign_days INT, defenses INT);\n",
"Testing select query: SELECT days_held FROM champion WHERE reign_days > 3 AND defenses < 1;\n",
"Error while executing select query: no such column: days_held\n",
"Testing create query: CREATE TABLE champion_stats (champion_id INT, days_held INT, reign INT, defenses INT)\n",
"Testing select query: SELECT AVG(defenses) FROM champion_stats WHERE days_held = 404 AND reign > 1\n",
"[(None,)]\n",
"Testing create query: CREATE TABLE champions (name VARCHAR(255), days_held INT, defense INT)\n",
"Testing select query: SELECT MIN(defense) FROM champions WHERE days_held = 345\n",
"[(None,)]\n",
"Testing create query: CREATE TABLE records (date DATE, record VARCHAR(10))\n",
"Testing select query: SELECT date FROM records WHERE record = '76-72'\n",
"[]\n",
"Testing create query: CREATE TABLE attendance (game_id INT, loss VARCHAR(20), attendance INT);\n",
"Testing select query: SELECT attendance FROM attendance WHERE loss = 'Ponson (1-5)';\n",
"[]\n",
"Testing create query: CREATE TABLE records (day VARCHAR(10), record INT);\n",
"Testing select query: SELECT day FROM records WHERE record >= 36 AND record <= 39;\n",
"[]\n",
"Testing create query: CREATE TABLE records (id INT, date DATE);\n",
"Testing select query: SELECT date FROM records WHERE id = '30-31';\n",
"[]\n",
"Testing create query: CREATE TABLE games (game_id INT, opponent VARCHAR(255), result VARCHAR(255))\n",
"Testing select query: SELECT opponent FROM games WHERE result = 'loss' AND game_id = (SELECT game_id FROM games WHERE opponent = 'Leonard' AND result = '78')\n",
"[]\n",
"Testing create query: CREATE TABLE game_scores (record VARCHAR(10), score INT);\n",
"Testing select query: SELECT score FROM game_scores WHERE record = '1843';\n",
"[]\n",
"Testing create query: CREATE TABLE game_scores (game_id INT, opponent VARCHAR(255), score INT)\n",
"Testing select query: SELECT score FROM game_scores WHERE opponent = 'Royals' AND record = '2452'\n",
"Error while executing select query: no such column: record\n",
"Testing create query: CREATE TABLE game_scores (record VARCHAR(10), score INT);\n",
"Testing select query: SELECT score FROM game_scores WHERE record = '2246';\n",
"[]\n",
"Testing create query: CREATE TABLE military_specialties (id INT, specialty_name VARCHAR(255), real_name VARCHAR(255))\n",
"Testing select query: SELECT real_name FROM military_specialties WHERE specialty_name = 'shock paratrooper'\n",
"[]\n",
"Testing create query: CREATE TABLE IF NOT EXISTS people (id INT PRIMARY KEY, name VARCHAR(255), birthplace VARCHAR(255))\n",
"Testing select query: SELECT birthplace FROM people WHERE name = 'Pete Sanderson'\n",
"[]\n",
"Testing create query: CREATE TABLE roles (id INT, name VARCHAR(255), role VARCHAR(255))\n",
"Testing select query: SELECT role FROM roles WHERE name = 'Jean-Luc Bouvier'\n",
"[]\n",
"Testing create query: CREATE TABLE pilots (id INT, name VARCHAR(255), kayak VARCHAR(255))\n",
"Testing select query: SELECT name FROM pilots WHERE kayak = 'silent attack kayak'\n",
"[]\n",
"Testing create query: CREATE TABLE persons (id INT, name VARCHAR(255), birthplace VARCHAR(255), code_name VARCHAR(255))\n",
"Testing select query: SELECT code_name FROM persons WHERE birthplace = 'Liverpool'\n",
"[]\n",
"Testing create query: CREATE TABLE medalists (name VARCHAR(255), sport VARCHAR(255))\n",
"Testing select query: SELECT name FROM medalists WHERE sport = 'canoeing'\n",
"[]\n",
"Testing create query: CREATE TABLE IF NOT EXISTS events (event_id INT, event_name VARCHAR(255), event_category VARCHAR(255), event_gender VARCHAR(255), event_weight VARCHAR(255))\n",
"Testing select query: SELECT event_name FROM events WHERE event_category = 'Women' AND event_weight = 'Half Middleweight'\n",
"[]\n",
"Testing create query: CREATE TABLE medals (id INT, athlete_name VARCHAR(255), medal VARCHAR(255), year INT, city VARCHAR(255))\n",
"Testing select query: SELECT athlete_name FROM medals WHERE medal = 'bronze' AND year = 2000 AND city = 'Sydney'\n",
"[]\n",
"Testing create query: CREATE TABLE attendance (id INT, opponent VARCHAR(255), attendees INT);\n",
"Testing select query: SELECT COUNT(*) FROM attendance WHERE opponent = 'twins';\n",
"[(0,)]\n",
"Testing create query: CREATE TABLE records (id INT, date DATE, record INT);\n",
"Testing select query: SELECT date FROM records WHERE record BETWEEN 41 AND 46;\n",
"[]\n",
"Testing create query: CREATE TABLE scores (id INT, record VARCHAR(10), score INT);\n",
"Testing select query: SELECT score FROM scores WHERE record = '48-55';\n",
"[]\n",
"Testing create query: CREATE TABLE scores (id INT, record VARCHAR(10), score INT);\n",
"Testing select query: SELECT score FROM scores WHERE record = '44-49';\n",
"[]\n",
"Testing create query: CREATE TABLE scores (Opponent TEXT, Record TEXT, Score INTEGER);\n",
"Testing select query: SELECT Score FROM scores WHERE Opponent = 'white sox' AND Record = '2-0';\n",
"[]\n",
"Testing create query: CREATE TABLE votes (candidate_name VARCHAR(255), vote_count INT);\n",
"Testing select query: SELECT vote_count FROM votes WHERE candidate_name = 'candice sjostrom';\n",
"[]\n",
"Testing create query: CREATE TABLE percentages (name VARCHAR(255), percentage FLOAT);\n",
"Testing select query: SELECT percentage FROM percentages WHERE name = 'chris wright';\n",
"[]\n",
"Testing create query: CREATE TABLE votes (year INT, candidate VARCHAR(255), vote_percentage FLOAT, office VARCHAR(255))\n",
"Testing select query: SELECT COUNT(*) FROM votes WHERE year > 1992 AND vote_percentage = 1.59 AND office = 'us representative 4'\n",
"[(0,)]\n",
"Testing create query: CREATE TABLE representatives (id INT, name VARCHAR(255), year_start INT, year_end INT)\n",
"Testing select query: SELECT year_start, year_end FROM representatives WHERE name = 'J. Smith Young'\n",
"[]\n",
"Testing create query: CREATE TABLE politicians (id INT, name VARCHAR(255), party VARCHAR(255))\n",
"Testing select query: SELECT party FROM politicians WHERE name = 'Thomas L. Young'\n",
"[]\n",
"Testing create query: CREATE TABLE medals (country TEXT, gold INTEGER, silver INTEGER, bronze INTEGER);\n",
"Testing select query: SELECT MIN(gold + silver + bronze) AS lowest_total_medals FROM medals WHERE gold = 0 AND silver > 1 AND bronze > 2;\n",
"[(None,)]\n",
"Testing create query: CREATE TABLE medals (country TEXT, rank INTEGER, gold INTEGER, silver INTEGER, bronze INTEGER, total INTEGER);\n",
"Testing select query: SELECT SUM(silver) AS total_silver FROM medals WHERE rank = 14 AND total < 1;\n",
"[(None,)]\n",
"Testing create query: CREATE TABLE player_stats (player_id INT, tackles INT, fumble_recoveries INT, forced_fumbles INT);\n",
"Testing select query: SELECT COUNT(tackles) FROM player_stats WHERE fumble_recoveries > 0 AND forced_fumbles > 0;\n",
"[(0,)]\n",
"Testing create query: CREATE TABLE forced_fumbles (player_name VARCHAR(255), solo_tackles INT, forced_fumbles INT);\n",
"Testing select query: SELECT COUNT(forced_fumbles) FROM forced_fumbles WHERE player_name = 'jim laney' AND solo_tackles < 2;\n",
"[(0,)]\n",
"Testing create query: CREATE TABLE player_stats (player_id INT, solo_tackles INT, total INT);\n",
"Testing select query: SELECT MAX(total) AS high_total FROM player_stats WHERE solo_tackles > 15;\n",
"[(None,)]\n",
"Testing create query: CREATE TABLE fumble_recoveries (player_name VARCHAR(255), forced_fumbles INT, sacks INT, solo_tackles INT, fumble_recoveries INT);\n",
"Testing select query: SELECT COUNT(fumble_recoveries) FROM fumble_recoveries WHERE player_name = 'scott gajos' AND forced_fumbles = 0 AND sacks = 0 AND solo_tackles < 2;\n",
"[(0,)]\n",
"Testing create query: CREATE TABLE matches (id INT, opponent VARCHAR(255), time VARCHAR(255), location VARCHAR(255))\n",
"Testing select query: SELECT opponent FROM matches WHERE time = '20:00 GMT' AND location = 'Camp Nou'\n",
"[]\n",
"Testing create query: CREATE TABLE matches (id INT, match_time TIME, score VARCHAR(5))\n",
"Testing select query: SELECT match_time FROM matches WHERE score = '3-2'\n",
"[]\n",
"Testing create query: CREATE TABLE grounds (id INT, name VARCHAR(255))\n",
"Testing select query: SELECT name FROM grounds WHERE id IN (SELECT ground_id FROM team_grounds WHERE team_id = (SELECT id FROM teams WHERE name = 'Aston Villa'))\n",
"Error while executing select query: no such table: team_grounds\n",
"Testing create query: CREATE TABLE competition (id INT, name VARCHAR(255), type VARCHAR(255), time TIMESTAMP);\n",
"Testing select query: SELECT type FROM competition WHERE time = '18:30 GMT' AND name = 'San Siro';\n",
"[]\n",
"Testing create query: CREATE TABLE decile (school_name VARCHAR(255), locality VARCHAR(255), decile INT)\n",
"Testing select query: SELECT COUNT(decile) AS total_decile FROM decile WHERE locality = 'redwood school'\n",
"[(0,)]\n",
"Testing create query: CREATE TABLE reports (id INT, report_name VARCHAR(255), circuit_name VARCHAR(255))\n",
"Testing select query: SELECT report_name FROM reports WHERE circuit_name = 'Tripoli'\n",
"[]\n"
]
}
],
"source": [
"# Execute unit tests and capture results\n",
"results = []\n",
"\n",
"execute_unit_tests(input_df=test_df,output_list=results,system_prompt=system_prompt)"
]
},
{
"cell_type": "markdown",
"id": "a070a4bf-7435-4059-bf74-eb6129cbab2b",
"metadata": {},
"source": [
"#### Run Evaluation\n",
"\n",
"Now that we have generated the SQL based on system prompt 1 (run 1), we can run evaluation against the results. We use pandas `apply` functin to \"apply\" evaluation to each resulting generation"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "8fe18367",
"metadata": {},
"outputs": [],
"source": [
"results_df = pd.DataFrame(results)\n",
"results_df.columns = ['question','response','format','sql']\n",
"\n",
"# Execute evaluation\n",
"results_df['evaluation_score'] = results_df.apply(lambda x: get_geval_score(RELEVANCY_SCORE_CRITERIA,RELEVANCY_SCORE_STEPS,x['question'],x['response'],'relevancy'),axis=1)\n",
"results_df['unit_test_evaluation'] = results_df.apply(lambda x: evaluate_row(x),axis=1)"
]
},
{
"cell_type": "markdown",
"id": "c3dd9b04-44e2-476c-86fd-c0a261b1cbdd",
"metadata": {},
"source": [
"## Viewing unit test results and evaluations - Run 1\n",
"\n",
"We can now group the outcomes of the unit test (which test the structure of response) and evaluation (which checks if the SQL is syntatically correct)."
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "650e6159",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"unit_test_evaluation\n",
"correct 46\n",
"SQL incorrect 4\n",
"Name: count, dtype: int64"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"results_df['unit_test_evaluation'].value_counts()"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "b3f98f81",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"evaluation_score\n",
"5 46\n",
"4 3\n",
"3 1\n",
"Name: count, dtype: int64"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"results_df['evaluation_score'].value_counts()"
]
},
{
"cell_type": "markdown",
"id": "019f3a1d",
"metadata": {},
"source": [
"### Second run\n",
"\n",
"We now use a new system prompt to run same unit test and evaluation. Please note that we are using the same functions for unit testing and evaluations; the only change is the system prompt (which is under the test)."
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "513a2da1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"('Translate this natural language request into a JSON object containing two SQL queries. \\n'\n",
" \"The first query should be a CREATE statement for a table answering the user's request, while the second should be a \"\n",
" 'SELECT query answering their question. \\n'\n",
" 'This should be returned as a JSON with keys \"create\" and \"select\". \\n'\n",
" 'For example:\\n'\n",
" '\\n'\n",
" '{\"create\": \"CREATE_QUERY\",\"select\": \"SELECT_QUERY\"}\" \\n'\n",
" 'Ensure the SQL is always generated on one line, never use \\n'\n",
" ' to separate rows.')\n"
]
}
],
"source": [
"system_prompt_2 = \"\"\"Translate this natural language request into a JSON object containing two SQL queries. \n",
"The first query should be a CREATE statement for a table answering the user's request, while the second should be a SELECT query answering their question. \n",
"This should be returned as a JSON with keys \"create\" and \"select\". \n",
"For example:\\n\\n{\"create\": \"CREATE_QUERY\",\"select\": \"SELECT_QUERY\"}\" \n",
"Ensure the SQL is always generated on one line, never use \\n to separate rows.\"\"\"\n",
"\n",
"pprint(system_prompt_2, width=120)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "70bd3e32",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Testing create query: CREATE TABLE partnerships (id INT, player1 VARCHAR(255), player2 VARCHAR(255), venue VARCHAR(255))\n",
"Testing select query: SELECT venue FROM partnerships WHERE player1 = 'shoaib malik' AND player2 = 'misbah-ul-haq'\n",
"[]\n",
"Testing create query: CREATE TABLE partnerships (partnership_id INT, player1 VARCHAR(255), player2 VARCHAR(255), venue VARCHAR(255));\n",
"Testing select query: SELECT venue FROM partnerships WHERE player1 = 'herschelle gibbs' AND player2 = 'justin kemp';\n",
"[]\n",
"Testing create query: CREATE TABLE scores (NumberPlayed INT, Points INT);\n",
"Testing select query: SELECT NumberPlayed FROM scores WHERE Points = 310;\n",
"[]\n",
"Testing create query: CREATE TABLE losing_bonus (bonus_id INT, points_against INT)\n",
"Testing select query: SELECT * FROM losing_bonus WHERE points_against = 588\n",
"[]\n",
"Testing create query: CREATE TABLE rugby_teams (team_name VARCHAR(255), tries_against INT, losing_bonus INT);\n",
"Testing select query: SELECT * FROM rugby_teams WHERE tries_against = 7;\n",
"[]\n",
"Testing create query: CREATE TABLE try_bonus (try_bonus_id INT, points_against INT)\n",
"Testing select query: SELECT try_bonus FROM try_bonus WHERE points_against = 488\n",
"Error while executing select query: no such column: try_bonus\n",
"Testing create query: CREATE TABLE Points (Team VARCHAR(255), Try_Bonus INT)\n",
"Testing select query: SELECT Team FROM Points WHERE Try_Bonus = 140\n",
"[]\n",
"Testing create query: CREATE TABLE Drawn (Team TEXT, Tries_against INT);\n",
"Testing select query: SELECT * FROM Drawn WHERE Tries_against = 0;\n",
"[]\n",
"Testing create query: CREATE TABLE IF NOT EXISTS champions (id INT, name VARCHAR(255), reign_days INT, defenses INT);\n",
"Testing select query: SELECT days_held FROM champions WHERE reign_days > 3 AND defenses = 1;\n",
"Error while executing select query: no such column: days_held\n",
"Testing create query: CREATE TABLE champion (name VARCHAR(255), reign_days INT, defenses INT);\n",
"Testing select query: SELECT days_held FROM champion WHERE reign_days > 3 AND defenses < 1;\n",
"Error while executing select query: no such column: days_held\n",
"Testing create query: CREATE TABLE champion_stats (champion_name VARCHAR(255), days_held INT, reign INT, defenses INT);\n",
"Testing select query: SELECT AVG(defenses) FROM champion_stats WHERE days_held = 404 AND reign > 1;\n",
"[(None,)]\n",
"Testing create query: CREATE TABLE champions (name VARCHAR(255), days_held INT, defense INT);\n",
"Testing select query: SELECT MIN(defense) FROM champions WHERE days_held = 345;\n",
"[(None,)]\n",
"Testing create query: CREATE TABLE records (date DATE, record VARCHAR(10));\n",
"Testing select query: SELECT date FROM records WHERE record = '76-72';\n",
"[]\n",
"Testing create query: CREATE TABLE attendance (game_id INT, loss VARCHAR(10), attendance INT);\n",
"Testing select query: SELECT attendance FROM attendance WHERE loss = 'Ponson (1-5)';\n",
"[]\n",
"Testing create query: CREATE TABLE records (day VARCHAR(10), record INT);\n",
"Testing select query: SELECT day FROM records WHERE record >= 36 AND record <= 39;\n",
"[]\n",
"Testing create query: CREATE TABLE records (id INT, date DATE);\n",
"Testing select query: SELECT date FROM records WHERE record = '30-31';\n",
"Error while executing select query: no such column: record\n",
"Testing create query: CREATE TABLE games (game_id INT, opponent VARCHAR(255), result VARCHAR(255));\n",
"Testing select query: SELECT opponent FROM games WHERE result = 'loss' AND game_id = (SELECT game_id FROM games WHERE opponent = 'Leonard' AND result = '78');\n",
"[]\n",
"Testing create query: CREATE TABLE game_scores (game_id INT, home_team_score INT, away_team_score INT);\n",
"Testing select query: SELECT home_team_score, away_team_score FROM game_scores WHERE home_team_score = 18 AND away_team_score = 43;\n",
"[]\n",
"Testing create query: CREATE TABLE game_scores (game_id INT, opponent VARCHAR(50), score INT)\n",
"Testing select query: SELECT score FROM game_scores WHERE opponent = 'Royals' AND record = '2452'\n",
"Error while executing select query: no such column: record\n",
"Testing create query: CREATE TABLE game_scores (id INT, home_team_score INT, away_team_score INT);\n",
"Testing select query: SELECT * FROM game_scores WHERE home_team_score = 22 AND away_team_score = 46;\n",
"[]\n",
"Testing create query: CREATE TABLE military_specialties (id INT, specialty_name VARCHAR(255), real_name VARCHAR(255));\n",
"Testing select query: SELECT real_name FROM military_specialties WHERE specialty_name = 'shock paratrooper';\n",
"[]\n",
"Testing create query: CREATE TABLE birthplaces (name VARCHAR(255), birthplace VARCHAR(255));\n",
"Testing select query: SELECT birthplace FROM birthplaces WHERE name = 'Pete Sanderson';\n",
"[]\n",
"Testing create query: CREATE TABLE roles (id INT, name VARCHAR(255), role VARCHAR(255))\n",
"Testing select query: SELECT role FROM roles WHERE name = 'Jean-Luc Bouvier'\n",
"[]\n",
"Testing create query: CREATE TABLE pilots (id INT, name VARCHAR(255), kayak_type VARCHAR(255))\n",
"Testing select query: SELECT name FROM pilots WHERE kayak_type = 'silent attack kayak'\n",
"[]\n",
"Testing create query: CREATE TABLE persons (id INT, name VARCHAR(255), birthplace VARCHAR(255), code_name VARCHAR(255));\n",
"Testing select query: SELECT code_name FROM persons WHERE birthplace = 'Liverpool';\n",
"[]\n",
"Testing create query: CREATE TABLE medalists (name VARCHAR(255), sport VARCHAR(255))\n",
"Testing select query: SELECT name FROM medalists WHERE sport = 'canoeing'\n",
"[]\n",
"Testing create query: CREATE TABLE IF NOT EXISTS events (event_id INT, event_name VARCHAR(255), event_category VARCHAR(255))\n",
"Testing select query: SELECT event_name FROM events WHERE event_category = 'women's half middleweight'\n",
"Error while executing select query: near \"s\": syntax error\n",
"Testing create query: CREATE TABLE medals (id INT, athlete VARCHAR(255), event VARCHAR(255), medal VARCHAR(255), year INT);\n",
"Testing select query: SELECT athlete FROM medals WHERE medal = 'bronze' AND year = 2000 AND event = 'Sydney games';\n",
"[]\n",
"Testing create query: CREATE TABLE attendance (opponent VARCHAR(50), people_attended INT);\n",
"Testing select query: SELECT people_attended FROM attendance WHERE opponent = 'twins';\n",
"[]\n",
"Testing create query: CREATE TABLE records (date DATE, record INT);\n",
"Testing select query: SELECT date FROM records WHERE record BETWEEN 41 AND 46;\n",
"[]\n",
"Testing create query: CREATE TABLE scores (id INT, record VARCHAR(10), score INT);\n",
"Testing select query: SELECT score FROM scores WHERE record = '48-55';\n",
"[]\n",
"Testing create query: CREATE TABLE scores (id INT, record VARCHAR(10), score INT);\n",
"Testing select query: SELECT score FROM scores WHERE record = '44-49';\n",
"[]\n",
"Testing create query: CREATE TABLE scores (Opponent VARCHAR(50), Record VARCHAR(10), Score INT);\n",
"Testing select query: SELECT Score FROM scores WHERE Opponent = 'white sox' AND Record = '2-0';\n",
"[]\n",
"Testing create query: CREATE TABLE votes (candidate_name VARCHAR(255), vote_count INT);\n",
"Testing select query: SELECT vote_count FROM votes WHERE candidate_name = 'candice sjostrom';\n",
"[]\n",
"Testing create query: CREATE TABLE percentages (name VARCHAR(255), percentage FLOAT);\n",
"Testing select query: SELECT percentage FROM percentages WHERE name = 'chris wright';\n",
"[]\n",
"Testing create query: CREATE TABLE votes (year INT, candidate VARCHAR(255), vote_percentage FLOAT, office VARCHAR(255));\n",
"Testing select query: SELECT COUNT(*) FROM votes WHERE year > 1992 AND vote_percentage = 1.59 AND office = 'us representative 4';\n",
"[(0,)]\n",
"Testing create query: CREATE TABLE representatives (id INT, name VARCHAR(255), start_year INT, end_year INT)\n",
"Testing select query: SELECT start_year, end_year FROM representatives WHERE name = 'J. Smith Young'\n",
"[]\n",
"Testing create query: CREATE TABLE politicians (id INT, name VARCHAR(255), party VARCHAR(255))\n",
"Testing select query: SELECT party FROM politicians WHERE name = 'Thomas L. Young'\n",
"[]\n",
"Testing create query: CREATE TABLE medal_counts (country TEXT, gold INTEGER, silver INTEGER, bronze INTEGER);\n",
"Testing select query: SELECT MIN(gold + silver + bronze) AS lowest_total_medals FROM medal_counts WHERE gold = 0 AND silver > 1 AND bronze > 2;\n",
"[(None,)]\n",
"Testing create query: CREATE TABLE medals (country TEXT, rank INTEGER, gold INTEGER, silver INTEGER, bronze INTEGER, total INTEGER);\n",
"Testing select query: SELECT SUM(silver) AS total_silver FROM medals WHERE rank = 14 AND total < 1;\n",
"[(None,)]\n",
"Testing create query: CREATE TABLE player_stats (player_id INT, tackles INT, fumble_recoveries INT, forced_fumbles INT);\n",
"Testing select query: SELECT COUNT(tackles) FROM player_stats WHERE fumble_recoveries > 0 AND forced_fumbles = 0;\n",
"[(0,)]\n",
"Testing create query: CREATE TABLE forced_fumbles (player_name VARCHAR(255), solo_tackles INT, forced_fumbles INT);\n",
"Testing select query: SELECT COUNT(forced_fumbles) FROM forced_fumbles WHERE player_name = 'Jim Laney' AND solo_tackles < 2;\n",
"[(0,)]\n",
"Testing create query: CREATE TABLE player_stats (player_id INT, solo_tackles INT, total INT);\n",
"Testing select query: SELECT MAX(total) AS high_total FROM player_stats WHERE solo_tackles > 15;\n",
"[(None,)]\n",
"Testing create query: CREATE TABLE fumble_recoveries (player_name VARCHAR(255), forced_fumbles INT, sacks INT, solo_tackles INT, fumble_recoveries INT);\n",
"Testing select query: SELECT fumble_recoveries FROM fumble_recoveries WHERE player_name = 'scott gajos' AND forced_fumbles = 0 AND sacks = 0 AND solo_tackles < 2;\n",
"[]\n",
"Testing create query: CREATE TABLE matches (id INT, opponent VARCHAR(255), time VARCHAR(255), location VARCHAR(255))\n",
"Testing select query: SELECT opponent FROM matches WHERE time = '20:00 GMT' AND location = 'Camp Nou'\n",
"[]\n",
"Testing create query: CREATE TABLE matches (id INT, match_time TIME, score VARCHAR(5))\n",
"Testing select query: SELECT match_time FROM matches WHERE score = '3-2'\n",
"[]\n",
"Testing create query: CREATE TABLE grounds (id INT, name VARCHAR(255));\n",
"Testing select query: SELECT name FROM grounds WHERE id IN (SELECT ground_id FROM team_grounds WHERE team_id = (SELECT id FROM teams WHERE name = 'Aston Villa'));\n",
"Error while executing select query: no such table: team_grounds\n",
"Testing create query: CREATE TABLE competition (id INT, type VARCHAR(255), location VARCHAR(255), time TIME)\n",
"Testing select query: SELECT type FROM competition WHERE location = 'San Siro' AND time = '18:30:00'\n",
"[]\n",
"Testing create query: CREATE TABLE decile (school_name VARCHAR(255), locality VARCHAR(255), decile INT)\n",
"Testing select query: SELECT COUNT(DISTINCT decile) AS total_decile FROM decile WHERE locality = 'redwood school'\n",
"[(0,)]\n",
"Testing create query: CREATE TABLE reports (id INT, name VARCHAR(255), circuit VARCHAR(255))\n",
"Testing select query: SELECT * FROM reports WHERE circuit = 'Tripoli'\n",
"[]\n"
]
}
],
"source": [
"# Execute unit tests\n",
"results_2 = []\n",
"\n",
"execute_unit_tests(input_df=test_df,output_list=results_2,system_prompt=system_prompt_2)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "04532d59",
"metadata": {},
"outputs": [],
"source": [
"results_2_df = pd.DataFrame(results_2)\n",
"results_2_df.columns = ['question','response','format','sql']\n",
"\n",
"# Execute evaluation\n",
"results_2_df['evaluation_score'] = results_2_df.apply(lambda x: get_geval_score(RELEVANCY_SCORE_CRITERIA,RELEVANCY_SCORE_STEPS,x['question'],x['response'],'relevancy'),axis=1)\n",
"results_2_df['unit_test_evaluation'] = results_2_df.apply(lambda x: evaluate_row(x),axis=1)"
]
},
{
"cell_type": "markdown",
"id": "cd95c3f9-f90d-451d-a32b-aeb066906779",
"metadata": {},
"source": [
"## Viewing unit test results and evaluations - Run 2\n",
"\n",
"We can now group the outcomes of the unit test (which test the structure of response) and evaluation (which checks if the SQL is syntatically correct)."
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "cbaa4bdf",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"unit_test_evaluation\n",
"correct 43\n",
"SQL incorrect 7\n",
"Name: count, dtype: int64"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"results_2_df['unit_test_evaluation'].value_counts()"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "1ada474e",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"evaluation_score\n",
"5 46\n",
"4 3\n",
"3 1\n",
"Name: count, dtype: int64"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"results_2_df['evaluation_score'].value_counts()"
]
},
{
"cell_type": "markdown",
"id": "1908c933",
"metadata": {},
"source": [
"## Report\n",
"\n",
"We'll make a simple dataframe to store and display the run performance - this is where you can use tools like Weights & Biases Prompts or Gantry to store the results for analytics on your different iterations."
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "d277222d",
"metadata": {},
"outputs": [],
"source": [
"results_df['run'] = 1\n",
"results_df['Evaluating Model'] = 'gpt-4'\n",
"\n",
"results_2_df['run'] = 2\n",
"results_2_df['Evaluating Model'] = 'gpt-4'"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "6da35c99",
"metadata": {},
"outputs": [],
"source": [
"run_df = pd.concat([results_df,results_2_df])"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "4116cb37",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>question</th>\n",
" <th>response</th>\n",
" <th>format</th>\n",
" <th>sql</th>\n",
" <th>evaluation_score</th>\n",
" <th>unit_test_evaluation</th>\n",
" <th>run</th>\n",
" <th>Evaluating Model</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>What venue did the parntership of shoaib malik...</td>\n",
" <td>{\"create\": \"CREATE TABLE partnerships (id INT,...</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>5</td>\n",
" <td>correct</td>\n",
" <td>1</td>\n",
" <td>gpt-4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>What venue did the partnership of herschelle g...</td>\n",
" <td>{\"create\": \"CREATE TABLE partnerships (id INT,...</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>5</td>\n",
" <td>correct</td>\n",
" <td>1</td>\n",
" <td>gpt-4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>What is the number Played that has 310 Points ...</td>\n",
" <td>{\"create\": \"CREATE TABLE Played (Number INT, P...</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>5</td>\n",
" <td>correct</td>\n",
" <td>1</td>\n",
" <td>gpt-4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>What Losing bonus has a Points against of 588?</td>\n",
" <td>{\"create\": \"CREATE TABLE losing_bonus (bonus_i...</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>5</td>\n",
" <td>correct</td>\n",
" <td>1</td>\n",
" <td>gpt-4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>What Tries against has a Losing bonus of 7?</td>\n",
" <td>{\"create\": \"CREATE TABLE rugby_teams (team_nam...</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>3</td>\n",
" <td>correct</td>\n",
" <td>1</td>\n",
" <td>gpt-4</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" question \\\n",
"0 What venue did the parntership of shoaib malik... \n",
"1 What venue did the partnership of herschelle g... \n",
"2 What is the number Played that has 310 Points ... \n",
"3 What Losing bonus has a Points against of 588? \n",
"4 What Tries against has a Losing bonus of 7? \n",
"\n",
" response format sql \\\n",
"0 {\"create\": \"CREATE TABLE partnerships (id INT,... True True \n",
"1 {\"create\": \"CREATE TABLE partnerships (id INT,... True True \n",
"2 {\"create\": \"CREATE TABLE Played (Number INT, P... True True \n",
"3 {\"create\": \"CREATE TABLE losing_bonus (bonus_i... True True \n",
"4 {\"create\": \"CREATE TABLE rugby_teams (team_nam... True True \n",
"\n",
" evaluation_score unit_test_evaluation run Evaluating Model \n",
"0 5 correct 1 gpt-4 \n",
"1 5 correct 1 gpt-4 \n",
"2 5 correct 1 gpt-4 \n",
"3 5 correct 1 gpt-4 \n",
"4 3 correct 1 gpt-4 "
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"run_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "ed800f0c",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th></th>\n",
" <th>Number of records</th>\n",
" </tr>\n",
" <tr>\n",
" <th>run</th>\n",
" <th>unit_test_evaluation</th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th rowspan=\"2\" valign=\"top\">1</th>\n",
" <th>SQL incorrect</th>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>correct</th>\n",
" <td>46</td>\n",
" </tr>\n",
" <tr>\n",
" <th rowspan=\"2\" valign=\"top\">2</th>\n",
" <th>SQL incorrect</th>\n",
" <td>7</td>\n",
" </tr>\n",
" <tr>\n",
" <th>correct</th>\n",
" <td>43</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Number of records\n",
"run unit_test_evaluation \n",
"1 SQL incorrect 4\n",
" correct 46\n",
"2 SQL incorrect 7\n",
" correct 43"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Unit test results\n",
"unittest_df_pivot = pd.pivot_table(run_df, values='format',index=['run','unit_test_evaluation'], #columns='position',\n",
" aggfunc='count')\n",
"unittest_df_pivot.columns = ['Number of records']\n",
"unittest_df_pivot"
]
},
{
"cell_type": "markdown",
"id": "0162a009-fc43-484c-90f6-d59a8e52f365",
"metadata": {},
"source": [
"#### Plotting the results\n",
"\n",
"We can create a simple bar chart to visualise the results of unit tests for both runs."
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "e2b4aa03-42f5-4c30-a610-e553937bf160",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"unittest_df_pivot.reset_index(inplace=True)\n",
"\n",
"# Plotting\n",
"plt.figure(figsize=(10, 6))\n",
"\n",
"# Set the width of each bar\n",
"bar_width = 0.35\n",
"\n",
"# OpenAI brand colors\n",
"openai_colors = ['#00D1B2', '#000000'] # Green and Black\n",
"\n",
"# Get unique runs and unit test evaluations\n",
"unique_runs = unittest_df_pivot['run'].unique()\n",
"unique_unit_test_evaluations = unittest_df_pivot['unit_test_evaluation'].unique()\n",
"\n",
"# Ensure we have enough colors (repeating the pattern if necessary)\n",
"colors = openai_colors * (len(unique_runs) // len(openai_colors) + 1)\n",
"\n",
"# Iterate over each run to plot\n",
"for i, run in enumerate(unique_runs):\n",
" run_data = unittest_df_pivot[unittest_df_pivot['run'] == run]\n",
"\n",
" # Position of bars for this run\n",
" positions = np.arange(len(unique_unit_test_evaluations)) + i * bar_width\n",
"\n",
" plt.bar(positions, run_data['Number of records'], width=bar_width, label=f'Run {run}', color=colors[i])\n",
"\n",
"# Setting the x-axis labels to be the unit test evaluations, centered under the groups\n",
"plt.xticks(np.arange(len(unique_unit_test_evaluations)) + bar_width / 2, unique_unit_test_evaluations)\n",
"\n",
"plt.xlabel('Unit Test Evaluation')\n",
"plt.ylabel('Number of Records')\n",
"plt.title('Unit Test Evaluations vs Number of Records for Each Run')\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "7228eac7-e0a9-473d-9432-e558bbc91841",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th></th>\n",
" <th>Number of records</th>\n",
" </tr>\n",
" <tr>\n",
" <th>run</th>\n",
" <th>evaluation_score</th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th rowspan=\"3\" valign=\"top\">1</th>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>46</td>\n",
" </tr>\n",
" <tr>\n",
" <th rowspan=\"3\" valign=\"top\">2</th>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>46</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Number of records\n",
"run evaluation_score \n",
"1 3 1\n",
" 4 3\n",
" 5 46\n",
"2 3 1\n",
" 4 3\n",
" 5 46"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Unit test results\n",
"evaluation_df_pivot = pd.pivot_table(run_df, values='format',index=['run','evaluation_score'], #columns='position',\n",
" aggfunc='count')\n",
"evaluation_df_pivot.columns = ['Number of records']\n",
"evaluation_df_pivot"
]
},
{
"cell_type": "markdown",
"id": "786515fa-6841-4820-98f9-aa29ae76cf76",
"metadata": {},
"source": [
"#### Plotting the results\n",
"\n",
"We can create a simple bar chart to visualise the results of unit tests for both runs."
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "b2a18a78-55ec-43f6-9d62-929707a94364",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"evaluation_df_pivot.reset_index(inplace=True)\n",
"\n",
"# Plotting\n",
"plt.figure(figsize=(10, 6))\n",
"\n",
"# Set the width of each bar\n",
"bar_width = 0.35\n",
"\n",
"# OpenAI brand colors\n",
"openai_colors = ['#00D1B2', '#000000'] # Green and Black\n",
"\n",
"# Get unique runs and evaluation scores\n",
"unique_runs = evaluation_df_pivot['run'].unique()\n",
"unique_evaluation_scores = evaluation_df_pivot['evaluation_score'].unique()\n",
"\n",
"# Ensure we have enough colors (repeating the pattern if necessary)\n",
"colors = openai_colors * (len(unique_runs) // len(openai_colors) + 1)\n",
"\n",
"# Iterate over each run to plot\n",
"for i, run in enumerate(unique_runs):\n",
" run_data = evaluation_df_pivot[evaluation_df_pivot['run'] == run]\n",
"\n",
" # Position of bars for this run\n",
" positions = np.arange(len(unique_evaluation_scores)) + i * bar_width\n",
"\n",
" plt.bar(positions, run_data['Number of records'], width=bar_width, label=f'Run {run}', color=colors[i])\n",
"\n",
"# Setting the x-axis labels to be the evaluation scores, centered under the groups\n",
"plt.xticks(np.arange(len(unique_evaluation_scores)) + bar_width / 2, unique_evaluation_scores)\n",
"\n",
"plt.xlabel('Evaluation Score')\n",
"plt.ylabel('Number of Records')\n",
"plt.title('Evaluation Scores vs Number of Records for Each Run')\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "992f9aa0",
"metadata": {},
"source": [
"## Conclusion\n",
"\n",
"Now you have a framework to test SQL generation using LLMs, and with some tweaks this approach can be extended to many other code generation use cases. With GPT-4 and engaged human labellers you can aim to automate the evaluation of these test cases, making an iterative loop where new examples are added to the test set and this structure detects any performance regressions. \n",
"\n",
"We hope you find this useful, and please supply any feedback."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "openai_test",
"language": "python",
"name": "openai_test"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}