talk-codebase/talk_codebase/llm.py

126 lines
5.0 KiB
Python

import os
import time
from typing import Optional
import gpt4all
import questionary
from halo import Halo
from langchain import FAISS
from langchain.callbacks.manager import CallbackManager
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import HuggingFaceEmbeddings, OpenAIEmbeddings
from langchain.llms import LlamaCpp
from langchain.text_splitter import RecursiveCharacterTextSplitter
from talk_codebase.consts import MODEL_TYPES
from talk_codebase.utils import load_files, get_local_vector_store, calculate_cost, StreamStdOut
class BaseLLM:
def __init__(self, root_dir, config):
self.config = config
self.llm = self._create_model()
self.root_dir = root_dir
self.vector_store = self._create_store(root_dir)
def _create_store(self, root_dir):
raise NotImplementedError("Subclasses must implement this method.")
def _create_model(self):
raise NotImplementedError("Subclasses must implement this method.")
def embedding_search(self, query, k):
return self.vector_store.search(query, k=k, search_type="similarity")
def _create_vector_store(self, embeddings, index, root_dir):
k = int(self.config.get("k"))
index_path = os.path.join(root_dir, f"vector_store/{index}")
new_db = get_local_vector_store(embeddings, index_path)
if new_db is not None:
return new_db.as_retriever(search_kwargs={"k": k})
docs = load_files()
if len(docs) == 0:
print("✘ No documents found")
exit(0)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=int(self.config.get("chunk_size")),
chunk_overlap=int(self.config.get("chunk_overlap")))
texts = text_splitter.split_documents(docs)
if index == MODEL_TYPES["OPENAI"]:
cost = calculate_cost(docs, self.config.get("openai_model_name"))
approve = questionary.select(
f"Creating a vector store will cost ~${cost:.5f}. Do you want to continue?",
choices=[
{"name": "Yes", "value": True},
{"name": "No", "value": False},
]
).ask()
if not approve:
exit(0)
spinners = Halo(text=f"Creating vector store", spinner='dots').start()
db = FAISS.from_documents([texts[0]], embeddings)
for i, text in enumerate(texts[1:]):
spinners.text = f"Creating vector store ({i + 1}/{len(texts)})"
db.add_documents([text])
db.save_local(index_path)
time.sleep(1.5)
spinners.succeed(f"Created vector store")
return db.as_retriever(search_kwargs={"k": k})
def send_query(self, query):
retriever = self._create_store(self.root_dir)
qa = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True
)
docs = qa(query)
file_paths = [os.path.abspath(s.metadata["source"]) for s in docs['source_documents']]
print('\n'.join([f'📄 {file_path}:' for file_path in file_paths]))
class LocalLLM(BaseLLM):
def _create_store(self, root_dir: str) -> Optional[FAISS]:
embeddings = HuggingFaceEmbeddings(model_name='all-MiniLM-L6-v2')
return self._create_vector_store(embeddings, MODEL_TYPES["LOCAL"], root_dir)
def _create_model(self):
os.makedirs(self.config.get("model_path"), exist_ok=True)
gpt4all.GPT4All.retrieve_model(model_name=self.config.get("local_model_name"),
model_path=self.config.get("model_path"))
model_path = os.path.join(self.config.get("model_path"), self.config.get("local_model_name"))
model_n_ctx = int(self.config.get("max_tokens"))
model_n_batch = int(self.config.get("n_batch"))
callbacks = CallbackManager([StreamStdOut()])
llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, n_batch=model_n_batch, callbacks=callbacks,
verbose=False)
llm.client.verbose = False
return llm
class OpenAILLM(BaseLLM):
def _create_store(self, root_dir: str) -> Optional[FAISS]:
embeddings = OpenAIEmbeddings(openai_api_key=self.config.get("api_key"))
return self._create_vector_store(embeddings, MODEL_TYPES["OPENAI"], root_dir)
def _create_model(self):
return ChatOpenAI(model_name=self.config.get("openai_model_name"),
openai_api_key=self.config.get("api_key"),
streaming=True,
max_tokens=int(self.config.get("max_tokens")),
callback_manager=CallbackManager([StreamStdOut()]),
temperature=float(self.config.get("temperature")))
def factory_llm(root_dir, config):
if config.get("model_type") == "openai":
return OpenAILLM(root_dir, config)
else:
return LocalLLM(root_dir, config)