talk-codebase/talk_codebase/config.py

163 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import gpt4all
import openai
import questionary
import yaml
from talk_codebase.consts import MODEL_TYPES
config_path = os.path.join(os.path.expanduser("~"), ".talk_codebase_config.yaml")
def get_config():
if os.path.exists(config_path):
with open(config_path, "r") as f:
config = yaml.safe_load(f)
else:
config = {}
return config
def save_config(config):
with open(config_path, "w") as f:
yaml.dump(config, f)
def api_key_is_invalid(api_key):
if not api_key:
return True
try:
openai.api_key = api_key
openai.Engine.list()
except Exception:
return True
return False
def get_gpt_models(openai):
try:
model_lst = openai.Model.list()
except Exception:
print("✘ Failed to retrieve model list")
return []
return [i['id'] for i in model_lst['data'] if 'gpt' in i['id']]
def configure_model_name_openai(config):
api_key = config.get("api_key")
if config.get("model_type") != MODEL_TYPES["OPENAI"] or config.get("openai_model_name"):
return
openai.api_key = api_key
gpt_models = get_gpt_models(openai)
choices = [{"name": model, "value": model} for model in gpt_models]
if not choices:
print(" No GPT models available")
return
model_name = questionary.select("🤖 Select model name:", choices).ask()
if not model_name:
print("✘ No model selected")
return
config["openai_model_name"] = model_name
save_config(config)
print("🤖 Model name saved!")
def remove_model_name_openai():
config = get_config()
config["openai_model_name"] = None
save_config(config)
def configure_model_name_local(config):
if config.get("model_type") != MODEL_TYPES["LOCAL"] or config.get("local_model_name"):
return
list_models = gpt4all.GPT4All.list_models()
def get_model_info(model):
return (
f"{model['name']} "
f"| {model['filename']} "
f"| {model['filesize']} "
f"| {model['parameters']} "
f"| {model['quant']} "
f"| {model['type']}"
)
choices = [
{"name": get_model_info(model), "value": model['filename']} for model in list_models
]
model_name = questionary.select("🤖 Select model name:", choices).ask()
config["local_model_name"] = model_name
save_config(config)
print("🤖 Model name saved!")
def remove_model_name_local():
config = get_config()
config["local_model_name"] = None
save_config(config)
def get_and_validate_api_key():
prompt = "🤖 Enter your OpenAI API key: "
api_key = input(prompt)
while api_key_is_invalid(api_key):
print("✘ Invalid API key")
api_key = input(prompt)
return api_key
def configure_api_key(config):
if config.get("model_type") != MODEL_TYPES["OPENAI"]:
return
if api_key_is_invalid(config.get("api_key")):
api_key = get_and_validate_api_key()
config["api_key"] = api_key
save_config(config)
def remove_api_key():
config = get_config()
config["api_key"] = None
save_config(config)
def remove_model_type():
config = get_config()
config["model_type"] = None
save_config(config)
def configure_model_type(config):
if config.get("model_type"):
return
model_type = questionary.select(
"🤖 Select model type:",
choices=[
{"name": "Local", "value": MODEL_TYPES["LOCAL"]},
{"name": "OpenAI", "value": MODEL_TYPES["OPENAI"]},
]
).ask()
config["model_type"] = model_type
save_config(config)
CONFIGURE_STEPS = [
configure_model_type,
configure_api_key,
configure_model_name_openai,
configure_model_name_local,
]