Refactor get_repo() and load_files() functions to use Repo() without root_dir.Refactored `load_files` added a delay when creating vector store.

pull/19/head
rsaryev 10 months ago
parent 3d3e2dabd5
commit b978a76402
No known key found for this signature in database

3
.gitignore vendored

@ -3,4 +3,5 @@
/.vscode/ /.vscode/
/.venv/ /.venv/
/talk_codebase/__pycache__/ /talk_codebase/__pycache__/
.DS_Store .DS_Store
/vector_store/

@ -14,6 +14,7 @@ talk-codebase is still under development and is recommended for educational purp
## Installation ## Installation
Requirement Python 3.8.1 or higher Requirement Python 3.8.1 or higher
Your project must be in a git repository
```bash ```bash
pip install talk-codebase pip install talk-codebase

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "talk-codebase" name = "talk-codebase"
version = "0.1.46" version = "0.1.47"
description = "talk-codebase is a powerful tool for querying and analyzing codebases." description = "talk-codebase is a powerful tool for querying and analyzing codebases."
authors = ["Saryev Rustam <rustam1997@gmail.com>"] authors = ["Saryev Rustam <rustam1997@gmail.com>"]
readme = "README.md" readme = "README.md"

@ -6,6 +6,7 @@ from talk_codebase.config import CONFIGURE_STEPS, save_config, get_config, confi
remove_model_type, remove_model_name_local remove_model_type, remove_model_name_local
from talk_codebase.consts import DEFAULT_CONFIG from talk_codebase.consts import DEFAULT_CONFIG
from talk_codebase.llm import factory_llm from talk_codebase.llm import factory_llm
from talk_codebase.utils import get_repo
def check_python_version(): def check_python_version():
@ -44,10 +45,14 @@ def chat_loop(llm):
llm.send_query(query) llm.send_query(query)
def chat(root_dir=None): def chat():
configure(False) configure(False)
config = get_config() config = get_config()
llm = factory_llm(root_dir, config) repo = get_repo()
if not repo:
print("🤖 Git repository not found")
sys.exit(1)
llm = factory_llm(repo.working_dir, config)
chat_loop(llm) chat_loop(llm)

@ -152,7 +152,6 @@ def configure_model_type(config):
).ask() ).ask()
config["model_type"] = model_type config["model_type"] = model_type
save_config(config) save_config(config)
print("🤖 Model type saved!")
CONFIGURE_STEPS = [ CONFIGURE_STEPS = [

@ -1,4 +1,5 @@
import os import os
import time
from typing import Optional from typing import Optional
import gpt4all import gpt4all
@ -40,7 +41,7 @@ class BaseLLM:
if new_db is not None: if new_db is not None:
return new_db.as_retriever(search_kwargs={"k": k}) return new_db.as_retriever(search_kwargs={"k": k})
docs = load_files(root_dir) docs = load_files()
if len(docs) == 0: if len(docs) == 0:
print("✘ No documents found") print("✘ No documents found")
exit(0) exit(0)
@ -60,9 +61,13 @@ class BaseLLM:
exit(0) exit(0)
spinners = Halo(text=f"Creating vector store", spinner='dots').start() spinners = Halo(text=f"Creating vector store", spinner='dots').start()
db = FAISS.from_documents(texts, embeddings) db = FAISS.from_documents([texts[0]], embeddings)
db.add_documents(texts) for i, text in enumerate(texts[1:]):
db.save_local(index_path) spinners.text = f"Creating vector store ({i + 1}/{len(texts)})"
db.add_documents([text])
db.save_local(index_path)
time.sleep(1.5)
spinners.succeed(f"Created vector store") spinners.succeed(f"Created vector store")
return db.as_retriever(search_kwargs={"k": k}) return db.as_retriever(search_kwargs={"k": k})
@ -93,7 +98,8 @@ class LocalLLM(BaseLLM):
model_n_ctx = int(self.config.get("max_tokens")) model_n_ctx = int(self.config.get("max_tokens"))
model_n_batch = int(self.config.get("n_batch")) model_n_batch = int(self.config.get("n_batch"))
callbacks = CallbackManager([StreamStdOut()]) callbacks = CallbackManager([StreamStdOut()])
llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, n_batch=model_n_batch, callbacks=callbacks, verbose=False) llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, n_batch=model_n_batch, callbacks=callbacks,
verbose=False)
llm.client.verbose = False llm.client.verbose = False
return llm return llm

@ -1,6 +1,3 @@
import glob
import multiprocessing
import os
import sys import sys
import tiktoken import tiktoken
@ -11,23 +8,13 @@ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from talk_codebase.consts import LOADER_MAPPING, EXCLUDE_FILES from talk_codebase.consts import LOADER_MAPPING, EXCLUDE_FILES
def get_repo(root_dir): def get_repo():
try: try:
return Repo(root_dir) return Repo()
except: except:
return None return None
def is_ignored(path, root_dir):
repo = get_repo(root_dir)
if repo is None:
return False
if not os.path.exists(path):
return False
ignored = repo.ignored(path)
return len(ignored) > 0
class StreamStdOut(StreamingStdOutCallbackHandler): class StreamStdOut(StreamingStdOutCallbackHandler):
def on_llm_new_token(self, token: str, **kwargs) -> None: def on_llm_new_token(self, token: str, **kwargs) -> None:
sys.stdout.write(token) sys.stdout.write(token)
@ -41,26 +28,24 @@ class StreamStdOut(StreamingStdOutCallbackHandler):
sys.stdout.flush() sys.stdout.flush()
def load_files(root_dir): def load_files():
num_cpus = multiprocessing.cpu_count() repo = get_repo()
with multiprocessing.Pool(num_cpus) as pool: if repo is None:
futures = [] return []
for file_path in glob.glob(os.path.join(root_dir, '**/*'), recursive=True): files = []
if is_ignored(file_path, root_dir): tree = repo.tree()
continue for blob in tree.traverse():
if any( path = blob.path
file_path.endswith(exclude_file) for exclude_file in EXCLUDE_FILES): if any(
continue path.endswith(exclude_file) for exclude_file in EXCLUDE_FILES):
for ext in LOADER_MAPPING: continue
if file_path.endswith(ext): for ext in LOADER_MAPPING:
print('\r' + f'📂 Loading files: {file_path}') if path.endswith(ext):
args = LOADER_MAPPING[ext]['args'] print('\r' + f'📂 Loading files: {path}')
loader = LOADER_MAPPING[ext]['loader'](file_path, *args) args = LOADER_MAPPING[ext]['args']
futures.append(pool.apply_async(loader.load)) loader = LOADER_MAPPING[ext]['loader'](path, *args)
docs = [] files.extend(loader.load())
for future in futures: return files
docs.extend(future.get())
return docs
def calculate_cost(texts, model_name): def calculate_cost(texts, model_name):

Loading…
Cancel
Save