mirror of
https://github.com/openai/openai-cookbook
synced 2024-11-19 15:25:37 +00:00
157 lines
5.2 KiB
Python
157 lines
5.2 KiB
Python
"""
|
|
Note: To answer questions based on text documents, we recommend the procedure in
|
|
[Question Answering using Embeddings](https://github.com/openai/openai-cookbook/blob/main/examples/Question_answering_using_embeddings.ipynb).
|
|
Some of the code below may rely on [deprecated API endpoints](https://github.com/openai/openai-cookbook/tree/main/transition_guides_for_deprecated_API_endpoints).
|
|
"""
|
|
|
|
import argparse
|
|
|
|
import openai
|
|
|
|
|
|
def create_context(
|
|
question, search_file_id, max_len=1800, search_model="ada", max_rerank=10
|
|
):
|
|
"""
|
|
Create a context for a question by finding the most similar context from the search file.
|
|
:param question: The question
|
|
:param search_file_id: The file id of the search file
|
|
:param max_len: The maximum length of the returned context (in tokens)
|
|
:param search_model: The search model to use
|
|
:param max_rerank: The maximum number of reranking
|
|
:return: The context
|
|
"""
|
|
results = openai.Engine(search_model).search(
|
|
search_model=search_model,
|
|
query=question,
|
|
max_rerank=max_rerank,
|
|
file=search_file_id,
|
|
return_metadata=True,
|
|
)
|
|
returns = []
|
|
cur_len = 0
|
|
for result in results["data"]:
|
|
cur_len += int(result["metadata"]) + 4
|
|
if cur_len > max_len:
|
|
break
|
|
returns.append(result["text"])
|
|
return "\n\n###\n\n".join(returns)
|
|
|
|
|
|
def answer_question(
|
|
search_file_id="<SEARCH_FILE_ID>",
|
|
fine_tuned_qa_model="<FT_QA_MODEL_ID>",
|
|
question="Which country won the European Football championship in 2021?",
|
|
max_len=1800,
|
|
search_model="ada",
|
|
max_rerank=10,
|
|
debug=False,
|
|
stop_sequence=["\n", "."],
|
|
max_tokens=100,
|
|
):
|
|
"""
|
|
Answer a question based on the most similar context from the search file, using your fine-tuned model.
|
|
:param question: The question
|
|
:param fine_tuned_qa_model: The fine tuned QA model
|
|
:param search_file_id: The file id of the search file
|
|
:param max_len: The maximum length of the returned context (in tokens)
|
|
:param search_model: The search model to use
|
|
:param max_rerank: The maximum number of reranking
|
|
:param debug: Whether to output debug information
|
|
:param stop_sequence: The stop sequence for Q&A model
|
|
:param max_tokens: The maximum number of tokens to return
|
|
:return: The answer
|
|
"""
|
|
context = create_context(
|
|
question,
|
|
search_file_id,
|
|
max_len=max_len,
|
|
search_model=search_model,
|
|
max_rerank=max_rerank,
|
|
)
|
|
if debug:
|
|
print("Context:\n" + context)
|
|
print("\n\n")
|
|
try:
|
|
# fine-tuned models requires model parameter, whereas other models require engine parameter
|
|
model_param = (
|
|
{"model": fine_tuned_qa_model}
|
|
if ":" in fine_tuned_qa_model
|
|
and fine_tuned_qa_model.split(":")[1].startswith("ft")
|
|
else {"engine": fine_tuned_qa_model}
|
|
)
|
|
response = openai.Completion.create(
|
|
prompt=f"Answer the question based on the context below\n\nText: {context}\n\n---\n\nQuestion: {question}\nAnswer:",
|
|
temperature=0,
|
|
max_tokens=max_tokens,
|
|
top_p=1,
|
|
frequency_penalty=0,
|
|
presence_penalty=0,
|
|
stop=stop_sequence,
|
|
**model_param,
|
|
)
|
|
return response["choices"][0]["text"]
|
|
except Exception as e:
|
|
print(e)
|
|
return ""
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description="Rudimentary functionality of the answers endpoint with a fine-tuned Q&A model.",
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
)
|
|
parser.add_argument(
|
|
"--search_file_id", help="Search file id", required=True, type=str
|
|
)
|
|
parser.add_argument(
|
|
"--fine_tuned_qa_model", help="Fine-tuned QA model id", required=True, type=str
|
|
)
|
|
parser.add_argument(
|
|
"--question", help="Question to answer", required=True, type=str
|
|
)
|
|
parser.add_argument(
|
|
"--max_len",
|
|
help="Maximum length of the returned context (in tokens)",
|
|
default=1800,
|
|
type=int,
|
|
)
|
|
parser.add_argument(
|
|
"--search_model", help="Search model to use", default="ada", type=str
|
|
)
|
|
parser.add_argument(
|
|
"--max_rerank",
|
|
help="Maximum number of reranking for the search",
|
|
default=10,
|
|
type=int,
|
|
)
|
|
parser.add_argument(
|
|
"--debug", help="Print debug information (context used)", action="store_true"
|
|
)
|
|
parser.add_argument(
|
|
"--stop_sequence",
|
|
help="Stop sequences for the Q&A model",
|
|
default=["\n", "."],
|
|
nargs="+",
|
|
type=str,
|
|
)
|
|
parser.add_argument(
|
|
"--max_tokens",
|
|
help="Maximum number of tokens to return",
|
|
default=100,
|
|
type=int,
|
|
)
|
|
args = parser.parse_args()
|
|
response = answer_question(
|
|
search_file_id=args.search_file_id,
|
|
fine_tuned_qa_model=args.fine_tuned_qa_model,
|
|
question=args.question,
|
|
max_len=args.max_len,
|
|
search_model=args.search_model,
|
|
max_rerank=args.max_rerank,
|
|
debug=args.debug,
|
|
stop_sequence=args.stop_sequence,
|
|
max_tokens=args.max_tokens,
|
|
)
|
|
print(f"Answer:{response}")
|