refactor: simplify chat loop

- Simplify the chat loop by moving it to a separate function called 'loop'.
- This improves code readability and makes it easier to test the chat function without needing to mock user input.
pull/1/head
Saryev Rustam 1 year ago
parent 5f08927c76
commit 11185c079f

@ -1,10 +1,10 @@
[tool.poetry]
name = "talk-codebase"
version = "0.1.14"
version = "0.1.15"
description = "talk-codebase is a powerful tool for querying and analyzing codebases."
authors = ["Saryev Rustam <rustam1997@gmail.com>"]
readme = "README.md"
packages = [{include = "talk_codebase"}]
packages = [{ include = "talk_codebase" }]
keywords = ["chatgpt", "openai", "cli"]
[tool.poetry.dependencies]

@ -33,6 +33,17 @@ def configure():
save_config(config)
def loop(vector_store, api_key, model_name):
while True:
question = input("👉 ")
if not question:
print("🤖 Please enter a question.")
continue
if question.lower() in ('exit', 'quit'):
break
send_question(question, vector_store, api_key, model_name)
def chat(root_dir):
try:
config = get_config()
@ -42,14 +53,7 @@ def chat(root_dir):
configure()
chat(root_dir)
vector_store = create_vector_store(root_dir, api_key)
while True:
question = input("👉 ")
if not question:
print("🤖 Please enter a question.")
continue
if question.lower() in ('exit', 'quit'):
break
send_question(question, vector_store, api_key, model_name)
loop(vector_store, api_key, model_name)
except KeyboardInterrupt:
print("\n🤖 Bye!")
except Exception as e:

@ -14,6 +14,9 @@ from talk_codebase.utils import StreamStdOut, load_files
@Halo(text='Creating vector store', spinner='dots')
def create_vector_store(root_dir, openai_api_key):
docs = load_files(root_dir)
if len(docs) == 0:
print("✘ No documents found")
exit(0)
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_documents(docs)

@ -9,19 +9,15 @@ from langchain.document_loaders import TextLoader
from talk_codebase.consts import EXCLUDE_DIRS, EXCLUDE_FILES, ALLOW_FILES
def get_repo():
def get_repo(root_dir):
try:
return Repo()
return Repo(root_dir)
except:
return None
def has_repo():
return get_repo() is not None
def is_ignored(path):
repo = get_repo()
def is_ignored(path, root_dir):
repo = get_repo(root_dir)
if repo is None:
return False
if not os.path.exists(path):
@ -47,14 +43,14 @@ def load_files(root_dir):
spinners = Halo(text='Loading files', spinner='dots')
docs = []
for dirpath, dirnames, filenames in os.walk(root_dir):
if is_ignored(dirpath):
if is_ignored(dirpath, root_dir):
continue
if any(exclude_dir in dirpath for exclude_dir in EXCLUDE_DIRS):
continue
if not filenames:
continue
for file in filenames:
if is_ignored(os.path.join(dirpath, file)):
if is_ignored(os.path.join(dirpath, file), root_dir):
continue
if any(file.endswith(allow_file) for allow_file in ALLOW_FILES) and not any(
file == exclude_file for exclude_file in EXCLUDE_FILES):

Loading…
Cancel
Save