openai-cookbook/examples/Backtranslation_of_SQL_queries.py
2022-06-03 12:56:03 -07:00

190 lines
6.6 KiB
Python

from typing import List, Union
from smokey import Smokey
import openai
def get_candidates(
prompt: str,
stop: List[str],
temperature: float,
priming_prefix: str,
engine: str,
n: int = 5,
) -> List[str]:
"""
Generate N candidate completions based on the prompt, generated with a specific temperature.
:param prompt: The prompt to start the conversation with.
:param stop: A list of tokens that indicate the end of the generation.
:param temperature: The temperature of the generation.
:param priming_prefix: The prefix to use for the priming.
:param engine: The engine to use for the generation.
:param n: The number of completions to generate.
:return: A list of completions.
"""
response = openai.Completion.create(
engine=engine,
prompt=prompt,
temperature=temperature,
max_tokens=150,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
stop=stop,
n=n,
)
responses = [priming_prefix + choice.text for choice in response.choices]
return responses
def rindex(lst: List, value: str) -> int:
"""
Return the index of the last occurence of a value in a list.
:param lst: The list to search in.
:param value: The value to search for.
:return: The index of the last occurence of the value.
"""
try:
return len(lst) - lst[::-1].index(value) - 1
except ValueError:
raise ValueError(f"Answer start token `{value}` not found in the eval template")
def eval_candidate(
candidate_answer: str,
original_instruction: str,
eval_template: str,
answer_start_token: str,
engine: str,
) -> float:
"""
Evaluate a candidate answer by calculating the average log probability
of the original instruction, given the candidate answer with a specific
evaluation template, aimed at reconstructing the original instruction.
:param candidate_answer: The candidate answer to evaluate.
:param original_instruction: The original instruction.
:param eval_template: The template to use for the evaluation.
:param answer_start_token: The token to use to indicate the start of the answer.
:param engine: The engine to use for the evaluation.
:return: The evaluation of the candidate answer.
"""
response = openai.Completion.create(
engine=engine,
prompt=eval_template.format(candidate_answer, original_instruction),
temperature=0,
max_tokens=0,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
logprobs=1,
echo=True,
)
answer_start = rindex(
response["choices"][0]["logprobs"]["tokens"], answer_start_token
)
logprobs = response["choices"][0]["logprobs"]["token_logprobs"][answer_start + 1 :]
return sum(logprobs) / len(logprobs)
def backtranslation(
prompt_template: str,
additional_info: str,
instruction: str,
eval_template: str,
priming_prefix: str = "SELECT",
stop1: List[str] = ["#", ";"],
answer_start_token: str = "--",
n: int = 5,
temperature: float = 0.5,
return_all_results: bool = False,
engine: str = "davinci-codex",
) -> Union[str, List[str, float]]:
"""
Generate a number of SQL queries given a natural language instruction,
and pick the best one based on the average log probability of explaining the
candidate SQL query with the exact original instruction, when prompted for
a natural language explanation of the candidate SQL query.
:param prompt_template: The template to use for the prompt to generate SQL.
:param additional_info: Additional information to include in the prompt
(SQL Tables, and their properties).
:param instruction: The instruction in natural language.
:param eval_template: The template to use for the evaluation.
:param priming_prefix: The prefix to use for the priming of the SQL query.
:param stop1: A list of tokens that indicate the end of the generation.
:param answer_start_token: The token to use to indicate the start of the
natural answer.
:param n: The number of candidates to generate.
:param temperature: The temperature of the generation.
:param return_all_results: Whether to return all results or just the best one.
:param engine: The engine to use for the generation and evaluation.
:return: The best SQL query, or a list of all scored generated SQL queries.
"""
prompt_template = prompt_template.format(
additional_info, instruction, priming_prefix
)
candidates = []
responses = get_candidates(
prompt_template, stop1, temperature, priming_prefix, engine=engine, n=n
)
for i in range(n):
quality = eval_candidate(
responses[i],
instruction,
eval_template,
answer_start_token,
engine=engine,
)
candidates.append((responses[i], quality))
candidates.sort(key=lambda x: x[1], reverse=True)
if return_all_results:
return candidates
return candidates[0][0]
def main(
nl_query: str = "Return the name of each department that had more than 10 employees in June 2021",
eval_template: str = "{};\n-- Explanation of the above query in human readable format\n-- {}",
table_definitions: str = "# Employee(id, name, department_id)\n# Department(id, name, address)\n# Salary_Payments(id, employee_id, amount, date)\n",
prompt_template: str = "### Postgres SQL tables, with their properties:\n#\n{}#\n### {}\n{}",
n: int = 3,
temperature: float = 0.3,
engine: str = "davinci-codex",
):
"""
Generate a number of SQL queries given a natural language instruction,
and pick the best one based on the highest backtranslation score.
:param nl_query: The natural language query.
:param eval_template: The template to use for the evaluation.
:param table_definitions: The definitions of the tables used in the query.
:param prompt_template: The template to use for the prompt to generate SQL.
:param n: The number of candidates to generate.
:param temperature: The temperature of the generation.
:param engine: The engine to use for the generation and evaluation.
:return: The best SQL query, or a list of all scored generated SQL queries.
"""
result = backtranslation(
prompt_template,
table_definitions,
nl_query,
eval_template,
priming_prefix="SELECT",
temperature=temperature,
n=n,
engine=engine,
)
print(result)
if __name__ == "__main__":
Smokey(main)