diff --git a/README.md b/README.md index 2a551dc..b37ead7 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ This is an app that let's you ask questions about any data source by leveraging ## TODO LIST If you like to contribute, feel free to grab any task +- [x] Refactor utils, especially the loaders - [ ] Add option to choose model and embeddings - [ ] Enable fully local / private mode -- [ ] Refactor utils, especially the loaders - [ ] Add Image caption and Audio transcription support \ No newline at end of file diff --git a/app.py b/app.py index d44c1fc..871f29e 100644 --- a/app.py +++ b/app.py @@ -5,20 +5,22 @@ from constants import ( ACTIVELOOP_HELP, APP_NAME, AUTHENTICATION_HELP, + CHUNK_OVERLAP, CHUNK_SIZE, DEFAULT_DATA_SOURCE, + EMBEDDINGS, ENABLE_ADVANCED_OPTIONS, FETCH_K, MAX_TOKENS, + MODEL, OPENAI_HELP, PAGE_ICON, - REPO_URL, + PROJECT_URL, TEMPERATURE, USAGE_HELP, K, ) from utils import ( - advanced_options_form, authenticate, delete_uploaded_file, generate_response, @@ -49,9 +51,12 @@ SESSION_DEFAULTS = { "openai_api_key": None, "activeloop_token": None, "activeloop_org_name": None, + "model": MODEL, + "embeddings": EMBEDDINGS, "k": K, "fetch_k": FETCH_K, "chunk_size": CHUNK_SIZE, + "chunk_overlap": CHUNK_OVERLAP, "temperature": TEMPERATURE, "max_tokens": MAX_TOKENS, } @@ -61,9 +66,7 @@ for k, v in SESSION_DEFAULTS.items(): st.session_state[k] = v -# Sidebar with Authentication -# Only start App if authentication is OK -with st.sidebar: +def authentication_form() -> None: st.title("Authentication", help=AUTHENTICATION_HELP) with st.form("authentication"): openai_api_key = st.text_input( @@ -88,7 +91,70 @@ with st.sidebar: if submitted: authenticate(openai_api_key, activeloop_token, activeloop_org_name) - st.info(f"Learn how it works [here]({REPO_URL})") + +def advanced_options_form() -> None: + # Input Form that takes advanced options and rebuilds chain with them + advanced_options = st.checkbox( + "Advanced Options", help="Caution! This may break things!" + ) + if advanced_options: + with st.form("advanced_options"): + temperature = st.slider( + "temperature", + min_value=0.0, + max_value=1.0, + value=TEMPERATURE, + help="Controls the randomness of the language model output", + ) + col1, col2 = st.columns(2) + fetch_k = col1.number_input( + "k_fetch", + min_value=1, + max_value=1000, + value=FETCH_K, + help="The number of documents to pull from the vector database", + ) + k = col2.number_input( + "k", + min_value=1, + max_value=100, + value=K, + help="The number of most similar documents to build the context from", + ) + chunk_size = col1.number_input( + "chunk_size", + min_value=1, + max_value=100000, + value=CHUNK_SIZE, + help=( + "The size at which the text is divided into smaller chunks " + "before being embedded.\n\nChanging this parameter makes re-embedding " + "and re-uploading the data to the database necessary " + ), + ) + max_tokens = col2.number_input( + "max_tokens", + min_value=1, + max_value=4069, + value=MAX_TOKENS, + help="Limits the documents returned from database based on number of tokens", + ) + applied = st.form_submit_button("Apply") + if applied: + st.session_state["k"] = k + st.session_state["fetch_k"] = fetch_k + st.session_state["chunk_size"] = chunk_size + st.session_state["temperature"] = temperature + st.session_state["max_tokens"] = max_tokens + update_chain() + + +# Sidebar with Authentication and Advanced Options +with st.sidebar: + authentication_form() + + st.info(f"Learn how it works [here]({PROJECT_URL})") + # Only start App if authentication is OK if not st.session_state["auth_ok"]: st.stop() @@ -99,11 +165,6 @@ with st.sidebar: if ENABLE_ADVANCED_OPTIONS: advanced_options_form() - -# the chain can only be initialized after authentication is OK -if "chain" not in st.session_state: - update_chain() - if clear_button: # resets all chat history related caches st.session_state["past"] = [] @@ -118,6 +179,10 @@ data_source = st.text_input( placeholder="Any path or url pointing to a file or directory of files", ) +# the chain can only be initialized after authentication is OK +if "chain" not in st.session_state: + update_chain() + # generate new chain for new data source / uploaded file # make sure to do this only once per input / on change if data_source and data_source != st.session_state["data_source"]: @@ -128,28 +193,28 @@ if data_source and data_source != st.session_state["data_source"]: if uploaded_file and uploaded_file != st.session_state["uploaded_file"]: logger.info(f"Uploaded file: '{uploaded_file.name}'") st.session_state["uploaded_file"] = uploaded_file - data_source = save_uploaded_file(uploaded_file) + data_source = save_uploaded_file() st.session_state["data_source"] = data_source update_chain() - delete_uploaded_file(uploaded_file) + delete_uploaded_file() # container for chat history response_container = st.container() # container for text box -container = st.container() +text_container = st.container() # As streamlit reruns the whole script on each change # it is necessary to repopulate the chat containers -with container: +with text_container: with st.form(key="prompt_input", clear_on_submit=True): user_input = st.text_area("You:", key="input", height=100) submit_button = st.form_submit_button(label="Send") - if submit_button and user_input: - output = generate_response(user_input) - st.session_state["past"].append(user_input) - st.session_state["generated"].append(output) +if submit_button and user_input: + output = generate_response(user_input) + st.session_state["past"].append(user_input) + st.session_state["generated"].append(output) if st.session_state["generated"]: with response_container: diff --git a/constants.py b/constants.py index 2971944..215d6c0 100644 --- a/constants.py +++ b/constants.py @@ -2,11 +2,15 @@ from pathlib import Path APP_NAME = "DataChad" MODEL = "gpt-3.5-turbo" +EMBEDDINGS = "openai" PAGE_ICON = "🤖" +PROJECT_URL = "https://github.com/gustavz/DataChad" + K = 10 FETCH_K = 20 CHUNK_SIZE = 1000 +CHUNK_OVERLAP = 0 TEMPERATURE = 0.7 MAX_TOKENS = 3357 ENABLE_ADVANCED_OPTIONS = True @@ -14,12 +18,10 @@ ENABLE_ADVANCED_OPTIONS = True DATA_PATH = Path.cwd() / "data" DEFAULT_DATA_SOURCE = "https://github.com/gustavz/DataChad.git" -REPO_URL = "https://github.com/gustavz/DataChad" - AUTHENTICATION_HELP = f""" Your credentials are only stored in your session state.\n The keys are neither exposed nor made visible or stored permanently in any way.\n -Feel free to check out [the code base]({REPO_URL}) to validate how things work. +Feel free to check out [the code base]({PROJECT_URL}) to validate how things work. """ USAGE_HELP = f""" diff --git a/utils.py b/utils.py index 08d1821..63f94c6 100644 --- a/utils.py +++ b/utils.py @@ -9,41 +9,36 @@ import deeplake import openai import streamlit as st from dotenv import load_dotenv +from langchain.base_language import BaseLanguageModel from langchain.callbacks import OpenAICallbackHandler, get_openai_callback from langchain.chains import ConversationalRetrievalChain from langchain.chat_models import ChatOpenAI from langchain.document_loaders import ( CSVLoader, DirectoryLoader, + EverNoteLoader, GitLoader, NotebookLoader, OnlinePDFLoader, + PDFMinerLoader, PythonLoader, TextLoader, + UnstructuredEPubLoader, UnstructuredFileLoader, UnstructuredHTMLLoader, - UnstructuredPDFLoader, + UnstructuredMarkdownLoader, + UnstructuredODTLoader, + UnstructuredPowerPointLoader, UnstructuredWordDocumentLoader, WebBaseLoader, ) -from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.document_loaders.base import BaseLoader +from langchain.embeddings.openai import Embeddings, OpenAIEmbeddings from langchain.schema import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import DeepLake, VectorStore -from streamlit.runtime.uploaded_file_manager import UploadedFile - -from constants import ( - APP_NAME, - CHUNK_SIZE, - DATA_PATH, - FETCH_K, - MAX_TOKENS, - MODEL, - PAGE_ICON, - REPO_URL, - TEMPERATURE, - K, -) + +from constants import APP_NAME, DATA_PATH, PAGE_ICON, PROJECT_URL # loads environment variables load_dotenv() @@ -116,66 +111,10 @@ def authenticate( logger.info("Authentification successful!") -def advanced_options_form() -> None: - # Input Form that takes advanced options and rebuilds chain with them - advanced_options = st.checkbox( - "Advanced Options", help="Caution! This may break things!" - ) - if advanced_options: - with st.form("advanced_options"): - temperature = st.slider( - "temperature", - min_value=0.0, - max_value=1.0, - value=TEMPERATURE, - help="Controls the randomness of the language model output", - ) - col1, col2 = st.columns(2) - fetch_k = col1.number_input( - "k_fetch", - min_value=1, - max_value=1000, - value=FETCH_K, - help="The number of documents to pull from the vector database", - ) - k = col2.number_input( - "k", - min_value=1, - max_value=100, - value=K, - help="The number of most similar documents to build the context from", - ) - chunk_size = col1.number_input( - "chunk_size", - min_value=1, - max_value=100000, - value=CHUNK_SIZE, - help=( - "The size at which the text is divided into smaller chunks " - "before being embedded.\n\nChanging this parameter makes re-embedding " - "and re-uploading the data to the database necessary " - ), - ) - max_tokens = col2.number_input( - "max_tokens", - min_value=1, - max_value=4069, - value=MAX_TOKENS, - help="Limits the documents returned from database based on number of tokens", - ) - applied = st.form_submit_button("Apply") - if applied: - st.session_state["k"] = k - st.session_state["fetch_k"] = fetch_k - st.session_state["chunk_size"] = chunk_size - st.session_state["temperature"] = temperature - st.session_state["max_tokens"] = max_tokens - update_chain() - - -def save_uploaded_file(uploaded_file: UploadedFile) -> str: +def save_uploaded_file() -> str: # streamlit uploaded files need to be stored locally # before embedded and uploaded to the hub + uploaded_file = st.session_state["uploaded_file"] if not os.path.exists(DATA_PATH): os.makedirs(DATA_PATH) file_path = str(DATA_PATH / uploaded_file.name) @@ -188,128 +127,162 @@ def save_uploaded_file(uploaded_file: UploadedFile) -> str: return file_path -def delete_uploaded_file(uploaded_file: UploadedFile) -> None: +def delete_uploaded_file() -> None: # cleanup locally stored files - file_path = DATA_PATH / uploaded_file.name + file_path = DATA_PATH / st.session_state["uploaded_file"].name if os.path.exists(DATA_PATH): os.remove(file_path) logger.info(f"Removed: {file_path}") -def handle_load_error(e: str = None) -> None: - error_msg = f"Failed to load '{st.session_state['data_source']}':\n\n{e}" - st.error(error_msg, icon=PAGE_ICON) - logger.error(error_msg) - st.stop() - - -def load_git(data_source: str, chunk_size: int = CHUNK_SIZE) -> List[Document]: - # We need to try both common main branches - # Thank you github for the "master" to "main" switch - # we need to make sure the data path exists - if not os.path.exists(DATA_PATH): - os.makedirs(DATA_PATH) - repo_name = data_source.split("/")[-1].split(".")[0] - repo_path = str(DATA_PATH / repo_name) - clone_url = data_source - if os.path.exists(repo_path): - clone_url = None - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=chunk_size, chunk_overlap=0 - ) - branches = ["main", "master"] - for branch in branches: +class AutoGitLoader: + def __init__(self, data_source: str) -> None: + self.data_source = data_source + + def load(self) -> List[Document]: + # We need to try both common main branches + # Thank you github for the "master" to "main" switch + # we need to make sure the data path exists + if not os.path.exists(DATA_PATH): + os.makedirs(DATA_PATH) + repo_name = self.data_source.split("/")[-1].split(".")[0] + repo_path = str(DATA_PATH / repo_name) + clone_url = self.data_source + if os.path.exists(repo_path): + clone_url = None + branches = ["main", "master"] + for branch in branches: + try: + docs = GitLoader(repo_path, clone_url, branch).load() + break + except Exception as e: + logger.error(f"Error loading git: {e}") + if os.path.exists(repo_path): + # cleanup repo afterwards + shutil.rmtree(repo_path) try: - docs = GitLoader(repo_path, clone_url, branch).load_and_split(text_splitter) - break - except Exception as e: - logger.error(f"Error loading git: {e}") - if os.path.exists(repo_path): - # cleanup repo afterwards - shutil.rmtree(repo_path) - try: - return docs - except: - msg = "Make sure to use HTTPS git repo links" - handle_load_error(msg) + return docs + except: + raise RuntimeError("Make sure to use HTTPS GitHub repo links") + + +FILE_LOADER_MAPPING = { + ".csv": (CSVLoader, {"encoding": "utf-8"}), + ".doc": (UnstructuredWordDocumentLoader, {}), + ".docx": (UnstructuredWordDocumentLoader, {}), + ".enex": (EverNoteLoader, {}), + ".epub": (UnstructuredEPubLoader, {}), + ".html": (UnstructuredHTMLLoader, {}), + ".md": (UnstructuredMarkdownLoader, {}), + ".odt": (UnstructuredODTLoader, {}), + ".pdf": (PDFMinerLoader, {}), + ".ppt": (UnstructuredPowerPointLoader, {}), + ".pptx": (UnstructuredPowerPointLoader, {}), + ".txt": (TextLoader, {"encoding": "utf8"}), + ".ipynb": (NotebookLoader, {}), + ".py": (PythonLoader, {}), + # Add more mappings for other file extensions and loaders as needed +} + +WEB_LOADER_MAPPING = { + ".git": (AutoGitLoader, {}), + ".pdf": (OnlinePDFLoader, {}), +} + + +def get_loader(file_path: str, mapping: dict, default_loader:BaseLoader) -> BaseLoader: + # Choose loader from mapping, load default if no match found + ext = "." + file_path.rsplit(".", 1)[-1] + if ext in mapping: + loader_class, loader_args = mapping[ext] + loader = loader_class(file_path, **loader_args) + else: + loader = default_loader(file_path) + return loader -def load_any_data_source( - data_source: str, chunk_size: int = CHUNK_SIZE -) -> List[Document]: +def load_data_source() -> List[Document]: # Ugly thing that decides how to load data # It aint much, but it's honest work - is_text = data_source.endswith(".txt") + data_source = st.session_state["data_source"] is_web = data_source.startswith("http") - is_pdf = data_source.endswith(".pdf") - is_csv = data_source.endswith("csv") - is_html = data_source.endswith(".html") - is_git = data_source.endswith(".git") - is_notebook = data_source.endswith(".ipynb") - is_doc = data_source.endswith(".doc") - is_py = data_source.endswith(".py") is_dir = os.path.isdir(data_source) is_file = os.path.isfile(data_source) loader = None if is_dir: loader = DirectoryLoader(data_source, recursive=True, silent_errors=True) - elif is_git: - return load_git(data_source, chunk_size) elif is_web: - if is_pdf: - loader = OnlinePDFLoader(data_source) - else: - loader = WebBaseLoader(data_source) + loader = get_loader(data_source, WEB_LOADER_MAPPING, WebBaseLoader) elif is_file: - if is_text: - loader = TextLoader(data_source, encoding="utf-8") - elif is_notebook: - loader = NotebookLoader(data_source) - elif is_pdf: - loader = UnstructuredPDFLoader(data_source) - elif is_html: - loader = UnstructuredHTMLLoader(data_source) - elif is_doc: - loader = UnstructuredWordDocumentLoader(data_source) - elif is_csv: - loader = CSVLoader(data_source, encoding="utf-8") - elif is_py: - loader = PythonLoader(data_source) - else: - loader = UnstructuredFileLoader(data_source) + loader = get_loader(data_source, FILE_LOADER_MAPPING, UnstructuredFileLoader) try: # Chunk size is a major trade-off parameter to control result accuracy over computaion text_splitter = RecursiveCharacterTextSplitter( - chunk_size=chunk_size, chunk_overlap=0 + chunk_size=st.session_state["chunk_size"], + chunk_overlap=st.session_state["chunk_overlap"], ) - docs = loader.load_and_split(text_splitter) + docs = loader.load() + docs = text_splitter.split_documents(docs) logger.info(f"Loaded: {len(docs)} document chucks") return docs except Exception as e: msg = ( e if loader - else f"No Loader found for your data source. Consider contributing:  {REPO_URL}!" + else f"No Loader found for your data source. Consider contributing:  {PROJECT_URL}!" ) - handle_load_error(msg) + error_msg = f"Failed to load '{st.session_state['data_source']}':\n\n{msg}" + st.error(error_msg, icon=PAGE_ICON) + logger.error(error_msg) + st.stop() -def clean_data_source_string(data_source_string: str) -> str: +def get_data_source_string() -> str: # replace all non-word characters with dashes # to get a string that can be used to create a new dataset - dashed_string = re.sub(r"\W+", "-", data_source_string) + dashed_string = re.sub(r"\W+", "-", st.session_state["data_source"]) cleaned_string = re.sub(r"--+", "- ", dashed_string).strip("-") return cleaned_string -def setup_vector_store(data_source: str, chunk_size: int = CHUNK_SIZE) -> VectorStore: +def get_model() -> BaseLanguageModel: + match st.session_state["model"]: + case "gpt-3.5-turbo": + model = ChatOpenAI( + model_name=st.session_state["model"], + temperature=st.session_state["temperature"], + openai_api_key=st.session_state["openai_api_key"], + ) + # Add more models as needed + case _default: + msg = f"Model {st.session_state['model']} not supported!" + logger.error(msg) + st.error(msg) + exit + return model + + +def get_embeddings() -> Embeddings: + match st.session_state["embeddings"]: + case "openai": + embeddings = OpenAIEmbeddings( + disallowed_special=(), openai_api_key=st.session_state["openai_api_key"] + ) + # Add more embeddings as needed + case _default: + msg = f"Embeddings {st.session_state['embeddings']} not supported!" + logger.error(msg) + st.error(msg) + exit + return embeddings + + +def get_vector_store() -> VectorStore: # either load existing vector store or upload a new one to the hub - embeddings = OpenAIEmbeddings( - disallowed_special=(), openai_api_key=st.session_state["openai_api_key"] - ) - data_source_name = clean_data_source_string(data_source) - dataset_path = f"hub://{st.session_state['activeloop_org_name']}/{data_source_name}-{chunk_size}" + embeddings = get_embeddings() + data_source_name = get_data_source_string() + dataset_path = f"hub://{st.session_state['activeloop_org_name']}/{data_source_name}-{st.session_state['chunk_size']}" if deeplake.exists(dataset_path, token=st.session_state["activeloop_token"]): with st.spinner("Loading vector store..."): logger.info(f"Dataset '{dataset_path}' exists -> loading") @@ -322,7 +295,7 @@ def setup_vector_store(data_source: str, chunk_size: int = CHUNK_SIZE) -> Vector else: with st.spinner("Reading, embedding and uploading data to hub..."): logger.info(f"Dataset '{dataset_path}' does not exist -> uploading") - docs = load_any_data_source(data_source, chunk_size) + docs = load_data_source() vector_store = DeepLake.from_documents( docs, embeddings, @@ -332,16 +305,9 @@ def setup_vector_store(data_source: str, chunk_size: int = CHUNK_SIZE) -> Vector return vector_store -def build_chain( - data_source: str, - k: int = K, - fetch_k: int = FETCH_K, - chunk_size: int = CHUNK_SIZE, - temperature: float = TEMPERATURE, - max_tokens: int = MAX_TOKENS, -) -> ConversationalRetrievalChain: +def get_chain() -> ConversationalRetrievalChain: # create the langchain that will be called to generate responses - vector_store = setup_vector_store(data_source, chunk_size) + vector_store = get_vector_store() retriever = vector_store.as_retriever() # Search params "fetch_k" and "k" define how many documents are pulled from the hub # and selected after the document matching to build the context @@ -349,15 +315,11 @@ def build_chain( search_kwargs = { "maximal_marginal_relevance": True, "distance_metric": "cos", - "fetch_k": fetch_k, - "k": k, + "fetch_k": st.session_state["fetch_k"], + "k": st.session_state["k"], } retriever.search_kwargs.update(search_kwargs) - model = ChatOpenAI( - model_name=MODEL, - temperature=temperature, - openai_api_key=st.session_state["openai_api_key"], - ) + model = get_model() chain = ConversationalRetrievalChain.from_llm( model, retriever=retriever, @@ -365,9 +327,8 @@ def build_chain( verbose=True, # we limit the maximum number of used tokens # to prevent running into the models token limit of 4096 - max_tokens_limit=max_tokens, + max_tokens_limit=st.session_state["max_tokens"], ) - logger.info(f"Data source '{data_source}' is ready to go!") return chain @@ -375,15 +336,11 @@ def update_chain() -> None: # Build chain with parameters from session state and store it back # Also delete chat history to not confuse the bot with old context try: - st.session_state["chain"] = build_chain( - data_source=st.session_state["data_source"], - k=st.session_state["k"], - fetch_k=st.session_state["fetch_k"], - chunk_size=st.session_state["chunk_size"], - temperature=st.session_state["temperature"], - max_tokens=st.session_state["max_tokens"], - ) + st.session_state["chain"] = get_chain() st.session_state["chat_history"] = [] + msg = f"Data source '{st.session_state['data_source']}' is ready to go!" + logger.info(msg) + st.info(msg, icon=PAGE_ICON) except Exception as e: msg = f"Failed to build chain for data source '{st.session_state['data_source']}' with error: {e}" logger.error(msg)