openai-cookbook/examples/Code_search_using_embeddings.ipynb

431 lines
15 KiB
Plaintext
Raw Normal View History

2022-03-11 02:08:53 +00:00
{
"cells": [
{
"attachments": {},
2022-03-11 02:08:53 +00:00
"cell_type": "markdown",
"metadata": {},
"source": [
2023-10-18 00:39:42 +00:00
"## Code search using embeddings\n",
2022-03-11 02:08:53 +00:00
"\n",
"This notebook shows how Ada embeddings can be used to implement semantic code search. For this demonstration, we use our own [openai-python code repository](https://github.com/openai/openai-python). We implement a simple version of file parsing and extracting of functions from python files, which can be embedded, indexed, and queried."
]
},
{
2023-10-18 00:39:42 +00:00
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Helper Functions\n",
"\n",
"We first setup some simple parsing functions that allow us to extract important information from our codebase."
2022-03-11 02:08:53 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
2022-03-11 02:08:53 +00:00
"metadata": {},
"outputs": [],
2022-03-11 02:08:53 +00:00
"source": [
"import pandas as pd\n",
"from pathlib import Path\n",
"\n",
"DEF_PREFIXES = ['def ', 'async def ']\n",
"NEWLINE = '\\n'\n",
"\n",
2022-03-11 02:08:53 +00:00
"def get_function_name(code):\n",
" \"\"\"\n",
" Extract function name from a line beginning with 'def' or 'async def'.\n",
2022-03-11 02:08:53 +00:00
" \"\"\"\n",
" for prefix in DEF_PREFIXES:\n",
" if code.startswith(prefix):\n",
" return code[len(prefix): code.index('(')]\n",
"\n",
2022-03-11 02:08:53 +00:00
"\n",
"def get_until_no_space(all_lines, i):\n",
2022-03-11 02:08:53 +00:00
" \"\"\"\n",
" Get all lines until a line outside the function definition is found.\n",
" \"\"\"\n",
" ret = [all_lines[i]]\n",
" for j in range(i + 1, len(all_lines)):\n",
" if len(all_lines[j]) == 0 or all_lines[j][0] in [' ', '\\t', ')']:\n",
" ret.append(all_lines[j])\n",
" else:\n",
" break\n",
" return NEWLINE.join(ret)\n",
"\n",
2022-03-11 02:08:53 +00:00
"\n",
"def get_functions(filepath):\n",
" \"\"\"\n",
" Get all functions in a Python file.\n",
" \"\"\"\n",
" with open(filepath, 'r') as file:\n",
" all_lines = file.read().replace('\\r', NEWLINE).split(NEWLINE)\n",
" for i, l in enumerate(all_lines):\n",
" for prefix in DEF_PREFIXES:\n",
" if l.startswith(prefix):\n",
" code = get_until_no_space(all_lines, i)\n",
" function_name = get_function_name(code)\n",
" yield {\n",
" 'code': code,\n",
" 'function_name': function_name,\n",
" 'filepath': filepath,\n",
" }\n",
" break\n",
"\n",
"\n",
"def extract_functions_from_repo(code_root):\n",
" \"\"\"\n",
" Extract all .py functions from the repository.\n",
" \"\"\"\n",
" code_files = list(code_root.glob('**/*.py'))\n",
"\n",
" num_files = len(code_files)\n",
" print(f'Total number of .py files: {num_files}')\n",
2022-03-11 02:08:53 +00:00
"\n",
" if num_files == 0:\n",
" print('Verify openai-python repo exists and code_root is set correctly.')\n",
" return None\n",
2022-03-11 02:08:53 +00:00
"\n",
" all_funcs = [\n",
" func\n",
" for code_file in code_files\n",
" for func in get_functions(str(code_file))\n",
" ]\n",
2022-03-11 02:08:53 +00:00
"\n",
" num_funcs = len(all_funcs)\n",
" print(f'Total number of functions extracted: {num_funcs}')\n",
"\n",
" return all_funcs"
]
},
{
2023-10-18 00:39:42 +00:00
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Data Loading\n",
"\n",
"We'll first load the openai-python folder and extract the needed information using the functions we defined above."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of .py files: 57\n",
"Total number of functions extracted: 118\n"
]
}
],
"source": [
"# Set user root directory to the 'openai-python' repository\n",
"root_dir = Path.home()\n",
2022-03-11 02:08:53 +00:00
"\n",
"# Assumes the 'openai-python' repository exists in the user's root directory\n",
"code_root = root_dir / 'openai-python'\n",
"\n",
"# Extract all functions from the repository\n",
"all_funcs = extract_functions_from_repo(code_root)"
2022-03-11 02:08:53 +00:00
]
},
{
2023-10-18 00:39:42 +00:00
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we have our content, we can pass the data to the text-embedding-ada-002 endpoint to get back our vector embeddings."
]
},
2022-03-11 02:08:53 +00:00
{
"cell_type": "code",
"execution_count": 11,
2022-03-11 02:08:53 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>code</th>\n",
" <th>function_name</th>\n",
" <th>filepath</th>\n",
" <th>code_embedding</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>def _console_log_level():\\n if openai.log i...</td>\n",
" <td>_console_log_level</td>\n",
" <td>openai/util.py</td>\n",
" <td>[0.033906757831573486, -0.00418944051489234, 0...</td>\n",
2022-03-11 02:08:53 +00:00
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>def log_debug(message, **params):\\n msg = l...</td>\n",
" <td>log_debug</td>\n",
" <td>openai/util.py</td>\n",
" <td>[-0.004059609025716782, 0.004895503632724285, ...</td>\n",
2022-03-11 02:08:53 +00:00
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>def log_info(message, **params):\\n msg = lo...</td>\n",
" <td>log_info</td>\n",
" <td>openai/util.py</td>\n",
" <td>[0.0048639848828315735, 0.0033139237202703953,...</td>\n",
2022-03-11 02:08:53 +00:00
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>def log_warn(message, **params):\\n msg = lo...</td>\n",
" <td>log_warn</td>\n",
" <td>openai/util.py</td>\n",
" <td>[0.0024026145692914724, -0.010721310041844845,...</td>\n",
2022-03-11 02:08:53 +00:00
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>def logfmt(props):\\n def fmt(key, val):\\n ...</td>\n",
" <td>logfmt</td>\n",
" <td>openai/util.py</td>\n",
" <td>[0.01664826273918152, 0.01730910874903202, 0.0...</td>\n",
2022-03-11 02:08:53 +00:00
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" code function_name \\\n",
"0 def _console_log_level():\\n if openai.log i... _console_log_level \n",
"1 def log_debug(message, **params):\\n msg = l... log_debug \n",
"2 def log_info(message, **params):\\n msg = lo... log_info \n",
"3 def log_warn(message, **params):\\n msg = lo... log_warn \n",
"4 def logfmt(props):\\n def fmt(key, val):\\n ... logfmt \n",
2022-03-11 02:08:53 +00:00
"\n",
" filepath code_embedding \n",
"0 openai/util.py [0.033906757831573486, -0.00418944051489234, 0... \n",
"1 openai/util.py [-0.004059609025716782, 0.004895503632724285, ... \n",
"2 openai/util.py [0.0048639848828315735, 0.0033139237202703953,... \n",
"3 openai/util.py [0.0024026145692914724, -0.010721310041844845,... \n",
"4 openai/util.py [0.01664826273918152, 0.01730910874903202, 0.0... "
2022-03-11 02:08:53 +00:00
]
},
"execution_count": 11,
2022-03-11 02:08:53 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from utils.embeddings_utils import get_embedding\n",
2022-03-11 02:08:53 +00:00
"\n",
"df = pd.DataFrame(all_funcs)\n",
"df['code_embedding'] = df['code'].apply(lambda x: get_embedding(x, model='text-embedding-ada-002'))\n",
"df['filepath'] = df['filepath'].map(lambda x: Path(x).relative_to(code_root))\n",
"df.to_csv(\"data/code_search_openai-python.csv\", index=False)\n",
2022-03-11 02:08:53 +00:00
"df.head()"
]
},
{
2023-10-18 00:39:42 +00:00
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Testing\n",
"\n",
"Let's test our endpoint with some simple queries. If you're familiar with the `openai-python` repository, you'll see that we're able to easily find functions we're looking for only a simple English description.\n",
"\n",
"We define a search_functions method that takes our data that contains our embeddings, a query string, and some other configuration options. The process of searching our database works like such:\n",
"\n",
"1. We first embed our query string (code_query) with text-embedding-ada-002. The reasoning here is that a query string like 'a function that reverses a string' and a function like 'def reverse(string): return string[::-1]' will be very similar when embedded.\n",
"2. We then calculate the cosine similarity between our query string embedding and all data points in our database. This gives a distance between each point and our query.\n",
"3. We finally sort all of our data points by their distance to our query string and return the number of results requested in the function parameters. "
]
},
2022-03-11 02:08:53 +00:00
{
"cell_type": "code",
"execution_count": 1,
2022-03-11 02:08:53 +00:00
"metadata": {},
"outputs": [],
2022-03-11 02:08:53 +00:00
"source": [
"from utils.embeddings_utils import cosine_similarity\n",
2022-03-11 02:08:53 +00:00
"\n",
"def search_functions(df, code_query, n=3, pprint=True, n_lines=7):\n",
" embedding = get_embedding(code_query, model='text-embedding-ada-002')\n",
2022-03-11 02:08:53 +00:00
" df['similarities'] = df.code_embedding.apply(lambda x: cosine_similarity(x, embedding))\n",
"\n",
" res = df.sort_values('similarities', ascending=False).head(n)\n",
"\n",
2022-03-11 02:08:53 +00:00
" if pprint:\n",
" for r in res.iterrows():\n",
" print(f\"{r[1].filepath}:{r[1].function_name} score={round(r[1].similarities, 3)}\")\n",
2022-03-11 02:08:53 +00:00
" print(\"\\n\".join(r[1].code.split(\"\\n\")[:n_lines]))\n",
" print('-' * 70)\n",
"\n",
" return res"
2022-03-11 02:08:53 +00:00
]
},
{
"cell_type": "code",
"execution_count": 13,
2022-03-11 02:08:53 +00:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"openai/validators.py:format_inferrer_validator score=0.751\n",
2022-03-11 02:08:53 +00:00
"def format_inferrer_validator(df):\n",
" \"\"\"\n",
" This validator will infer the likely fine-tuning format of the data, and display it to the user if it is classification.\n",
" It will also suggest to use ada and explain train/validation split benefits.\n",
" \"\"\"\n",
" ft_type = infer_task_type(df)\n",
" immediate_msg = None\n",
"----------------------------------------------------------------------\n",
"openai/validators.py:get_validators score=0.748\n",
"def get_validators():\n",
" return [\n",
" num_examples_validator,\n",
" lambda x: necessary_column_validator(x, \"prompt\"),\n",
" lambda x: necessary_column_validator(x, \"completion\"),\n",
" additional_column_validator,\n",
" non_empty_field_validator,\n",
2022-03-11 02:08:53 +00:00
"----------------------------------------------------------------------\n",
"openai/validators.py:infer_task_type score=0.739\n",
"def infer_task_type(df):\n",
2022-03-11 02:08:53 +00:00
" \"\"\"\n",
" Infer the likely fine-tuning task type from the data\n",
2022-03-11 02:08:53 +00:00
" \"\"\"\n",
" CLASSIFICATION_THRESHOLD = 3 # min_average instances of each class\n",
" if sum(df.prompt.str.len()) == 0:\n",
" return \"open-ended generation\"\n",
2022-03-11 02:08:53 +00:00
"----------------------------------------------------------------------\n"
]
}
],
"source": [
"res = search_functions(df, 'fine-tuning input data validation logic', n=3)"
]
},
{
"cell_type": "code",
"execution_count": 14,
2022-03-11 02:08:53 +00:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"openai/validators.py:get_common_xfix score=0.794\n",
"def get_common_xfix(series, xfix=\"suffix\"):\n",
" \"\"\"\n",
" Finds the longest common suffix or prefix of all the values in a series\n",
" \"\"\"\n",
" common_xfix = \"\"\n",
" while True:\n",
" common_xfixes = (\n",
" series.str[-(len(common_xfix) + 1) :]\n",
" if xfix == \"suffix\"\n",
" else series.str[: len(common_xfix) + 1]\n",
"----------------------------------------------------------------------\n",
"openai/validators.py:common_completion_suffix_validator score=0.778\n",
2022-03-11 02:08:53 +00:00
"def common_completion_suffix_validator(df):\n",
" \"\"\"\n",
" This validator will suggest to add a common suffix to the completion if one doesn't already exist in case of classification or conditional generation.\n",
" \"\"\"\n",
" error_msg = None\n",
" immediate_msg = None\n",
" optional_msg = None\n",
" optional_fn = None\n",
"\n",
" ft_type = infer_task_type(df)\n",
"----------------------------------------------------------------------\n"
]
}
],
"source": [
"res = search_functions(df, 'find common suffix', n=2, n_lines=10)"
]
},
{
"cell_type": "code",
"execution_count": 15,
2022-03-11 02:08:53 +00:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"openai/cli.py:tools_register score=0.78\n",
2022-03-11 02:08:53 +00:00
"def tools_register(parser):\n",
" subparsers = parser.add_subparsers(\n",
" title=\"Tools\", help=\"Convenience client side tools\"\n",
" )\n",
"\n",
" def help(args):\n",
" parser.print_help()\n",
"\n",
" parser.set_defaults(func=help)\n",
"\n",
" sub = subparsers.add_parser(\"fine_tunes.prepare_data\")\n",
" sub.add_argument(\n",
" \"-f\",\n",
" \"--file\",\n",
" required=True,\n",
" help=\"JSONL, JSON, CSV, TSV, TXT or XLSX file containing prompt-completion examples to be analyzed.\"\n",
" \"This should be the local file path.\",\n",
" )\n",
" sub.add_argument(\n",
" \"-q\",\n",
"----------------------------------------------------------------------\n"
]
}
],
"source": [
"res = search_functions(df, 'Command line interface for fine-tuning', n=1, n_lines=20)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "openai",
"language": "python",
"name": "python3"
2022-03-11 02:08:53 +00:00
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
2022-03-11 02:08:53 +00:00
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}