|
|
|
@ -1,6 +1,7 @@
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
import gpt4all
|
|
|
|
|
import openai
|
|
|
|
|
import questionary
|
|
|
|
|
import yaml
|
|
|
|
|
|
|
|
|
@ -23,6 +24,58 @@ def save_config(config):
|
|
|
|
|
yaml.dump(config, f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def api_key_is_invalid(api_key):
|
|
|
|
|
if not api_key:
|
|
|
|
|
return True
|
|
|
|
|
try:
|
|
|
|
|
openai.api_key = api_key
|
|
|
|
|
openai.Engine.list()
|
|
|
|
|
except Exception:
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_gpt_models(openai):
|
|
|
|
|
try:
|
|
|
|
|
model_lst = openai.Model.list()
|
|
|
|
|
except Exception:
|
|
|
|
|
print("✘ Failed to retrieve model list")
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
return [i['id'] for i in model_lst['data'] if 'gpt' in i['id']]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def configure_model_name_openai(config):
|
|
|
|
|
api_key = config.get("api_key")
|
|
|
|
|
|
|
|
|
|
if config.get("model_type") != MODEL_TYPES["OPENAI"] or config.get("openai_model_name"):
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
openai.api_key = api_key
|
|
|
|
|
gpt_models = get_gpt_models(openai)
|
|
|
|
|
choices = [{"name": model, "value": model} for model in gpt_models]
|
|
|
|
|
|
|
|
|
|
if not choices:
|
|
|
|
|
print("ℹ No GPT models available")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
model_name = questionary.select("🤖 Select model name:", choices).ask()
|
|
|
|
|
|
|
|
|
|
if not model_name:
|
|
|
|
|
print("✘ No model selected")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
config["openai_model_name"] = model_name
|
|
|
|
|
save_config(config)
|
|
|
|
|
print("🤖 Model name saved!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def remove_model_name_openai():
|
|
|
|
|
config = get_config()
|
|
|
|
|
config["openai_model_name"] = None
|
|
|
|
|
save_config(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def configure_model_name_local(config):
|
|
|
|
|
if config.get("model_type") != MODEL_TYPES["LOCAL"] or config.get("local_model_name"):
|
|
|
|
|
return
|
|
|
|
@ -49,13 +102,22 @@ def configure_model_name_local(config):
|
|
|
|
|
print("🤖 Model name saved!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_and_validate_api_key():
|
|
|
|
|
prompt = "🤖 Enter your OpenAI API key: "
|
|
|
|
|
api_key = input(prompt)
|
|
|
|
|
while api_key_is_invalid(api_key):
|
|
|
|
|
print("✘ Invalid API key")
|
|
|
|
|
api_key = input(prompt)
|
|
|
|
|
return api_key
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def configure_api_key(config):
|
|
|
|
|
if config.get("model_type") != MODEL_TYPES["OPENAI"] or config.get("api_key"):
|
|
|
|
|
if config.get("model_type") != MODEL_TYPES["OPENAI"]:
|
|
|
|
|
return
|
|
|
|
|
api_key = input("🤖 Enter your OpenAI API key: ")
|
|
|
|
|
|
|
|
|
|
api_key = get_and_validate_api_key()
|
|
|
|
|
config["api_key"] = api_key
|
|
|
|
|
save_config(config)
|
|
|
|
|
print("🤖 API key saved!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def remove_api_key():
|
|
|
|
@ -88,6 +150,7 @@ def configure_model_type(config):
|
|
|
|
|
|
|
|
|
|
CONFIGURE_STEPS = [
|
|
|
|
|
configure_model_type,
|
|
|
|
|
configure_model_name_local,
|
|
|
|
|
configure_api_key,
|
|
|
|
|
configure_model_name_openai,
|
|
|
|
|
configure_model_name_local,
|
|
|
|
|
]
|
|
|
|
|