From b978a7640276ec35dc1ec897e97c13ef6c4b62c4 Mon Sep 17 00:00:00 2001 From: rsaryev Date: Tue, 22 Aug 2023 00:09:22 +0300 Subject: [PATCH] Refactor get_repo() and load_files() functions to use Repo() without root_dir.Refactored `load_files` added a delay when creating vector store. --- .gitignore | 3 ++- README.md | 1 + pyproject.toml | 2 +- talk_codebase/cli.py | 9 +++++-- talk_codebase/config.py | 1 - talk_codebase/llm.py | 16 ++++++++---- talk_codebase/utils.py | 55 +++++++++++++++-------------------------- 7 files changed, 42 insertions(+), 45 deletions(-) diff --git a/.gitignore b/.gitignore index 6d4fc84..c2f4ace 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ /.vscode/ /.venv/ /talk_codebase/__pycache__/ -.DS_Store \ No newline at end of file +.DS_Store +/vector_store/ diff --git a/README.md b/README.md index 03db6a1..238b7c7 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ talk-codebase is still under development and is recommended for educational purp ## Installation Requirement Python 3.8.1 or higher +Your project must be in a git repository ```bash pip install talk-codebase diff --git a/pyproject.toml b/pyproject.toml index 9828a0b..ef2a499 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "talk-codebase" -version = "0.1.46" +version = "0.1.47" description = "talk-codebase is a powerful tool for querying and analyzing codebases." authors = ["Saryev Rustam "] readme = "README.md" diff --git a/talk_codebase/cli.py b/talk_codebase/cli.py index 7abb561..063154c 100644 --- a/talk_codebase/cli.py +++ b/talk_codebase/cli.py @@ -6,6 +6,7 @@ from talk_codebase.config import CONFIGURE_STEPS, save_config, get_config, confi remove_model_type, remove_model_name_local from talk_codebase.consts import DEFAULT_CONFIG from talk_codebase.llm import factory_llm +from talk_codebase.utils import get_repo def check_python_version(): @@ -44,10 +45,14 @@ def chat_loop(llm): llm.send_query(query) -def chat(root_dir=None): +def chat(): configure(False) config = get_config() - llm = factory_llm(root_dir, config) + repo = get_repo() + if not repo: + print("🤖 Git repository not found") + sys.exit(1) + llm = factory_llm(repo.working_dir, config) chat_loop(llm) diff --git a/talk_codebase/config.py b/talk_codebase/config.py index 57c0f55..622c524 100644 --- a/talk_codebase/config.py +++ b/talk_codebase/config.py @@ -152,7 +152,6 @@ def configure_model_type(config): ).ask() config["model_type"] = model_type save_config(config) - print("🤖 Model type saved!") CONFIGURE_STEPS = [ diff --git a/talk_codebase/llm.py b/talk_codebase/llm.py index d04710b..2303336 100644 --- a/talk_codebase/llm.py +++ b/talk_codebase/llm.py @@ -1,4 +1,5 @@ import os +import time from typing import Optional import gpt4all @@ -40,7 +41,7 @@ class BaseLLM: if new_db is not None: return new_db.as_retriever(search_kwargs={"k": k}) - docs = load_files(root_dir) + docs = load_files() if len(docs) == 0: print("✘ No documents found") exit(0) @@ -60,9 +61,13 @@ class BaseLLM: exit(0) spinners = Halo(text=f"Creating vector store", spinner='dots').start() - db = FAISS.from_documents(texts, embeddings) - db.add_documents(texts) - db.save_local(index_path) + db = FAISS.from_documents([texts[0]], embeddings) + for i, text in enumerate(texts[1:]): + spinners.text = f"Creating vector store ({i + 1}/{len(texts)})" + db.add_documents([text]) + db.save_local(index_path) + time.sleep(1.5) + spinners.succeed(f"Created vector store") return db.as_retriever(search_kwargs={"k": k}) @@ -93,7 +98,8 @@ class LocalLLM(BaseLLM): model_n_ctx = int(self.config.get("max_tokens")) model_n_batch = int(self.config.get("n_batch")) callbacks = CallbackManager([StreamStdOut()]) - llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, n_batch=model_n_batch, callbacks=callbacks, verbose=False) + llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, n_batch=model_n_batch, callbacks=callbacks, + verbose=False) llm.client.verbose = False return llm diff --git a/talk_codebase/utils.py b/talk_codebase/utils.py index a0e970f..2622e6a 100644 --- a/talk_codebase/utils.py +++ b/talk_codebase/utils.py @@ -1,6 +1,3 @@ -import glob -import multiprocessing -import os import sys import tiktoken @@ -11,23 +8,13 @@ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from talk_codebase.consts import LOADER_MAPPING, EXCLUDE_FILES -def get_repo(root_dir): +def get_repo(): try: - return Repo(root_dir) + return Repo() except: return None -def is_ignored(path, root_dir): - repo = get_repo(root_dir) - if repo is None: - return False - if not os.path.exists(path): - return False - ignored = repo.ignored(path) - return len(ignored) > 0 - - class StreamStdOut(StreamingStdOutCallbackHandler): def on_llm_new_token(self, token: str, **kwargs) -> None: sys.stdout.write(token) @@ -41,26 +28,24 @@ class StreamStdOut(StreamingStdOutCallbackHandler): sys.stdout.flush() -def load_files(root_dir): - num_cpus = multiprocessing.cpu_count() - with multiprocessing.Pool(num_cpus) as pool: - futures = [] - for file_path in glob.glob(os.path.join(root_dir, '**/*'), recursive=True): - if is_ignored(file_path, root_dir): - continue - if any( - file_path.endswith(exclude_file) for exclude_file in EXCLUDE_FILES): - continue - for ext in LOADER_MAPPING: - if file_path.endswith(ext): - print('\r' + f'📂 Loading files: {file_path}') - args = LOADER_MAPPING[ext]['args'] - loader = LOADER_MAPPING[ext]['loader'](file_path, *args) - futures.append(pool.apply_async(loader.load)) - docs = [] - for future in futures: - docs.extend(future.get()) - return docs +def load_files(): + repo = get_repo() + if repo is None: + return [] + files = [] + tree = repo.tree() + for blob in tree.traverse(): + path = blob.path + if any( + path.endswith(exclude_file) for exclude_file in EXCLUDE_FILES): + continue + for ext in LOADER_MAPPING: + if path.endswith(ext): + print('\r' + f'📂 Loading files: {path}') + args = LOADER_MAPPING[ext]['args'] + loader = LOADER_MAPPING[ext]['loader'](path, *args) + files.extend(loader.load()) + return files def calculate_cost(texts, model_name):