Refactored CLI and LLM classes
- Refactored the CLI and LLM classes to improve code organization and readability. - Added a function to create an LLM instance based on the config. - Moved the function to the and classes. - Added a function to handle loading an existing vector store. - Added a function to estimate the cost of creating a vector store for OpenAI models. - Updated the function to prompt for the model type and path or API key depending on the type. - Updated the function to use the function and method of the LLM instance. - Updated the default config to include default values for and . - Added a constant to store the default config values. - Added a constant to store the default model path.pull/1/head
parent
9b9a834941
commit
f9a31937bb
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,108 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import questionary
|
||||
from halo import Halo
|
||||
from langchain import FAISS
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.chains import RetrievalQA
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.embeddings import HuggingFaceEmbeddings, OpenAIEmbeddings
|
||||
from langchain.llms import GPT4All
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from talk_codebase.utils import load_files, get_local_vector_store, calculate_cost, StreamStdOut
|
||||
|
||||
|
||||
class BaseLLM:
|
||||
|
||||
def __init__(self, root_dir, config):
|
||||
self.config = config
|
||||
self.llm = self._create_model()
|
||||
self.root_dir = root_dir
|
||||
self.vector_store = self._create_store(root_dir)
|
||||
|
||||
def _create_store(self, root_dir):
|
||||
raise NotImplementedError("Subclasses must implement this method.")
|
||||
|
||||
def _create_model(self):
|
||||
raise NotImplementedError("Subclasses must implement this method.")
|
||||
|
||||
def send_question(self, question):
|
||||
k = self.config.get("k")
|
||||
qa = RetrievalQA.from_chain_type(llm=self.llm, chain_type="stuff",
|
||||
retriever=self.vector_store.as_retriever(search_kwargs={"k": int(k)}),
|
||||
return_source_documents=True)
|
||||
answer = qa(question)
|
||||
print('\n' + '\n'.join([f'📄 {os.path.abspath(s.metadata["source"])}:' for s in answer["source_documents"]]))
|
||||
|
||||
def _create_vector_store(self, embeddings, index, root_dir):
|
||||
index_path = os.path.join(root_dir, f"vector_store/{index}")
|
||||
new_db = get_local_vector_store(embeddings, index_path)
|
||||
if new_db is not None:
|
||||
approve = questionary.select(
|
||||
f"Found existing vector store. Do you want to use it?",
|
||||
choices=[
|
||||
{"name": "Yes", "value": True},
|
||||
{"name": "No", "value": False},
|
||||
]
|
||||
).ask()
|
||||
if approve:
|
||||
return new_db
|
||||
|
||||
docs = load_files(root_dir)
|
||||
if len(docs) == 0:
|
||||
print("✘ No documents found")
|
||||
exit(0)
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=int(self.config.get("chunk_size")),
|
||||
chunk_overlap=int(self.config.get("chunk_overlap")))
|
||||
texts = text_splitter.split_documents(docs)
|
||||
|
||||
if index == "openai":
|
||||
cost = calculate_cost(docs, self.config.get("model_name"))
|
||||
approve = questionary.select(
|
||||
f"Creating a vector store for {len(docs)} documents will cost ~${cost:.5f}. Do you want to continue?",
|
||||
choices=[
|
||||
{"name": "Yes", "value": True},
|
||||
{"name": "No", "value": False},
|
||||
]
|
||||
).ask()
|
||||
if not approve:
|
||||
exit(0)
|
||||
|
||||
spinners = Halo(text=f"Creating vector store for {len(docs)} documents", spinner='dots').start()
|
||||
db = FAISS.from_documents(texts, embeddings)
|
||||
db.add_documents(texts)
|
||||
db.save_local(index_path)
|
||||
spinners.succeed(f"Created vector store for {len(docs)} documents")
|
||||
return db
|
||||
|
||||
|
||||
class LocalLLM(BaseLLM):
|
||||
|
||||
def _create_store(self, root_dir: str) -> Optional[FAISS]:
|
||||
embeddings = HuggingFaceEmbeddings(model_name='all-MiniLM-L6-v2')
|
||||
return self._create_vector_store(embeddings, "local", root_dir)
|
||||
|
||||
def _create_model(self):
|
||||
llm = GPT4All(model=self.config.get("model_path"), n_ctx=int(self.config.get("max_tokens")), streaming=True)
|
||||
return llm
|
||||
|
||||
|
||||
class OpenAILLM(BaseLLM):
|
||||
def _create_store(self, root_dir: str) -> Optional[FAISS]:
|
||||
embeddings = OpenAIEmbeddings(openai_api_key=self.config.get("api_key"))
|
||||
return self._create_vector_store(embeddings, "openai", root_dir)
|
||||
|
||||
def _create_model(self):
|
||||
return ChatOpenAI(model_name=self.config.get("model_name"), openai_api_key=self.config.get("api_key"),
|
||||
streaming=True,
|
||||
max_tokens=int(self.config.get("max_tokens")),
|
||||
callback_manager=CallbackManager([StreamStdOut()]))
|
||||
|
||||
|
||||
def factory_llm(root_dir, config):
|
||||
if config.get("model_type") == "openai":
|
||||
return OpenAILLM(root_dir, config)
|
||||
else:
|
||||
return LocalLLM(root_dir, config)
|
@ -1,81 +0,0 @@
|
||||
import os
|
||||
|
||||
import questionary
|
||||
import tiktoken
|
||||
from halo import Halo
|
||||
from langchain import FAISS
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.chains import ConversationalRetrievalChain
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from talk_codebase.utils import StreamStdOut, load_files
|
||||
|
||||
|
||||
def calculate_cost(texts, model_name):
|
||||
enc = tiktoken.encoding_for_model(model_name)
|
||||
all_text = ''.join([text.page_content for text in texts])
|
||||
tokens = enc.encode(all_text)
|
||||
token_count = len(tokens)
|
||||
cost = (token_count / 1000) * 0.0004
|
||||
return cost
|
||||
|
||||
|
||||
def get_local_vector_store(embeddings):
|
||||
try:
|
||||
return FAISS.load_local("vector_store", embeddings)
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def create_vector_store(root_dir, openai_api_key, model_name):
|
||||
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
|
||||
new_db = get_local_vector_store(embeddings)
|
||||
if new_db is not None:
|
||||
approve = questionary.select(
|
||||
f"Found existing vector store. Do you want to use it?",
|
||||
choices=[
|
||||
{"name": "Yes", "value": True},
|
||||
{"name": "No", "value": False},
|
||||
]
|
||||
).ask()
|
||||
if approve:
|
||||
return new_db
|
||||
|
||||
docs = load_files(root_dir)
|
||||
if len(docs) == 0:
|
||||
print("✘ No documents found")
|
||||
exit(0)
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
|
||||
texts = text_splitter.split_documents(docs)
|
||||
|
||||
cost = calculate_cost(docs, model_name)
|
||||
approve = questionary.select(
|
||||
f"Creating a vector store for {len(docs)} documents will cost ~${cost:.5f}. Do you want to continue?",
|
||||
choices=[
|
||||
{"name": "Yes", "value": True},
|
||||
{"name": "No", "value": False},
|
||||
]
|
||||
).ask()
|
||||
|
||||
if not approve:
|
||||
exit(0)
|
||||
|
||||
spinners = Halo(text='Creating vector store', spinner='dots').start()
|
||||
db = FAISS.from_documents(texts, embeddings)
|
||||
db.save_local("vector_store")
|
||||
spinners.succeed(f"Created vector store with {len(docs)} documents")
|
||||
|
||||
return db
|
||||
|
||||
|
||||
def send_question(question, vector_store, openai_api_key, model_name):
|
||||
model = ChatOpenAI(model_name=model_name, openai_api_key=openai_api_key, streaming=True,
|
||||
callback_manager=CallbackManager([StreamStdOut()]))
|
||||
qa = ConversationalRetrievalChain.from_llm(model,
|
||||
retriever=vector_store.as_retriever(search_kwargs={"k": 4}),
|
||||
return_source_documents=True)
|
||||
answer = qa({"question": question, "chat_history": []})
|
||||
print('\n' + '\n'.join([f'📄 {os.path.abspath(s.metadata["source"])}:' for s in answer["source_documents"]]))
|
||||
return answer
|
Loading…
Reference in New Issue