From d1bc2aa6c861387ed756b4e253c44129fbb6ed8c Mon Sep 17 00:00:00 2001 From: sean1832 Date: Sat, 4 Mar 2023 03:38:43 +1100 Subject: [PATCH] fix: non stream mode cannot use gpt-3.5 --- GPT/gpt_tools.py | 16 ++++++++++++++++ GPT/query.py | 10 ++++++++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/GPT/gpt_tools.py b/GPT/gpt_tools.py index 9512f64..7f0d32f 100644 --- a/GPT/gpt_tools.py +++ b/GPT/gpt_tools.py @@ -48,6 +48,22 @@ def gpt3(prompt, model, params): return text +def gpt35(prompt, params, system_role_content: str = 'You are a helpful assistant.'): + completions = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + max_tokens=params.max_tokens, + temperature=params.temp, + top_p=params.top_p, + frequency_penalty=params.frequency_penalty, + presence_penalty=params.present_penalty, + messages=[ + {"role": "system", "content": system_role_content}, + {"role": "user", "content": prompt} + ]) + text = completions['choices'][0]['message']['content'] + return text + + def gpt3_stream(prompt, model, params): response = openai.Completion.create( model=model, diff --git a/GPT/query.py b/GPT/query.py index 642d40d..e25b75d 100644 --- a/GPT/query.py +++ b/GPT/query.py @@ -47,7 +47,10 @@ def run(query, model, prompt_file, isQuestion, params, info_file=None): prompt = prompt.replace('<>', query) prompt = prompt.replace('<>', my_info) - answer = GPT.gpt_tools.gpt3(prompt, model, params) + if model == 'gpt-3.5-turbo': + answer = GPT.gpt_tools.gpt35(prompt, params) + else: + answer = GPT.gpt_tools.gpt3(prompt, model, params) answers.append(answer) all_response = '\n\n'.join(answers) else: @@ -55,7 +58,10 @@ def run(query, model, prompt_file, isQuestion, params, info_file=None): responses = [] for chunk in chunks: prompt = util.read_file(prompt_file).replace('<>', chunk) - response = GPT.gpt_tools.gpt3(prompt, model, params) + if model == 'gpt-3.5-turbo': + response = GPT.gpt_tools.gpt35(prompt, params) + else: + response = GPT.gpt_tools.gpt3(prompt, model, params) responses.append(response) all_response = '\n\n'.join(responses) return all_response