|
|
@ -1,6 +1,3 @@
|
|
|
|
import glob
|
|
|
|
|
|
|
|
import multiprocessing
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
import sys
|
|
|
|
import sys
|
|
|
|
|
|
|
|
|
|
|
|
import tiktoken
|
|
|
|
import tiktoken
|
|
|
@ -11,23 +8,13 @@ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
|
|
|
from talk_codebase.consts import LOADER_MAPPING, EXCLUDE_FILES
|
|
|
|
from talk_codebase.consts import LOADER_MAPPING, EXCLUDE_FILES
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_repo(root_dir):
|
|
|
|
def get_repo():
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
return Repo(root_dir)
|
|
|
|
return Repo()
|
|
|
|
except:
|
|
|
|
except:
|
|
|
|
return None
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_ignored(path, root_dir):
|
|
|
|
|
|
|
|
repo = get_repo(root_dir)
|
|
|
|
|
|
|
|
if repo is None:
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
if not os.path.exists(path):
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
ignored = repo.ignored(path)
|
|
|
|
|
|
|
|
return len(ignored) > 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StreamStdOut(StreamingStdOutCallbackHandler):
|
|
|
|
class StreamStdOut(StreamingStdOutCallbackHandler):
|
|
|
|
def on_llm_new_token(self, token: str, **kwargs) -> None:
|
|
|
|
def on_llm_new_token(self, token: str, **kwargs) -> None:
|
|
|
|
sys.stdout.write(token)
|
|
|
|
sys.stdout.write(token)
|
|
|
@ -41,26 +28,24 @@ class StreamStdOut(StreamingStdOutCallbackHandler):
|
|
|
|
sys.stdout.flush()
|
|
|
|
sys.stdout.flush()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_files(root_dir):
|
|
|
|
def load_files():
|
|
|
|
num_cpus = multiprocessing.cpu_count()
|
|
|
|
repo = get_repo()
|
|
|
|
with multiprocessing.Pool(num_cpus) as pool:
|
|
|
|
if repo is None:
|
|
|
|
futures = []
|
|
|
|
return []
|
|
|
|
for file_path in glob.glob(os.path.join(root_dir, '**/*'), recursive=True):
|
|
|
|
files = []
|
|
|
|
if is_ignored(file_path, root_dir):
|
|
|
|
tree = repo.tree()
|
|
|
|
continue
|
|
|
|
for blob in tree.traverse():
|
|
|
|
|
|
|
|
path = blob.path
|
|
|
|
if any(
|
|
|
|
if any(
|
|
|
|
file_path.endswith(exclude_file) for exclude_file in EXCLUDE_FILES):
|
|
|
|
path.endswith(exclude_file) for exclude_file in EXCLUDE_FILES):
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|
for ext in LOADER_MAPPING:
|
|
|
|
for ext in LOADER_MAPPING:
|
|
|
|
if file_path.endswith(ext):
|
|
|
|
if path.endswith(ext):
|
|
|
|
print('\r' + f'📂 Loading files: {file_path}')
|
|
|
|
print('\r' + f'📂 Loading files: {path}')
|
|
|
|
args = LOADER_MAPPING[ext]['args']
|
|
|
|
args = LOADER_MAPPING[ext]['args']
|
|
|
|
loader = LOADER_MAPPING[ext]['loader'](file_path, *args)
|
|
|
|
loader = LOADER_MAPPING[ext]['loader'](path, *args)
|
|
|
|
futures.append(pool.apply_async(loader.load))
|
|
|
|
files.extend(loader.load())
|
|
|
|
docs = []
|
|
|
|
return files
|
|
|
|
for future in futures:
|
|
|
|
|
|
|
|
docs.extend(future.get())
|
|
|
|
|
|
|
|
return docs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_cost(texts, model_name):
|
|
|
|
def calculate_cost(texts, model_name):
|
|
|
|