From 56cd7e3ba52fcfae0fb28493511d51fbdaf72d90 Mon Sep 17 00:00:00 2001 From: Gustav von Zitzewitz Date: Tue, 23 May 2023 15:23:04 +0200 Subject: [PATCH] refactor to enable mode and model selection --- app.py | 58 ++-- datachad/__init__.py | 0 datachad/chain.py | 77 +++++ constants.py => datachad/constants.py | 21 +- datachad/database.py | 51 ++++ datachad/loader.py | 133 +++++++++ datachad/models.py | 111 ++++++++ datachad/utils.py | 104 +++++++ requirements.txt | 2 +- utils.py | 393 -------------------------- 10 files changed, 522 insertions(+), 428 deletions(-) create mode 100644 datachad/__init__.py create mode 100644 datachad/chain.py rename constants.py => datachad/constants.py (81%) create mode 100644 datachad/database.py create mode 100644 datachad/loader.py create mode 100644 datachad/models.py create mode 100644 datachad/utils.py delete mode 100644 utils.py diff --git a/app.py b/app.py index 501950c..5d0270c 100644 --- a/app.py +++ b/app.py @@ -1,33 +1,33 @@ import streamlit as st from streamlit_chat import message -from constants import ( +from datachad.chain import generate_response, update_chain +from datachad.constants import ( ACTIVELOOP_HELP, APP_NAME, AUTHENTICATION_HELP, CHUNK_OVERLAP, CHUNK_SIZE, DEFAULT_DATA_SOURCE, - EMBEDDINGS, ENABLE_ADVANCED_OPTIONS, + ENABLE_LOCAL_MODE, FETCH_K, + LOCAL_MODE_DISABLED_HELP, MAX_TOKENS, - MODEL, + MODEL_N_CTX, OPENAI_HELP, PAGE_ICON, PROJECT_URL, TEMPERATURE, USAGE_HELP, - MODEL_N_CTX, K, ) -from utils import ( +from datachad.models import MODELS, MODES +from datachad.utils import ( authenticate, delete_uploaded_file, - generate_response, logger, save_uploaded_file, - update_chain, ) # Page options and header @@ -52,8 +52,8 @@ SESSION_DEFAULTS = { "activeloop_org_name": None, "uploaded_file": None, "data_source": DEFAULT_DATA_SOURCE, - "model": MODEL, - "embeddings": EMBEDDINGS, + "mode": MODES.OPENAI, + "model": MODELS.GPT35TURBO, "k": K, "fetch_k": FETCH_K, "chunk_size": CHUNK_SIZE, @@ -72,7 +72,7 @@ def authentication_form() -> None: st.title("Authentication", help=AUTHENTICATION_HELP) with st.form("authentication"): openai_api_key = st.text_input( - "OpenAI API Key", + f"{st.session_state['mode']} API Key", type="password", help=OPENAI_HELP, placeholder="This field is mandatory", @@ -101,29 +101,34 @@ def advanced_options_form() -> None: ) if advanced_options: with st.form("advanced_options"): - temperature = st.slider( + col1, col2 = st.columns(2) + col1.selectbox("model", MODELS.for_mode(st.session_state["mode"])) + col2.number_input( "temperature", min_value=0.0, max_value=1.0, value=TEMPERATURE, help="Controls the randomness of the language model output", + key="temperature", ) - col1, col2 = st.columns(2) - fetch_k = col1.number_input( + + col1.number_input( "k_fetch", min_value=1, max_value=1000, value=FETCH_K, help="The number of documents to pull from the vector database", + key="k_fetch", ) - k = col2.number_input( + col2.number_input( "k", min_value=1, max_value=100, value=K, help="The number of most similar documents to build the context from", + key="k", ) - chunk_size = col1.number_input( + col1.number_input( "chunk_size", min_value=1, max_value=100000, @@ -133,35 +138,37 @@ def advanced_options_form() -> None: "before being embedded.\n\nChanging this parameter makes re-embedding " "and re-uploading the data to the database necessary " ), + key="chunk_size", ) - max_tokens = col2.number_input( + col2.number_input( "max_tokens", min_value=1, - max_value=4069, + max_value=30000, value=MAX_TOKENS, help="Limits the documents returned from database based on number of tokens", + key="max_tokens", ) applied = st.form_submit_button("Apply") if applied: - st.session_state["k"] = k - st.session_state["fetch_k"] = fetch_k - st.session_state["chunk_size"] = chunk_size - st.session_state["temperature"] = temperature - st.session_state["max_tokens"] = max_tokens update_chain() # Sidebar with Authentication and Advanced Options with st.sidebar: - authentication_form() + mode = st.selectbox("Mode", MODES.values(), key="mode") + if mode == MODES.LOCAL and not ENABLE_LOCAL_MODE: + st.error(LOCAL_MODE_DISABLED_HELP, icon=PAGE_ICON) + st.stop() + if mode != MODES.LOCAL: + authentication_form() st.info(f"Learn how it works [here]({PROJECT_URL})") # Only start App if authentication is OK - if not st.session_state["auth_ok"]: + if not (st.session_state["auth_ok"] or mode == MODES.LOCAL): st.stop() # Clear button to reset all chat communication - clear_button = st.button("Clear Conversation", key="clear") + clear_button = st.button("Clear Conversation") # Advanced Options if ENABLE_ADVANCED_OPTIONS: @@ -214,6 +221,7 @@ with text_container: submit_button = st.form_submit_button(label="Send") if submit_button and user_input: + text_container.empty() output = generate_response(user_input) st.session_state["past"].append(user_input) st.session_state["generated"].append(output) diff --git a/datachad/__init__.py b/datachad/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/datachad/chain.py b/datachad/chain.py new file mode 100644 index 0000000..2bc91c4 --- /dev/null +++ b/datachad/chain.py @@ -0,0 +1,77 @@ +import streamlit as st +from langchain.callbacks import OpenAICallbackHandler, get_openai_callback +from langchain.chains import ConversationalRetrievalChain + +from datachad.constants import PAGE_ICON +from datachad.database import get_vector_store +from datachad.models import get_model +from datachad.utils import logger + + +def get_chain() -> ConversationalRetrievalChain: + # create the langchain that will be called to generate responses + vector_store = get_vector_store() + retriever = vector_store.as_retriever() + # Search params "fetch_k" and "k" define how many documents are pulled from the hub + # and selected after the document matching to build the context + # that is fed to the model together with your prompt + search_kwargs = { + "maximal_marginal_relevance": True, + "distance_metric": "cos", + "fetch_k": st.session_state["fetch_k"], + "k": st.session_state["k"], + } + retriever.search_kwargs.update(search_kwargs) + model = get_model() + chain = ConversationalRetrievalChain.from_llm( + model, + retriever=retriever, + chain_type="stuff", + verbose=True, + # we limit the maximum number of used tokens + # to prevent running into the models token limit of 4096 + max_tokens_limit=st.session_state["max_tokens"], + ) + return chain + + +def update_chain() -> None: + # Build chain with parameters from session state and store it back + # Also delete chat history to not confuse the bot with old context + try: + st.session_state["chain"] = get_chain() + st.session_state["chat_history"] = [] + msg = f"Data source '{st.session_state['data_source']}' is ready to go!" + logger.info(msg) + st.info(msg, icon=PAGE_ICON) + except Exception as e: + msg = f"Failed to build chain for data source '{st.session_state['data_source']}' with error: {e}" + logger.error(msg) + st.error(msg, icon=PAGE_ICON) + + +def update_usage(cb: OpenAICallbackHandler) -> None: + # Accumulate API call usage via callbacks + logger.info(f"Usage: {cb}") + callback_properties = [ + "total_tokens", + "prompt_tokens", + "completion_tokens", + "total_cost", + ] + for prop in callback_properties: + value = getattr(cb, prop, 0) + st.session_state["usage"].setdefault(prop, 0) + st.session_state["usage"][prop] += value + + +def generate_response(prompt: str) -> str: + # call the chain to generate responses and add them to the chat history + with st.spinner("Generating response"), get_openai_callback() as cb: + response = st.session_state["chain"]( + {"question": prompt, "chat_history": st.session_state["chat_history"]} + ) + update_usage(cb) + logger.info(f"Response: '{response}'") + st.session_state["chat_history"].append((prompt, response["answer"])) + return response["answer"] diff --git a/constants.py b/datachad/constants.py similarity index 81% rename from constants.py rename to datachad/constants.py index bc90e3f..714d996 100644 --- a/constants.py +++ b/datachad/constants.py @@ -1,28 +1,31 @@ from pathlib import Path -APP_NAME = "DataChad" -MODEL = "gpt-3.5-turbo" -EMBEDDINGS = "openai" PAGE_ICON = "🤖" - +APP_NAME = "DataChad" PROJECT_URL = "https://github.com/gustavz/DataChad" + K = 10 FETCH_K = 20 CHUNK_SIZE = 1000 CHUNK_OVERLAP = 0 TEMPERATURE = 0.7 MAX_TOKENS = 3357 +MODEL_N_CTX = 1000 + +ENABLE_LOCAL_MODE = False ENABLE_ADVANCED_OPTIONS = True -MODEL_N_CTX = 1000 -LLAMACPP_MODEL_PATH = "" -GPT4ALL_MODEL_PATH = "" -ENABLE_LOCAL_MODELS = False DATA_PATH = Path.cwd() / "data" DEFAULT_DATA_SOURCE = "https://github.com/gustavz/DataChad.git" + +LOCAL_MODE_DISABLED_HELP = """ +This is a demo hosted with limited resources. Local Mode is not enabled.\n +To use Local Mode deploy the app on your machine of choice with ENABLE_LOCAL_MODE set to True. +""" + AUTHENTICATION_HELP = f""" Your credentials are only stored in your session state.\n The keys are neither exposed nor made visible or stored permanently in any way.\n @@ -31,7 +34,7 @@ Feel free to check out [the code base]({PROJECT_URL}) to validate how things wor USAGE_HELP = f""" These are the accumulated OpenAI API usage metrics.\n -The app uses '{MODEL}' for chat and 'text-embedding-ada-002' for embeddings.\n +The app uses 'gpt-3.5-turbo' for chat and 'text-embedding-ada-002' for embeddings.\n Learn more about OpenAI's pricing [here](https://openai.com/pricing#language-models) """ diff --git a/datachad/database.py b/datachad/database.py new file mode 100644 index 0000000..3ff04b4 --- /dev/null +++ b/datachad/database.py @@ -0,0 +1,51 @@ +import os +import re + +import deeplake +import streamlit as st +from langchain.vectorstores import DeepLake, VectorStore + +from datachad.constants import DATA_PATH +from datachad.loader import load_data_source +from datachad.models import MODES, get_embeddings +from datachad.utils import logger + + +def get_dataset_path() -> str: + # replace all non-word characters with dashes + # to get a string that can be used to create a new dataset + dataset_name = re.sub(r"\W+", "-", st.session_state["data_source"]) + dataset_name = re.sub(r"--+", "- ", dataset_name).strip("-") + if st.session_state["mode"] == MODES.LOCAL: + if not os.path.exists(DATA_PATH): + os.makedirs(DATA_PATH) + dataset_path = str(DATA_PATH / dataset_name) + else: + dataset_path = f"hub://{st.session_state['activeloop_org_name']}/{dataset_name}-{st.session_state['chunk_size']}" + return dataset_path + + +def get_vector_store() -> VectorStore: + # either load existing vector store or upload a new one to the hub + embeddings = get_embeddings() + dataset_path = get_dataset_path() + if deeplake.exists(dataset_path, token=st.session_state["activeloop_token"]): + with st.spinner("Loading vector store..."): + logger.info(f"Dataset '{dataset_path}' exists -> loading") + vector_store = DeepLake( + dataset_path=dataset_path, + read_only=True, + embedding_function=embeddings, + token=st.session_state["activeloop_token"], + ) + else: + with st.spinner("Reading, embedding and uploading data to hub..."): + logger.info(f"Dataset '{dataset_path}' does not exist -> uploading") + docs = load_data_source() + vector_store = DeepLake.from_documents( + docs, + embeddings, + dataset_path=dataset_path, + token=st.session_state["activeloop_token"], + ) + return vector_store diff --git a/datachad/loader.py b/datachad/loader.py new file mode 100644 index 0000000..2491e46 --- /dev/null +++ b/datachad/loader.py @@ -0,0 +1,133 @@ +import os +import shutil +from typing import List + +import streamlit as st +from langchain.document_loaders import ( + CSVLoader, + DirectoryLoader, + EverNoteLoader, + GitLoader, + NotebookLoader, + OnlinePDFLoader, + PDFMinerLoader, + PythonLoader, + TextLoader, + UnstructuredEPubLoader, + UnstructuredFileLoader, + UnstructuredHTMLLoader, + UnstructuredMarkdownLoader, + UnstructuredODTLoader, + UnstructuredPowerPointLoader, + UnstructuredWordDocumentLoader, + WebBaseLoader, +) +from langchain.document_loaders.base import BaseLoader +from langchain.schema import Document +from langchain.text_splitter import RecursiveCharacterTextSplitter + +from datachad.constants import DATA_PATH, PAGE_ICON, PROJECT_URL +from datachad.utils import logger + + +class AutoGitLoader: + def __init__(self, data_source: str) -> None: + self.data_source = data_source + + def load(self) -> List[Document]: + # We need to try both common main branches + # Thank you github for the "master" to "main" switch + # we need to make sure the data path exists + if not os.path.exists(DATA_PATH): + os.makedirs(DATA_PATH) + repo_name = self.data_source.split("/")[-1].split(".")[0] + repo_path = str(DATA_PATH / repo_name) + clone_url = self.data_source + if os.path.exists(repo_path): + clone_url = None + branches = ["main", "master"] + for branch in branches: + try: + docs = GitLoader(repo_path, clone_url, branch).load() + break + except Exception as e: + logger.error(f"Error loading git: {e}") + if os.path.exists(repo_path): + # cleanup repo afterwards + shutil.rmtree(repo_path) + try: + return docs + except: + raise RuntimeError("Make sure to use HTTPS GitHub repo links") + + +FILE_LOADER_MAPPING = { + ".csv": (CSVLoader, {"encoding": "utf-8"}), + ".doc": (UnstructuredWordDocumentLoader, {}), + ".docx": (UnstructuredWordDocumentLoader, {}), + ".enex": (EverNoteLoader, {}), + ".epub": (UnstructuredEPubLoader, {}), + ".html": (UnstructuredHTMLLoader, {}), + ".md": (UnstructuredMarkdownLoader, {}), + ".odt": (UnstructuredODTLoader, {}), + ".pdf": (PDFMinerLoader, {}), + ".ppt": (UnstructuredPowerPointLoader, {}), + ".pptx": (UnstructuredPowerPointLoader, {}), + ".txt": (TextLoader, {"encoding": "utf8"}), + ".ipynb": (NotebookLoader, {}), + ".py": (PythonLoader, {}), + # Add more mappings for other file extensions and loaders as needed +} + +WEB_LOADER_MAPPING = { + ".git": (AutoGitLoader, {}), + ".pdf": (OnlinePDFLoader, {}), +} + + +def get_loader(file_path: str, mapping: dict, default_loader: BaseLoader) -> BaseLoader: + # Choose loader from mapping, load default if no match found + ext = "." + file_path.rsplit(".", 1)[-1] + if ext in mapping: + loader_class, loader_args = mapping[ext] + loader = loader_class(file_path, **loader_args) + else: + loader = default_loader(file_path) + return loader + + +def load_data_source() -> List[Document]: + # Ugly thing that decides how to load data + # It aint much, but it's honest work + data_source = st.session_state["data_source"] + is_web = data_source.startswith("http") + is_dir = os.path.isdir(data_source) + is_file = os.path.isfile(data_source) + + loader = None + if is_dir: + loader = DirectoryLoader(data_source, recursive=True, silent_errors=True) + elif is_web: + loader = get_loader(data_source, WEB_LOADER_MAPPING, WebBaseLoader) + elif is_file: + loader = get_loader(data_source, FILE_LOADER_MAPPING, UnstructuredFileLoader) + try: + # Chunk size is a major trade-off parameter to control result accuracy over computaion + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=st.session_state["chunk_size"], + chunk_overlap=st.session_state["chunk_overlap"], + ) + docs = loader.load() + docs = text_splitter.split_documents(docs) + logger.info(f"Loaded: {len(docs)} document chucks") + return docs + except Exception as e: + msg = ( + e + if loader + else f"No Loader found for your data source. Consider contributing:  {PROJECT_URL}!" + ) + error_msg = f"Failed to load '{st.session_state['data_source']}':\n\n{msg}" + st.error(error_msg, icon=PAGE_ICON) + logger.error(error_msg) + st.stop() diff --git a/datachad/models.py b/datachad/models.py new file mode 100644 index 0000000..308ed28 --- /dev/null +++ b/datachad/models.py @@ -0,0 +1,111 @@ +from dataclasses import dataclass + +import streamlit as st +from langchain.base_language import BaseLanguageModel +from langchain.chat_models import ChatOpenAI +from langchain.embeddings import HuggingFaceEmbeddings +from langchain.embeddings.openai import Embeddings, OpenAIEmbeddings +from langchain.llms import GPT4All, LlamaCpp + +from datachad.utils import logger + + +class Enum: + @classmethod + def values(cls): + return [v for k, v in cls.__dict__.items() if not k.startswith("_")] + + @classmethod + def dict(cls): + return {k: v for k, v in cls.__dict__.items() if not k.startswith("_")} + + +@dataclass +class Model: + name: str + mode: str + embedding: str + path: str = None # for local models only + + def __str__(self): + return self.name + + +class MODES(Enum): + OPENAI = "OpenAI" + LOCAL = "Local" + + +class EMBEDDINGS(Enum): + OPENAI = "openai" + HUGGINGFACE = "all-MiniLM-L6-v2" + + +class MODELS(Enum): + GPT35TURBO = Model("gpt-3.5-turbo", MODES.OPENAI, EMBEDDINGS.OPENAI) + GPT4 = Model("gpt-4", MODES.OPENAI, EMBEDDINGS.OPENAI) + LLAMACPP = Model( + "LLAMA", MODES.LOCAL, EMBEDDINGS.HUGGINGFACE, "models/llamacpp.bin" + ) + GPT4ALL = Model( + "GPT4All", MODES.LOCAL, EMBEDDINGS.HUGGINGFACE, "models/gpt4all.bin" + ) + + @classmethod + def for_mode(cls, mode): + return [v for v in cls.values() if isinstance(v, Model) and v.mode == mode] + + +def get_model() -> BaseLanguageModel: + match st.session_state["model"].name: + case MODELS.GPT35TURBO.name: + model = ChatOpenAI( + model_name=st.session_state["model"].name, + temperature=st.session_state["temperature"], + openai_api_key=st.session_state["openai_api_key"], + ) + case MODELS.GPT4.name: + model = ChatOpenAI( + model_name=st.session_state["model"].name, + temperature=st.session_state["temperature"], + openai_api_key=st.session_state["openai_api_key"], + ) + case MODELS.LLAMACPP.name: + model = LlamaCpp( + model_path=st.session_state["model"].path, + n_ctx=st.session_state["model_n_ctx"], + temperature=st.session_state["temperature"], + verbose=True, + ) + case MODELS.GPT4ALL.name: + model = GPT4All( + model=st.session_state["model"].path, + n_ctx=st.session_state["model_n_ctx"], + backend="gptj", + temp=st.session_state["temperature"], + verbose=True, + ) + # Add more models as needed + case _default: + msg = f"Model {st.session_state['model']} not supported!" + logger.error(msg) + st.error(msg) + exit + return model + + +def get_embeddings() -> Embeddings: + match st.session_state["model"].embedding: + case EMBEDDINGS.OPENAI: + embeddings = OpenAIEmbeddings( + disallowed_special=(), openai_api_key=st.session_state["openai_api_key"] + ) + case EMBEDDINGS.HUGGINGFACE: + embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS.HUGGINGFACE) + # Add more embeddings as needed + case _default: + msg = f"Embeddings {st.session_state['embeddings']} not supported!" + logger.error(msg) + st.error(msg) + exit + return embeddings diff --git a/datachad/utils.py b/datachad/utils.py new file mode 100644 index 0000000..6eee1cc --- /dev/null +++ b/datachad/utils.py @@ -0,0 +1,104 @@ +import logging +import os +import sys + +import deeplake +import openai +import streamlit as st +from dotenv import load_dotenv + +from datachad.constants import APP_NAME, DATA_PATH, PAGE_ICON + +# loads environment variables +load_dotenv() + +logger = logging.getLogger(APP_NAME) + + +def configure_logger(debug: int = 0) -> None: + # boilerplate code to enable logging in the streamlit app console + log_level = logging.DEBUG if debug == 1 else logging.INFO + logger.setLevel(log_level) + + stream_handler = logging.StreamHandler(stream=sys.stdout) + stream_handler.setLevel(log_level) + + formatter = logging.Formatter("%(message)s") + + stream_handler.setFormatter(formatter) + + logger.addHandler(stream_handler) + logger.propagate = False + + +configure_logger(0) + + +def authenticate( + openai_api_key: str, activeloop_token: str, activeloop_org_name: str +) -> None: + # Validate all credentials are set and correct + # Check for env variables to enable local dev and deployments with shared credentials + openai_api_key = ( + openai_api_key + or os.environ.get("OPENAI_API_KEY") + or st.secrets.get("OPENAI_API_KEY") + ) + activeloop_token = ( + activeloop_token + or os.environ.get("ACTIVELOOP_TOKEN") + or st.secrets.get("ACTIVELOOP_TOKEN") + ) + activeloop_org_name = ( + activeloop_org_name + or os.environ.get("ACTIVELOOP_ORG_NAME") + or st.secrets.get("ACTIVELOOP_ORG_NAME") + ) + if not (openai_api_key and activeloop_token and activeloop_org_name): + st.session_state["auth_ok"] = False + st.error("Credentials neither set nor stored", icon=PAGE_ICON) + return + try: + # Try to access openai and deeplake + with st.spinner("Authentifying..."): + openai.api_key = openai_api_key + openai.Model.list() + deeplake.exists( + f"hub://{activeloop_org_name}/DataChad-Authentication-Check", + token=activeloop_token, + ) + except Exception as e: + logger.error(f"Authentication failed with {e}") + st.session_state["auth_ok"] = False + st.error("Authentication failed", icon=PAGE_ICON) + return + # store credentials in the session state + st.session_state["auth_ok"] = True + st.session_state["openai_api_key"] = openai_api_key + st.session_state["activeloop_token"] = activeloop_token + st.session_state["activeloop_org_name"] = activeloop_org_name + logger.info("Authentification successful!") + + +def save_uploaded_file() -> str: + # streamlit uploaded files need to be stored locally + # before embedded and uploaded to the hub + uploaded_file = st.session_state["uploaded_file"] + if not os.path.exists(DATA_PATH): + os.makedirs(DATA_PATH) + file_path = str(DATA_PATH / uploaded_file.name) + uploaded_file.seek(0) + file_bytes = uploaded_file.read() + file = open(file_path, "wb") + file.write(file_bytes) + file.close() + logger.info(f"Saved: {file_path}") + return file_path + + +def delete_uploaded_file() -> None: + # cleanup locally stored files + file_path = DATA_PATH / st.session_state["uploaded_file"].name + if os.path.exists(DATA_PATH): + os.remove(file_path) + logger.info(f"Removed: {file_path}") diff --git a/requirements.txt b/requirements.txt index c408250..0f45702 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ streamlit==1.22.0 streamlit-chat==0.0.2.2 deeplake==3.4.1 openai==0.27.6 -langchain==0.0.164 +langchain==0.0.173 tiktoken==0.4.0 unstructured==0.6.5 pdf2image==1.16.3 diff --git a/utils.py b/utils.py deleted file mode 100644 index 7aed86e..0000000 --- a/utils.py +++ /dev/null @@ -1,393 +0,0 @@ -import logging -import os -import re -import shutil -import sys -from typing import List - -import deeplake -import openai -import streamlit as st -from dotenv import load_dotenv -from langchain.base_language import BaseLanguageModel -from langchain.callbacks import OpenAICallbackHandler, get_openai_callback -from langchain.chains import ConversationalRetrievalChain -from langchain.chat_models import ChatOpenAI -from langchain.document_loaders import ( - CSVLoader, - DirectoryLoader, - EverNoteLoader, - GitLoader, - NotebookLoader, - OnlinePDFLoader, - PDFMinerLoader, - PythonLoader, - TextLoader, - UnstructuredEPubLoader, - UnstructuredFileLoader, - UnstructuredHTMLLoader, - UnstructuredMarkdownLoader, - UnstructuredODTLoader, - UnstructuredPowerPointLoader, - UnstructuredWordDocumentLoader, - WebBaseLoader, -) -from langchain.document_loaders.base import BaseLoader -from langchain.embeddings.openai import Embeddings, OpenAIEmbeddings -from langchain.schema import Document -from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain.vectorstores import DeepLake, VectorStore -from langchain.llms import GPT4All, LlamaCpp -from langchain.embeddings import HuggingFaceEmbeddings - -from constants import APP_NAME, DATA_PATH, PAGE_ICON, PROJECT_URL, LLAMACPP_MODEL_PATH, GPT4ALL_MODEL_PATH - -# loads environment variables -load_dotenv() - -logger = logging.getLogger(APP_NAME) - - -def configure_logger(debug: int = 0) -> None: - # boilerplate code to enable logging in the streamlit app console - log_level = logging.DEBUG if debug == 1 else logging.INFO - logger.setLevel(log_level) - - stream_handler = logging.StreamHandler(stream=sys.stdout) - stream_handler.setLevel(log_level) - - formatter = logging.Formatter("%(message)s") - - stream_handler.setFormatter(formatter) - - logger.addHandler(stream_handler) - logger.propagate = False - - -configure_logger(0) - - -def authenticate( - openai_api_key: str, activeloop_token: str, activeloop_org_name: str -) -> None: - # Validate all credentials are set and correct - # Check for env variables to enable local dev and deployments with shared credentials - openai_api_key = ( - openai_api_key - or os.environ.get("OPENAI_API_KEY") - or st.secrets.get("OPENAI_API_KEY") - ) - activeloop_token = ( - activeloop_token - or os.environ.get("ACTIVELOOP_TOKEN") - or st.secrets.get("ACTIVELOOP_TOKEN") - ) - activeloop_org_name = ( - activeloop_org_name - or os.environ.get("ACTIVELOOP_ORG_NAME") - or st.secrets.get("ACTIVELOOP_ORG_NAME") - ) - if not (openai_api_key and activeloop_token and activeloop_org_name): - st.session_state["auth_ok"] = False - st.error("Credentials neither set nor stored", icon=PAGE_ICON) - return - try: - # Try to access openai and deeplake - with st.spinner("Authentifying..."): - openai.api_key = openai_api_key - openai.Model.list() - deeplake.exists( - f"hub://{activeloop_org_name}/DataChad-Authentication-Check", - token=activeloop_token, - ) - except Exception as e: - logger.error(f"Authentication failed with {e}") - st.session_state["auth_ok"] = False - st.error("Authentication failed", icon=PAGE_ICON) - return - # store credentials in the session state - st.session_state["auth_ok"] = True - st.session_state["openai_api_key"] = openai_api_key - st.session_state["activeloop_token"] = activeloop_token - st.session_state["activeloop_org_name"] = activeloop_org_name - logger.info("Authentification successful!") - - -def save_uploaded_file() -> str: - # streamlit uploaded files need to be stored locally - # before embedded and uploaded to the hub - uploaded_file = st.session_state["uploaded_file"] - if not os.path.exists(DATA_PATH): - os.makedirs(DATA_PATH) - file_path = str(DATA_PATH / uploaded_file.name) - uploaded_file.seek(0) - file_bytes = uploaded_file.read() - file = open(file_path, "wb") - file.write(file_bytes) - file.close() - logger.info(f"Saved: {file_path}") - return file_path - - -def delete_uploaded_file() -> None: - # cleanup locally stored files - file_path = DATA_PATH / st.session_state["uploaded_file"].name - if os.path.exists(DATA_PATH): - os.remove(file_path) - logger.info(f"Removed: {file_path}") - - -class AutoGitLoader: - def __init__(self, data_source: str) -> None: - self.data_source = data_source - - def load(self) -> List[Document]: - # We need to try both common main branches - # Thank you github for the "master" to "main" switch - # we need to make sure the data path exists - if not os.path.exists(DATA_PATH): - os.makedirs(DATA_PATH) - repo_name = self.data_source.split("/")[-1].split(".")[0] - repo_path = str(DATA_PATH / repo_name) - clone_url = self.data_source - if os.path.exists(repo_path): - clone_url = None - branches = ["main", "master"] - for branch in branches: - try: - docs = GitLoader(repo_path, clone_url, branch).load() - break - except Exception as e: - logger.error(f"Error loading git: {e}") - if os.path.exists(repo_path): - # cleanup repo afterwards - shutil.rmtree(repo_path) - try: - return docs - except: - raise RuntimeError("Make sure to use HTTPS GitHub repo links") - - -FILE_LOADER_MAPPING = { - ".csv": (CSVLoader, {"encoding": "utf-8"}), - ".doc": (UnstructuredWordDocumentLoader, {}), - ".docx": (UnstructuredWordDocumentLoader, {}), - ".enex": (EverNoteLoader, {}), - ".epub": (UnstructuredEPubLoader, {}), - ".html": (UnstructuredHTMLLoader, {}), - ".md": (UnstructuredMarkdownLoader, {}), - ".odt": (UnstructuredODTLoader, {}), - ".pdf": (PDFMinerLoader, {}), - ".ppt": (UnstructuredPowerPointLoader, {}), - ".pptx": (UnstructuredPowerPointLoader, {}), - ".txt": (TextLoader, {"encoding": "utf8"}), - ".ipynb": (NotebookLoader, {}), - ".py": (PythonLoader, {}), - # Add more mappings for other file extensions and loaders as needed -} - -WEB_LOADER_MAPPING = { - ".git": (AutoGitLoader, {}), - ".pdf": (OnlinePDFLoader, {}), -} - - -def get_loader(file_path: str, mapping: dict, default_loader: BaseLoader) -> BaseLoader: - # Choose loader from mapping, load default if no match found - ext = "." + file_path.rsplit(".", 1)[-1] - if ext in mapping: - loader_class, loader_args = mapping[ext] - loader = loader_class(file_path, **loader_args) - else: - loader = default_loader(file_path) - return loader - - -def load_data_source() -> List[Document]: - # Ugly thing that decides how to load data - # It aint much, but it's honest work - data_source = st.session_state["data_source"] - is_web = data_source.startswith("http") - is_dir = os.path.isdir(data_source) - is_file = os.path.isfile(data_source) - - loader = None - if is_dir: - loader = DirectoryLoader(data_source, recursive=True, silent_errors=True) - elif is_web: - loader = get_loader(data_source, WEB_LOADER_MAPPING, WebBaseLoader) - elif is_file: - loader = get_loader(data_source, FILE_LOADER_MAPPING, UnstructuredFileLoader) - try: - # Chunk size is a major trade-off parameter to control result accuracy over computaion - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=st.session_state["chunk_size"], - chunk_overlap=st.session_state["chunk_overlap"], - ) - docs = loader.load() - docs = text_splitter.split_documents(docs) - logger.info(f"Loaded: {len(docs)} document chucks") - return docs - except Exception as e: - msg = ( - e - if loader - else f"No Loader found for your data source. Consider contributing:  {PROJECT_URL}!" - ) - error_msg = f"Failed to load '{st.session_state['data_source']}':\n\n{msg}" - st.error(error_msg, icon=PAGE_ICON) - logger.error(error_msg) - st.stop() - - -def get_dataset_name() -> str: - # replace all non-word characters with dashes - # to get a string that can be used to create a new dataset - dashed_string = re.sub(r"\W+", "-", st.session_state["data_source"]) - cleaned_string = re.sub(r"--+", "- ", dashed_string).strip("-") - return cleaned_string - - -def get_model() -> BaseLanguageModel: - match st.session_state["model"]: - case "gpt-3.5-turbo": - model = ChatOpenAI( - model_name=st.session_state["model"], - temperature=st.session_state["temperature"], - openai_api_key=st.session_state["openai_api_key"], - ) - case "LlamaCpp": - model = LlamaCpp( - model_path=LLAMACPP_MODEL_PATH, - n_ctx=st.session_state["model_n_ctx"], - temperature=st.session_state["temperature"], - verbose=True, - ) - case "GPT4All": - model = GPT4All( - model=GPT4ALL_MODEL_PATH, - n_ctx=st.session_state["model_n_ctx"], - backend="gptj", - temp=st.session_state["temperature"], - verbose=True, - ) - # Add more models as needed - case _default: - msg = f"Model {st.session_state['model']} not supported!" - logger.error(msg) - st.error(msg) - exit - return model - - -def get_embeddings() -> Embeddings: - match st.session_state["embeddings"]: - case "openai": - embeddings = OpenAIEmbeddings( - disallowed_special=(), openai_api_key=st.session_state["openai_api_key"] - ) - case "huggingface-Fall-MiniLM-L6-v2": - embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") - # Add more embeddings as needed - case _default: - msg = f"Embeddings {st.session_state['embeddings']} not supported!" - logger.error(msg) - st.error(msg) - exit - return embeddings - - -def get_vector_store() -> VectorStore: - # either load existing vector store or upload a new one to the hub - embeddings = get_embeddings() - dataset_name = get_dataset_name() - dataset_path = f"hub://{st.session_state['activeloop_org_name']}/{dataset_name}-{st.session_state['chunk_size']}" - if deeplake.exists(dataset_path, token=st.session_state["activeloop_token"]): - with st.spinner("Loading vector store..."): - logger.info(f"Dataset '{dataset_path}' exists -> loading") - vector_store = DeepLake( - dataset_path=dataset_path, - read_only=True, - embedding_function=embeddings, - token=st.session_state["activeloop_token"], - ) - else: - with st.spinner("Reading, embedding and uploading data to hub..."): - logger.info(f"Dataset '{dataset_path}' does not exist -> uploading") - docs = load_data_source() - vector_store = DeepLake.from_documents( - docs, - embeddings, - dataset_path=dataset_path, - token=st.session_state["activeloop_token"], - ) - return vector_store - - -def get_chain() -> ConversationalRetrievalChain: - # create the langchain that will be called to generate responses - vector_store = get_vector_store() - retriever = vector_store.as_retriever() - # Search params "fetch_k" and "k" define how many documents are pulled from the hub - # and selected after the document matching to build the context - # that is fed to the model together with your prompt - search_kwargs = { - "maximal_marginal_relevance": True, - "distance_metric": "cos", - "fetch_k": st.session_state["fetch_k"], - "k": st.session_state["k"], - } - retriever.search_kwargs.update(search_kwargs) - model = get_model() - chain = ConversationalRetrievalChain.from_llm( - model, - retriever=retriever, - chain_type="stuff", - verbose=True, - # we limit the maximum number of used tokens - # to prevent running into the models token limit of 4096 - max_tokens_limit=st.session_state["max_tokens"], - ) - return chain - - -def update_chain() -> None: - # Build chain with parameters from session state and store it back - # Also delete chat history to not confuse the bot with old context - try: - st.session_state["chain"] = get_chain() - st.session_state["chat_history"] = [] - msg = f"Data source '{st.session_state['data_source']}' is ready to go!" - logger.info(msg) - st.info(msg, icon=PAGE_ICON) - except Exception as e: - msg = f"Failed to build chain for data source '{st.session_state['data_source']}' with error: {e}" - logger.error(msg) - st.error(msg, icon=PAGE_ICON) - - -def update_usage(cb: OpenAICallbackHandler) -> None: - # Accumulate API call usage via callbacks - logger.info(f"Usage: {cb}") - callback_properties = [ - "total_tokens", - "prompt_tokens", - "completion_tokens", - "total_cost", - ] - for prop in callback_properties: - value = getattr(cb, prop, 0) - st.session_state["usage"].setdefault(prop, 0) - st.session_state["usage"][prop] += value - - -def generate_response(prompt: str) -> str: - # call the chain to generate responses and add them to the chat history - with st.spinner("Generating response"), get_openai_callback() as cb: - response = st.session_state["chain"]( - {"question": prompt, "chat_history": st.session_state["chat_history"]} - ) - update_usage(cb) - logger.info(f"Response: '{response}'") - st.session_state["chat_history"].append((prompt, response["answer"])) - return response["answer"]