2023-08-15 22:25:02 +00:00
|
|
|
import sys
|
|
|
|
|
2023-05-26 01:15:30 +00:00
|
|
|
import fire
|
2023-05-27 01:13:34 +00:00
|
|
|
|
2023-07-05 01:54:25 +00:00
|
|
|
from talk_codebase.config import CONFIGURE_STEPS, save_config, get_config, config_path, remove_api_key, \
|
2023-08-15 22:25:02 +00:00
|
|
|
remove_model_type, remove_model_name_local
|
2023-05-29 21:48:41 +00:00
|
|
|
from talk_codebase.consts import DEFAULT_CONFIG
|
2023-07-05 01:54:25 +00:00
|
|
|
from talk_codebase.llm import factory_llm
|
2023-08-21 21:09:22 +00:00
|
|
|
from talk_codebase.utils import get_repo
|
2023-05-26 01:15:30 +00:00
|
|
|
|
|
|
|
|
2023-08-15 22:25:02 +00:00
|
|
|
def check_python_version():
|
|
|
|
if sys.version_info < (3, 8, 1):
|
|
|
|
print("🤖 Please use Python 3.8.1 or higher")
|
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
|
2023-07-12 00:41:09 +00:00
|
|
|
def update_config(config):
|
|
|
|
for key, value in DEFAULT_CONFIG.items():
|
|
|
|
if key not in config:
|
|
|
|
config[key] = value
|
|
|
|
return config
|
|
|
|
|
|
|
|
|
2023-07-05 01:54:25 +00:00
|
|
|
def configure(reset=True):
|
|
|
|
if reset:
|
|
|
|
remove_api_key()
|
|
|
|
remove_model_type()
|
2023-08-15 22:25:02 +00:00
|
|
|
remove_model_name_local()
|
2023-05-26 01:15:30 +00:00
|
|
|
config = get_config()
|
2023-07-12 00:41:09 +00:00
|
|
|
config = update_config(config)
|
2023-07-05 01:54:25 +00:00
|
|
|
for step in CONFIGURE_STEPS:
|
|
|
|
step(config)
|
2023-05-26 01:15:30 +00:00
|
|
|
save_config(config)
|
|
|
|
|
|
|
|
|
2023-07-12 00:41:09 +00:00
|
|
|
def chat_loop(llm):
|
2023-05-27 02:37:31 +00:00
|
|
|
while True:
|
2023-05-29 22:13:23 +00:00
|
|
|
query = input("👉 ").lower().strip()
|
|
|
|
if not query:
|
|
|
|
print("🤖 Please enter a query")
|
2023-05-27 02:37:31 +00:00
|
|
|
continue
|
2023-05-29 22:13:23 +00:00
|
|
|
if query in ('exit', 'quit'):
|
2023-05-27 02:37:31 +00:00
|
|
|
break
|
2023-05-29 22:13:23 +00:00
|
|
|
llm.send_query(query)
|
2023-05-29 21:48:41 +00:00
|
|
|
|
|
|
|
|
2023-08-21 21:09:22 +00:00
|
|
|
def chat():
|
2023-07-12 00:41:09 +00:00
|
|
|
configure(False)
|
|
|
|
config = get_config()
|
2023-08-21 21:09:22 +00:00
|
|
|
repo = get_repo()
|
|
|
|
if not repo:
|
|
|
|
print("🤖 Git repository not found")
|
|
|
|
sys.exit(1)
|
|
|
|
llm = factory_llm(repo.working_dir, config)
|
2023-07-12 00:41:09 +00:00
|
|
|
chat_loop(llm)
|
|
|
|
|
|
|
|
|
2023-05-29 21:48:41 +00:00
|
|
|
def main():
|
2023-08-15 22:25:02 +00:00
|
|
|
check_python_version()
|
2023-07-05 01:54:25 +00:00
|
|
|
print(f"🤖 Config path: {config_path}:")
|
2023-05-26 01:15:30 +00:00
|
|
|
try:
|
2023-05-29 21:48:41 +00:00
|
|
|
fire.Fire({
|
|
|
|
"chat": chat,
|
2023-07-05 01:54:25 +00:00
|
|
|
"configure": lambda: configure(True)
|
2023-05-29 21:48:41 +00:00
|
|
|
})
|
2023-05-26 01:15:30 +00:00
|
|
|
except KeyboardInterrupt:
|
|
|
|
print("\n🤖 Bye!")
|
|
|
|
except Exception as e:
|
2023-07-12 00:41:09 +00:00
|
|
|
raise e
|
2023-05-26 01:15:30 +00:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|