mirror of
https://github.com/openai/openai-cookbook
synced 2024-11-15 18:13:18 +00:00
190 lines
6.6 KiB
Python
190 lines
6.6 KiB
Python
from typing import List, Union
|
|
|
|
from smokey import Smokey
|
|
|
|
import openai
|
|
|
|
|
|
def get_candidates(
|
|
prompt: str,
|
|
stop: List[str],
|
|
temperature: float,
|
|
priming_prefix: str,
|
|
engine: str,
|
|
n: int = 5,
|
|
) -> List[str]:
|
|
"""
|
|
Generate N candidate completions based on the prompt, generated with a specific temperature.
|
|
|
|
:param prompt: The prompt to start the conversation with.
|
|
:param stop: A list of tokens that indicate the end of the generation.
|
|
:param temperature: The temperature of the generation.
|
|
:param priming_prefix: The prefix to use for the priming.
|
|
:param engine: The engine to use for the generation.
|
|
:param n: The number of completions to generate.
|
|
:return: A list of completions.
|
|
"""
|
|
response = openai.Completion.create(
|
|
engine=engine,
|
|
prompt=prompt,
|
|
temperature=temperature,
|
|
max_tokens=150,
|
|
top_p=1,
|
|
frequency_penalty=0,
|
|
presence_penalty=0,
|
|
stop=stop,
|
|
n=n,
|
|
)
|
|
responses = [priming_prefix + choice.text for choice in response.choices]
|
|
return responses
|
|
|
|
|
|
def rindex(lst: List, value: str) -> int:
|
|
"""
|
|
Return the index of the last occurrence of a value in a list.
|
|
|
|
:param lst: The list to search in.
|
|
:param value: The value to search for.
|
|
:return: The index of the last occurrence of the value.
|
|
"""
|
|
try:
|
|
return len(lst) - lst[::-1].index(value) - 1
|
|
except ValueError:
|
|
raise ValueError(f"Answer start token `{value}` not found in the eval template")
|
|
|
|
|
|
def eval_candidate(
|
|
candidate_answer: str,
|
|
original_instruction: str,
|
|
eval_template: str,
|
|
answer_start_token: str,
|
|
engine: str,
|
|
) -> float:
|
|
"""
|
|
Evaluate a candidate answer by calculating the average log probability
|
|
of the original instruction, given the candidate answer with a specific
|
|
evaluation template, aimed at reconstructing the original instruction.
|
|
|
|
:param candidate_answer: The candidate answer to evaluate.
|
|
:param original_instruction: The original instruction.
|
|
:param eval_template: The template to use for the evaluation.
|
|
:param answer_start_token: The token to use to indicate the start of the answer.
|
|
:param engine: The engine to use for the evaluation.
|
|
:return: The evaluation of the candidate answer.
|
|
"""
|
|
response = openai.Completion.create(
|
|
engine=engine,
|
|
prompt=eval_template.format(candidate_answer, original_instruction),
|
|
temperature=0,
|
|
max_tokens=0,
|
|
top_p=1,
|
|
frequency_penalty=0,
|
|
presence_penalty=0,
|
|
logprobs=1,
|
|
echo=True,
|
|
)
|
|
|
|
answer_start = rindex(
|
|
response["choices"][0]["logprobs"]["tokens"], answer_start_token
|
|
)
|
|
logprobs = response["choices"][0]["logprobs"]["token_logprobs"][answer_start + 1 :]
|
|
return sum(logprobs) / len(logprobs)
|
|
|
|
|
|
def backtranslation(
|
|
prompt_template: str,
|
|
additional_info: str,
|
|
instruction: str,
|
|
eval_template: str,
|
|
priming_prefix: str = "SELECT",
|
|
stop1: List[str] = ["#", ";"],
|
|
answer_start_token: str = "--",
|
|
n: int = 5,
|
|
temperature: float = 0.5,
|
|
return_all_results: bool = False,
|
|
engine: str = "davinci-codex",
|
|
) -> Union[str, List[str, float]]:
|
|
"""
|
|
Generate a number of SQL queries given a natural language instruction,
|
|
and pick the best one based on the average log probability of explaining the
|
|
candidate SQL query with the exact original instruction, when prompted for
|
|
a natural language explanation of the candidate SQL query.
|
|
|
|
:param prompt_template: The template to use for the prompt to generate SQL.
|
|
:param additional_info: Additional information to include in the prompt
|
|
(SQL Tables, and their properties).
|
|
:param instruction: The instruction in natural language.
|
|
:param eval_template: The template to use for the evaluation.
|
|
:param priming_prefix: The prefix to use for the priming of the SQL query.
|
|
:param stop1: A list of tokens that indicate the end of the generation.
|
|
:param answer_start_token: The token to use to indicate the start of the
|
|
natural answer.
|
|
:param n: The number of candidates to generate.
|
|
:param temperature: The temperature of the generation.
|
|
:param return_all_results: Whether to return all results or just the best one.
|
|
:param engine: The engine to use for the generation and evaluation.
|
|
:return: The best SQL query, or a list of all scored generated SQL queries.
|
|
"""
|
|
prompt_template = prompt_template.format(
|
|
additional_info, instruction, priming_prefix
|
|
)
|
|
|
|
candidates = []
|
|
responses = get_candidates(
|
|
prompt_template, stop1, temperature, priming_prefix, engine=engine, n=n
|
|
)
|
|
for i in range(n):
|
|
quality = eval_candidate(
|
|
responses[i],
|
|
instruction,
|
|
eval_template,
|
|
answer_start_token,
|
|
engine=engine,
|
|
)
|
|
candidates.append((responses[i], quality))
|
|
|
|
candidates.sort(key=lambda x: x[1], reverse=True)
|
|
if return_all_results:
|
|
return candidates
|
|
return candidates[0][0]
|
|
|
|
|
|
def main(
|
|
nl_query: str = "Return the name of each department that had more than 10 employees in June 2021",
|
|
eval_template: str = "{};\n-- Explanation of the above query in human readable format\n-- {}",
|
|
table_definitions: str = "# Employee(id, name, department_id)\n# Department(id, name, address)\n# Salary_Payments(id, employee_id, amount, date)\n",
|
|
prompt_template: str = "### Postgres SQL tables, with their properties:\n#\n{}#\n### {}\n{}",
|
|
n: int = 3,
|
|
temperature: float = 0.3,
|
|
engine: str = "davinci-codex",
|
|
):
|
|
"""
|
|
Generate a number of SQL queries given a natural language instruction,
|
|
and pick the best one based on the highest backtranslation score.
|
|
|
|
:param nl_query: The natural language query.
|
|
:param eval_template: The template to use for the evaluation.
|
|
:param table_definitions: The definitions of the tables used in the query.
|
|
:param prompt_template: The template to use for the prompt to generate SQL.
|
|
:param n: The number of candidates to generate.
|
|
:param temperature: The temperature of the generation.
|
|
:param engine: The engine to use for the generation and evaluation.
|
|
:return: The best SQL query, or a list of all scored generated SQL queries.
|
|
"""
|
|
|
|
result = backtranslation(
|
|
prompt_template,
|
|
table_definitions,
|
|
nl_query,
|
|
eval_template,
|
|
priming_prefix="SELECT",
|
|
temperature=temperature,
|
|
n=n,
|
|
engine=engine,
|
|
)
|
|
print(result)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
Smokey(main)
|