mirror of
https://github.com/rsaryev/talk-codebase
synced 2024-11-10 07:10:31 +00:00
158 lines
3.7 KiB
Python
158 lines
3.7 KiB
Python
import os
|
||
|
||
import gpt4all
|
||
import openai
|
||
import questionary
|
||
import yaml
|
||
|
||
from talk_codebase.consts import MODEL_TYPES
|
||
|
||
config_path = os.path.join(os.path.expanduser("~"), ".talk_codebase_config.yaml")
|
||
|
||
|
||
def get_config():
|
||
if os.path.exists(config_path):
|
||
with open(config_path, "r") as f:
|
||
config = yaml.safe_load(f)
|
||
else:
|
||
config = {}
|
||
return config
|
||
|
||
|
||
def save_config(config):
|
||
with open(config_path, "w") as f:
|
||
yaml.dump(config, f)
|
||
|
||
|
||
def api_key_is_invalid(api_key):
|
||
if not api_key:
|
||
return True
|
||
try:
|
||
openai.api_key = api_key
|
||
openai.Engine.list()
|
||
except Exception:
|
||
return True
|
||
return False
|
||
|
||
|
||
def get_gpt_models(openai):
|
||
try:
|
||
model_lst = openai.Model.list()
|
||
except Exception:
|
||
print("✘ Failed to retrieve model list")
|
||
return []
|
||
|
||
return [i['id'] for i in model_lst['data'] if 'gpt' in i['id']]
|
||
|
||
|
||
def configure_model_name_openai(config):
|
||
api_key = config.get("api_key")
|
||
|
||
if config.get("model_type") != MODEL_TYPES["OPENAI"] or config.get("openai_model_name"):
|
||
return
|
||
|
||
openai.api_key = api_key
|
||
gpt_models = get_gpt_models(openai)
|
||
choices = [{"name": model, "value": model} for model in gpt_models]
|
||
|
||
if not choices:
|
||
print("ℹ No GPT models available")
|
||
return
|
||
|
||
model_name = questionary.select("🤖 Select model name:", choices).ask()
|
||
|
||
if not model_name:
|
||
print("✘ No model selected")
|
||
return
|
||
|
||
config["openai_model_name"] = model_name
|
||
save_config(config)
|
||
print("🤖 Model name saved!")
|
||
|
||
|
||
def remove_model_name_openai():
|
||
config = get_config()
|
||
config["openai_model_name"] = None
|
||
save_config(config)
|
||
|
||
|
||
def configure_model_name_local(config):
|
||
if config.get("model_type") != MODEL_TYPES["LOCAL"] or config.get("local_model_name"):
|
||
return
|
||
|
||
list_models = gpt4all.GPT4All.list_models()
|
||
|
||
def get_model_info(model):
|
||
return (
|
||
f"{model['name']} "
|
||
f"| {model['filename']} "
|
||
f"| {model['filesize']} "
|
||
f"| {model['parameters']} "
|
||
f"| {model['quant']} "
|
||
f"| {model['type']}"
|
||
)
|
||
|
||
choices = [
|
||
{"name": get_model_info(model), "value": model['filename']} for model in list_models
|
||
]
|
||
|
||
model_name = questionary.select("🤖 Select model name:", choices).ask()
|
||
config["local_model_name"] = model_name
|
||
save_config(config)
|
||
print("🤖 Model name saved!")
|
||
|
||
|
||
def get_and_validate_api_key():
|
||
prompt = "🤖 Enter your OpenAI API key: "
|
||
api_key = input(prompt)
|
||
while api_key_is_invalid(api_key):
|
||
print("✘ Invalid API key")
|
||
api_key = input(prompt)
|
||
return api_key
|
||
|
||
|
||
def configure_api_key(config):
|
||
if config.get("model_type") != MODEL_TYPES["OPENAI"]:
|
||
return
|
||
|
||
if api_key_is_invalid(config.get("api_key")):
|
||
api_key = get_and_validate_api_key()
|
||
config["api_key"] = api_key
|
||
save_config(config)
|
||
|
||
|
||
def remove_api_key():
|
||
config = get_config()
|
||
config["api_key"] = None
|
||
save_config(config)
|
||
|
||
|
||
def remove_model_type():
|
||
config = get_config()
|
||
config["model_type"] = None
|
||
save_config(config)
|
||
|
||
|
||
def configure_model_type(config):
|
||
if config.get("model_type"):
|
||
return
|
||
|
||
model_type = questionary.select(
|
||
"🤖 Select model type:",
|
||
choices=[
|
||
{"name": "Local", "value": MODEL_TYPES["LOCAL"]},
|
||
{"name": "OpenAI", "value": MODEL_TYPES["OPENAI"]},
|
||
]
|
||
).ask()
|
||
config["model_type"] = model_type
|
||
save_config(config)
|
||
print("🤖 Model type saved!")
|
||
|
||
|
||
CONFIGURE_STEPS = [
|
||
configure_model_type,
|
||
configure_api_key,
|
||
configure_model_name_openai,
|
||
configure_model_name_local,
|
||
]
|