add OpenAI model selection and API key validation functionalities

pull/9/head
Saryev Rustam 12 months ago
parent 079e69dabd
commit 4aa2df1c56

@ -1,6 +1,7 @@
import os
import gpt4all
import openai
import questionary
import yaml
@ -23,6 +24,58 @@ def save_config(config):
yaml.dump(config, f)
def api_key_is_invalid(api_key):
if not api_key:
return True
try:
openai.api_key = api_key
openai.Engine.list()
except Exception:
return True
return False
def get_gpt_models(openai):
try:
model_lst = openai.Model.list()
except Exception:
print("✘ Failed to retrieve model list")
return []
return [i['id'] for i in model_lst['data'] if 'gpt' in i['id']]
def configure_model_name_openai(config):
api_key = config.get("api_key")
if config.get("model_type") != MODEL_TYPES["OPENAI"] or config.get("openai_model_name"):
return
openai.api_key = api_key
gpt_models = get_gpt_models(openai)
choices = [{"name": model, "value": model} for model in gpt_models]
if not choices:
print(" No GPT models available")
return
model_name = questionary.select("🤖 Select model name:", choices).ask()
if not model_name:
print("✘ No model selected")
return
config["openai_model_name"] = model_name
save_config(config)
print("🤖 Model name saved!")
def remove_model_name_openai():
config = get_config()
config["openai_model_name"] = None
save_config(config)
def configure_model_name_local(config):
if config.get("model_type") != MODEL_TYPES["LOCAL"] or config.get("local_model_name"):
return
@ -49,13 +102,22 @@ def configure_model_name_local(config):
print("🤖 Model name saved!")
def get_and_validate_api_key():
prompt = "🤖 Enter your OpenAI API key: "
api_key = input(prompt)
while api_key_is_invalid(api_key):
print("✘ Invalid API key")
api_key = input(prompt)
return api_key
def configure_api_key(config):
if config.get("model_type") != MODEL_TYPES["OPENAI"] or config.get("api_key"):
if config.get("model_type") != MODEL_TYPES["OPENAI"]:
return
api_key = input("🤖 Enter your OpenAI API key: ")
api_key = get_and_validate_api_key()
config["api_key"] = api_key
save_config(config)
print("🤖 API key saved!")
def remove_api_key():
@ -88,6 +150,7 @@ def configure_model_type(config):
CONFIGURE_STEPS = [
configure_model_type,
configure_model_name_local,
configure_api_key,
configure_model_name_openai,
configure_model_name_local,
]

@ -15,7 +15,6 @@ MODEL_TYPES = {
"LOCAL": "local",
}
DEFAULT_LOCAL_MODEL = "orca-mini-3b.ggmlv3.q4_0.bin"
DEFAULT_OPENAI_MODEL = "gpt-3.5-turbo"
DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\")
DEFAULT_CONFIG = {
@ -24,7 +23,6 @@ DEFAULT_CONFIG = {
"chunk_overlap": "256",
"k": "1",
"temperature": "0.7",
"openai_model_name": DEFAULT_OPENAI_MODEL,
"local_model_name": DEFAULT_LOCAL_MODEL,
"model_path": DEFAULT_MODEL_DIRECTORY,
"n_batch": "8",

Loading…
Cancel
Save