commit 2300731e566c1475abe729687b12244f497a817f Author: Gustav von Zitzewitz Date: Wed May 10 15:55:45 2023 +0200 init commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e9be3b6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +data +__pycache__ +.streamlit/secrets.toml \ No newline at end of file diff --git a/.streamlit/secrets.toml.template b/.streamlit/secrets.toml.template new file mode 100644 index 0000000..496401a --- /dev/null +++ b/.streamlit/secrets.toml.template @@ -0,0 +1,3 @@ +OPENAI_API_KEY = "your openai key" +ACTIVELOOP_TOKEN = "your activeloop key" +ACTIVELOOP_ORG_NAME = "your activeloop organization name" \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..9372dc1 --- /dev/null +++ b/README.md @@ -0,0 +1,23 @@ +# DataChad 🤖 + +This is an app that let's you ask questions about any data source by leveraging [embeddings](https://platform.openai.com/docs/guides/embeddings), [vector databases](https://www.activeloop.ai/), [large language models](https://platform.openai.com/docs/models/gpt-3-5) and last but not least [langchains](https://github.com/hwchase17/langchain) + +## How does it work? + +1. Upload any `file` or enter any `path` or `url` +2. The data source is detected and loaded into text documents +3. The text documents are embedded using openai embeddings +4. The embeddings are stored as a vector dataset to a datalake +5. A langchain is created consisting of a LLM model (`gpt-3.5-turbo` by default) and the embedding database index as retriever +6. When sending questions to the bot this chain is used as context to answer your questions +7. Finally the chat history is cached locally to enable a [ChatGPT](https://chat.openai.com/) like Q&A conversation + +## Good to know + +- As default context this git repository is taken so you can directly start asking question about its functionality without chosing an own data source. +- To run locally or deploy somewhere, execute: + + ```cp .streamlit/secret.toml.template .streamlit/secret.toml``` + + and set necessary keys in the newly created secrets file. Another option is to manually set environment variables +- Yes, Chad in `DataChad` refers to the well-known [meme](https://www.google.com/search?q=chad+meme) diff --git a/app.py b/app.py new file mode 100644 index 0000000..b404f9c --- /dev/null +++ b/app.py @@ -0,0 +1,97 @@ +import streamlit as st +from streamlit_chat import message + +from constants import APP_NAME, DEFAULT_DATA_SOURCE, PAGE_ICON +from utils import ( + generate_response, + get_chain, + reset_data_source, + save_uploaded_file, + validate_keys, +) + + +# Page options and header +st.set_option("client.showErrorDetails", True) +st.set_page_config(page_title=APP_NAME, page_icon=PAGE_ICON) +st.markdown( + f"

{APP_NAME} {PAGE_ICON}
I know all about your data!

