mirror of
https://github.com/rsaryev/talk-codebase
synced 2024-11-10 07:10:31 +00:00
80 lines
2.2 KiB
Python
80 lines
2.2 KiB
Python
import glob
|
|
import multiprocessing
|
|
import os
|
|
import sys
|
|
|
|
import tiktoken
|
|
from git import Repo
|
|
from langchain import FAISS
|
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
|
|
|
from talk_codebase.consts import LOADER_MAPPING, EXCLUDE_FILES
|
|
|
|
|
|
def get_repo(root_dir):
|
|
try:
|
|
return Repo(root_dir)
|
|
except:
|
|
return None
|
|
|
|
|
|
def is_ignored(path, root_dir):
|
|
repo = get_repo(root_dir)
|
|
if repo is None:
|
|
return False
|
|
if not os.path.exists(path):
|
|
return False
|
|
ignored = repo.ignored(path)
|
|
return len(ignored) > 0
|
|
|
|
|
|
class StreamStdOut(StreamingStdOutCallbackHandler):
|
|
def on_llm_new_token(self, token: str, **kwargs) -> None:
|
|
sys.stdout.write(token)
|
|
sys.stdout.flush()
|
|
|
|
def on_llm_start(self, serialized, prompts, **kwargs):
|
|
sys.stdout.write("🤖 ")
|
|
|
|
def on_llm_end(self, response, **kwargs):
|
|
sys.stdout.write("\n")
|
|
sys.stdout.flush()
|
|
|
|
|
|
def load_files(root_dir):
|
|
num_cpus = multiprocessing.cpu_count()
|
|
with multiprocessing.Pool(num_cpus) as pool:
|
|
futures = []
|
|
for file_path in glob.glob(os.path.join(root_dir, '**/*'), recursive=True):
|
|
if is_ignored(file_path, root_dir):
|
|
continue
|
|
if any(
|
|
file_path.endswith(exclude_file) for exclude_file in EXCLUDE_FILES):
|
|
continue
|
|
for ext in LOADER_MAPPING:
|
|
if file_path.endswith(ext):
|
|
print('\r' + f'📂 Loading files: {file_path}')
|
|
args = LOADER_MAPPING[ext]['args']
|
|
loader = LOADER_MAPPING[ext]['loader'](file_path, *args)
|
|
futures.append(pool.apply_async(loader.load))
|
|
docs = []
|
|
for future in futures:
|
|
docs.extend(future.get())
|
|
return docs
|
|
|
|
|
|
def calculate_cost(texts, model_name):
|
|
enc = tiktoken.encoding_for_model(model_name)
|
|
all_text = ''.join([text.page_content for text in texts])
|
|
tokens = enc.encode(all_text)
|
|
token_count = len(tokens)
|
|
cost = (token_count / 1000) * 0.0004
|
|
return cost
|
|
|
|
|
|
def get_local_vector_store(embeddings, path):
|
|
try:
|
|
return FAISS.load_local(path, embeddings)
|
|
except:
|
|
return None
|