Refactored CLI and LLM classes
- Refactored the CLI and LLM classes to improve code organization and readability. - Added a function to create an LLM instance based on the config. - Moved the function to the and classes. - Added a function to handle loading an existing vector store. - Added a function to estimate the cost of creating a vector store for OpenAI models. - Updated the function to prompt for the model type and path or API key depending on the type. - Updated the function to use the function and method of the LLM instance. - Updated the default config to include default values for and . - Added a constant to store the default config values. - Added a constant to store the default model path.pull/1/head
parent
9b9a834941
commit
f9a31937bb
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,108 @@
|
|||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import questionary
|
||||||
|
from halo import Halo
|
||||||
|
from langchain import FAISS
|
||||||
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
from langchain.chains import RetrievalQA
|
||||||
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
from langchain.embeddings import HuggingFaceEmbeddings, OpenAIEmbeddings
|
||||||
|
from langchain.llms import GPT4All
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
|
||||||
|
from talk_codebase.utils import load_files, get_local_vector_store, calculate_cost, StreamStdOut
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLLM:
|
||||||
|
|
||||||
|
def __init__(self, root_dir, config):
|
||||||
|
self.config = config
|
||||||
|
self.llm = self._create_model()
|
||||||
|
self.root_dir = root_dir
|
||||||
|
self.vector_store = self._create_store(root_dir)
|
||||||
|
|
||||||
|
def _create_store(self, root_dir):
|
||||||
|
raise NotImplementedError("Subclasses must implement this method.")
|
||||||
|
|
||||||
|
def _create_model(self):
|
||||||
|
raise NotImplementedError("Subclasses must implement this method.")
|
||||||
|
|
||||||
|
def send_question(self, question):
|
||||||
|
k = self.config.get("k")
|
||||||
|
qa = RetrievalQA.from_chain_type(llm=self.llm, chain_type="stuff",
|
||||||
|
retriever=self.vector_store.as_retriever(search_kwargs={"k": int(k)}),
|
||||||
|
return_source_documents=True)
|
||||||
|
answer = qa(question)
|
||||||
|
print('\n' + '\n'.join([f'📄 {os.path.abspath(s.metadata["source"])}:' for s in answer["source_documents"]]))
|
||||||
|
|
||||||
|
def _create_vector_store(self, embeddings, index, root_dir):
|
||||||
|
index_path = os.path.join(root_dir, f"vector_store/{index}")
|
||||||
|
new_db = get_local_vector_store(embeddings, index_path)
|
||||||
|
if new_db is not None:
|
||||||
|
approve = questionary.select(
|
||||||
|
f"Found existing vector store. Do you want to use it?",
|
||||||
|
choices=[
|
||||||
|
{"name": "Yes", "value": True},
|
||||||
|
{"name": "No", "value": False},
|
||||||
|
]
|
||||||
|
).ask()
|
||||||
|
if approve:
|
||||||
|
return new_db
|
||||||
|
|
||||||
|
docs = load_files(root_dir)
|
||||||
|
if len(docs) == 0:
|
||||||
|
print("✘ No documents found")
|
||||||
|
exit(0)
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=int(self.config.get("chunk_size")),
|
||||||
|
chunk_overlap=int(self.config.get("chunk_overlap")))
|
||||||
|
texts = text_splitter.split_documents(docs)
|
||||||
|
|
||||||
|
if index == "openai":
|
||||||
|
cost = calculate_cost(docs, self.config.get("model_name"))
|
||||||
|
approve = questionary.select(
|
||||||
|
f"Creating a vector store for {len(docs)} documents will cost ~${cost:.5f}. Do you want to continue?",
|
||||||
|
choices=[
|
||||||
|
{"name": "Yes", "value": True},
|
||||||
|
{"name": "No", "value": False},
|
||||||
|
]
|
||||||
|
).ask()
|
||||||
|
if not approve:
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
spinners = Halo(text=f"Creating vector store for {len(docs)} documents", spinner='dots').start()
|
||||||
|
db = FAISS.from_documents(texts, embeddings)
|
||||||
|
db.add_documents(texts)
|
||||||
|
db.save_local(index_path)
|
||||||
|
spinners.succeed(f"Created vector store for {len(docs)} documents")
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
class LocalLLM(BaseLLM):
|
||||||
|
|
||||||
|
def _create_store(self, root_dir: str) -> Optional[FAISS]:
|
||||||
|
embeddings = HuggingFaceEmbeddings(model_name='all-MiniLM-L6-v2')
|
||||||
|
return self._create_vector_store(embeddings, "local", root_dir)
|
||||||
|
|
||||||
|
def _create_model(self):
|
||||||
|
llm = GPT4All(model=self.config.get("model_path"), n_ctx=int(self.config.get("max_tokens")), streaming=True)
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAILLM(BaseLLM):
|
||||||
|
def _create_store(self, root_dir: str) -> Optional[FAISS]:
|
||||||
|
embeddings = OpenAIEmbeddings(openai_api_key=self.config.get("api_key"))
|
||||||
|
return self._create_vector_store(embeddings, "openai", root_dir)
|
||||||
|
|
||||||
|
def _create_model(self):
|
||||||
|
return ChatOpenAI(model_name=self.config.get("model_name"), openai_api_key=self.config.get("api_key"),
|
||||||
|
streaming=True,
|
||||||
|
max_tokens=int(self.config.get("max_tokens")),
|
||||||
|
callback_manager=CallbackManager([StreamStdOut()]))
|
||||||
|
|
||||||
|
|
||||||
|
def factory_llm(root_dir, config):
|
||||||
|
if config.get("model_type") == "openai":
|
||||||
|
return OpenAILLM(root_dir, config)
|
||||||
|
else:
|
||||||
|
return LocalLLM(root_dir, config)
|
@ -1,81 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
import questionary
|
|
||||||
import tiktoken
|
|
||||||
from halo import Halo
|
|
||||||
from langchain import FAISS
|
|
||||||
from langchain.callbacks.manager import CallbackManager
|
|
||||||
from langchain.chains import ConversationalRetrievalChain
|
|
||||||
from langchain.chat_models import ChatOpenAI
|
|
||||||
from langchain.embeddings import OpenAIEmbeddings
|
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
||||||
|
|
||||||
from talk_codebase.utils import StreamStdOut, load_files
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_cost(texts, model_name):
|
|
||||||
enc = tiktoken.encoding_for_model(model_name)
|
|
||||||
all_text = ''.join([text.page_content for text in texts])
|
|
||||||
tokens = enc.encode(all_text)
|
|
||||||
token_count = len(tokens)
|
|
||||||
cost = (token_count / 1000) * 0.0004
|
|
||||||
return cost
|
|
||||||
|
|
||||||
|
|
||||||
def get_local_vector_store(embeddings):
|
|
||||||
try:
|
|
||||||
return FAISS.load_local("vector_store", embeddings)
|
|
||||||
except:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def create_vector_store(root_dir, openai_api_key, model_name):
|
|
||||||
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
|
|
||||||
new_db = get_local_vector_store(embeddings)
|
|
||||||
if new_db is not None:
|
|
||||||
approve = questionary.select(
|
|
||||||
f"Found existing vector store. Do you want to use it?",
|
|
||||||
choices=[
|
|
||||||
{"name": "Yes", "value": True},
|
|
||||||
{"name": "No", "value": False},
|
|
||||||
]
|
|
||||||
).ask()
|
|
||||||
if approve:
|
|
||||||
return new_db
|
|
||||||
|
|
||||||
docs = load_files(root_dir)
|
|
||||||
if len(docs) == 0:
|
|
||||||
print("✘ No documents found")
|
|
||||||
exit(0)
|
|
||||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
|
|
||||||
texts = text_splitter.split_documents(docs)
|
|
||||||
|
|
||||||
cost = calculate_cost(docs, model_name)
|
|
||||||
approve = questionary.select(
|
|
||||||
f"Creating a vector store for {len(docs)} documents will cost ~${cost:.5f}. Do you want to continue?",
|
|
||||||
choices=[
|
|
||||||
{"name": "Yes", "value": True},
|
|
||||||
{"name": "No", "value": False},
|
|
||||||
]
|
|
||||||
).ask()
|
|
||||||
|
|
||||||
if not approve:
|
|
||||||
exit(0)
|
|
||||||
|
|
||||||
spinners = Halo(text='Creating vector store', spinner='dots').start()
|
|
||||||
db = FAISS.from_documents(texts, embeddings)
|
|
||||||
db.save_local("vector_store")
|
|
||||||
spinners.succeed(f"Created vector store with {len(docs)} documents")
|
|
||||||
|
|
||||||
return db
|
|
||||||
|
|
||||||
|
|
||||||
def send_question(question, vector_store, openai_api_key, model_name):
|
|
||||||
model = ChatOpenAI(model_name=model_name, openai_api_key=openai_api_key, streaming=True,
|
|
||||||
callback_manager=CallbackManager([StreamStdOut()]))
|
|
||||||
qa = ConversationalRetrievalChain.from_llm(model,
|
|
||||||
retriever=vector_store.as_retriever(search_kwargs={"k": 4}),
|
|
||||||
return_source_documents=True)
|
|
||||||
answer = qa({"question": question, "chat_history": []})
|
|
||||||
print('\n' + '\n'.join([f'📄 {os.path.abspath(s.metadata["source"])}:' for s in answer["source_documents"]]))
|
|
||||||
return answer
|
|
Loading…
Reference in New Issue