", + unsafe_allow_html=True, +) + +# Initialise session state variables +if "chat_history" not in st.session_state: + st.session_state["chat_history"] = [] +if "generated" not in st.session_state: + st.session_state["generated"] = [] +if "past" not in st.session_state: + st.session_state["past"] = [] +if "auth_ok" not in st.session_state: + st.session_state["auth_ok"] = False + + +# Sidebar +with st.sidebar: + st.title("Authentication") + with st.form("authentication"): + openai_key = st.text_input("OpenAI API Key", type="password", key="openai_key") + activeloop_token = st.text_input( + "ActiveLoop Token", type="password", key="activeloop_token" + ) + activeloop_org_name = st.text_input( + "ActiveLoop Organisation Name", type="password", key="activeloop_org_name" + ) + submitted = st.form_submit_button("Submit") + if submitted: + validate_keys(openai_key, activeloop_token, activeloop_org_name) + + if not st.session_state["auth_ok"]: + st.stop() + + clear_button = st.button("Clear Conversation and Reset Data", key="clear") + +# the chain can only be initialized after authentication is OK +if "chain" not in st.session_state: + st.session_state["chain"] = get_chain(DEFAULT_DATA_SOURCE) + +if clear_button: + # reset everything + reset_data_source(DEFAULT_DATA_SOURCE) + +# upload file or enter data source +uploaded_file = st.file_uploader("Upload a file") +data_source = st.text_input( + "Enter any data source", + placeholder="Any path or url pointing to a file or directory of files", +) + +if uploaded_file: + print(f"uploaded file: '{uploaded_file.name}'") + data_source = save_uploaded_file(uploaded_file) + reset_data_source(data_source) + +if data_source: + print(f"data source provided: '{data_source}'") + reset_data_source(data_source) + +# container for chat history +response_container = st.container() +# container for text box +container = st.container() + +with container: + with st.form(key="prompt_input", clear_on_submit=True): + user_input = st.text_area("You:", key="input", height=100) + submit_button = st.form_submit_button(label="Send") + + if submit_button and user_input: + output = generate_response(user_input) + st.session_state["past"].append(user_input) + st.session_state["generated"].append(output) + + +if st.session_state["generated"]: + with response_container: + for i in range(len(st.session_state["generated"])): + message(st.session_state["past"][i], is_user=True, key=str(i) + "_user") + message(st.session_state["generated"][i], key=str(i)) diff --git a/constants.py b/constants.py new file mode 100644 index 0000000..9b41ad3 --- /dev/null +++ b/constants.py @@ -0,0 +1,8 @@ +from pathlib import Path + +APP_NAME = "DataChad" +MODEL = "gpt-3.5-turbo" +PAGE_ICON = "🤖" + +DATA_PATH = Path.cwd() / "data" +DEFAULT_DATA_SOURCE = "git@github.com:gustavz/DataChad.git" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ac52191 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +streamlit==1.22.0 +streamlit-chat==0.0.2.2 +deeplake==3.4.1 +openai==0.27.6 +langchain==0.0.164 +tiktoken==0.4.0 +unstructured==0.6.5 +pdf2image==1.16.3 +pytesseract==0.3.10 +beautifulsoup4==4.12.2 +bs4==0.0.1 \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..a8e5465 --- /dev/null +++ b/utils.py @@ -0,0 +1,214 @@ +import os +import re + +import deeplake +import streamlit as st +from langchain.chains import ConversationalRetrievalChain +from langchain.chat_models import ChatOpenAI +from langchain.document_loaders import ( + CSVLoader, + DirectoryLoader, + GitLoader, + NotebookLoader, + OnlinePDFLoader, + PythonLoader, + TextLoader, + UnstructuredFileLoader, + UnstructuredHTMLLoader, + UnstructuredPDFLoader, + UnstructuredWordDocumentLoader, + WebBaseLoader, +) +from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain.vectorstores import DeepLake + +from constants import DATA_PATH, MODEL, PAGE_ICON + + +def validate_keys(openai_key, activeloop_token, activeloop_org_name): + # Validate all API related variables are set and correct + # TODO: Do proper token/key validation, currently activeloop has none + all_keys = [openai_key, activeloop_token, activeloop_org_name] + if any(all_keys): + print(f"{openai_key=}\n{activeloop_token=}\n{activeloop_org_name=}") + if not all(all_keys): + st.session_state["auth_ok"] = False + st.error("Authentication failed", icon=PAGE_ICON) + st.stop() + os.environ["OPENAI_API_KEY"] = openai_key + os.environ["ACTIVELOOP_TOKEN"] = activeloop_token + os.environ["ACTIVELOOP_ORG_NAME"] = activeloop_org_name + else: + # Fallback for local development or deployments with provided credentials + # either env variables or streamlit secrets need to be set + try: + assert os.environ.get("OPENAI_API_KEY") + assert os.environ.get("ACTIVELOOP_TOKEN") + assert os.environ.get("ACTIVELOOP_ORG_NAME") + except: + assert st.secrets.get("OPENAI_API_KEY") + assert st.secrets.get("ACTIVELOOP_TOKEN") + assert st.secrets.get("ACTIVELOOP_ORG_NAME") + + os.environ["OPENAI_API_KEY"] = st.secrets.get("OPENAI_API_KEY") + os.environ["ACTIVELOOP_TOKEN"] = st.secrets.get("ACTIVELOOP_TOKEN") + os.environ["ACTIVELOOP_ORG_NAME"] = st.secrets.get("ACTIVELOOP_ORG_NAME") + st.session_state["auth_ok"] = True + + +def save_uploaded_file(uploaded_file): + # streamlit uploaded files need to be stored locally before + # TODO: delete local files after they are uploaded to the datalake + if not os.path.exists(DATA_PATH): + os.makedirs(DATA_PATH) + file_path = str(DATA_PATH / uploaded_file.name) + uploaded_file.seek(0) + file_bytes = uploaded_file.read() + file = open(file_path, "wb") + file.write(file_bytes) + file.close() + return file_path + + +def load_git(data_source): + # Thank you github for the "master" to "main" switch + repo_name = data_source.split("/")[-1].split(".")[0] + repo_path = str(DATA_PATH / repo_name) + if os.path.exists(repo_path): + data_source = None + + text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) + branches = ["main", "master"] + for branch in branches: + try: + docs = GitLoader(repo_path, data_source, branch).load_and_split( + text_splitter + ) + except Exception as e: + print(f"error loading git: {e}") + return docs + + +def load_any_data_source(data_source): + # ugly thing that decides how to load data + is_text = data_source.endswith(".txt") + is_web = data_source.startswith("http") + is_pdf = data_source.endswith(".pdf") + is_csv = data_source.endswith("csv") + is_html = data_source.endswith(".html") + is_git = data_source.endswith(".git") + is_notebook = data_source.endswith(".ipynb") + is_doc = data_source.endswith(".doc") + is_py = data_source.endswith(".py") + is_dir = os.path.isdir(data_source) + is_file = os.path.isfile(data_source) + + loader = None + if is_dir: + loader = DirectoryLoader(data_source, recursive=True) + if is_git: + return load_git(data_source) + if is_web: + if is_pdf: + loader = OnlinePDFLoader(data_source) + else: + loader = WebBaseLoader(data_source) + if is_file: + if is_text: + loader = TextLoader(data_source) + elif is_notebook: + loader = NotebookLoader(data_source) + elif is_pdf: + loader = UnstructuredPDFLoader(data_source) + elif is_html: + loader = UnstructuredHTMLLoader(data_source) + elif is_doc: + loader = UnstructuredWordDocumentLoader(data_source) + elif is_csv: + loader = CSVLoader(data_source, encoding="utf-8") + elif is_py: + loader = PythonLoader(data_source) + else: + loader = UnstructuredFileLoader(data_source) + if loader: + text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) + docs = loader.load_and_split(text_splitter) + print(f"loaded {len(docs)} document chucks") + return docs + + error_msg = f"Failed to load {data_source}" + st.error(error_msg, icon=PAGE_ICON) + print(error_msg) + st.stop() + + +def clean_data_source_string(data_source): + # replace all non-word characters with dashes + # to get a string that can be used to create a datalake dataset + dashed_string = re.sub(r"\W+", "-", data_source) + cleaned_string = re.sub(r"--+", "- ", dashed_string).strip("-") + return cleaned_string + + +def setup_vector_store(data_source): + # either load existing vector store or upload a new one to the datalake + embeddings = OpenAIEmbeddings(disallowed_special=()) + data_source_name = clean_data_source_string(data_source) + dataset_path = f"hub://{os.environ['ACTIVELOOP_ORG_NAME']}/{data_source_name}" + if deeplake.exists(dataset_path): + print(f"{dataset_path} exists -> loading") + vector_store = DeepLake( + dataset_path=dataset_path, read_only=True, embedding_function=embeddings + ) + else: + print(f"{dataset_path} does not exist -> uploading") + docs = load_any_data_source(data_source) + vector_store = DeepLake.from_documents( + docs, + embeddings, + dataset_path=f"hub://{os.environ['ACTIVELOOP_ORG_NAME']}/{data_source_name}", + ) + return vector_store + + +def get_chain(data_source): + # create the langchain that will be called to generate responses + vector_store = setup_vector_store(data_source) + retriever = vector_store.as_retriever() + search_kwargs = { + "distance_metric": "cos", + "fetch_k": 20, + "maximal_marginal_relevance": True, + "k": 10, + } + retriever.search_kwargs.update(search_kwargs) + model = ChatOpenAI(model_name=MODEL) + chain = ConversationalRetrievalChain.from_llm( + model, + retriever=retriever, + chain_type="stuff", + verbose=True, + max_tokens_limit=3375, + ) + print(f"{data_source} is ready to go!") + return chain + + +def reset_data_source(data_source): + # we need to reset all caches if a new data source is loaded + # otherwise the langchain is confused and produces garbage + st.session_state["past"] = [] + st.session_state["generated"] = [] + st.session_state["chat_history"] = [] + st.session_state["chain"] = get_chain(data_source) + + +def generate_response(prompt): + # call the chain to generate responses and add them to the chat history + response = st.session_state["chain"]( + {"question": prompt, "chat_history": st.session_state["chat_history"]} + ) + print(f"{response=}") + st.session_state["chat_history"].append((prompt, response["answer"])) + return response["answer"]