diff --git a/app.py b/app.py index b404f9c..382ccd2 100644 --- a/app.py +++ b/app.py @@ -3,10 +3,10 @@ from streamlit_chat import message from constants import APP_NAME, DEFAULT_DATA_SOURCE, PAGE_ICON from utils import ( + delete_uploaded_file, generate_response, - get_chain, - reset_data_source, save_uploaded_file, + build_chain_and_clear_history, validate_keys, ) @@ -28,6 +28,10 @@ if "past" not in st.session_state: st.session_state["past"] = [] if "auth_ok" not in st.session_state: st.session_state["auth_ok"] = False +if "data_source" not in st.session_state: + st.session_state["data_source"] = "" +if "uploaded_file" not in st.session_state: + st.session_state["uploaded_file"] = None # Sidebar @@ -48,31 +52,39 @@ with st.sidebar: if not st.session_state["auth_ok"]: st.stop() - clear_button = st.button("Clear Conversation and Reset Data", key="clear") + clear_button = st.button("Clear Conversation", key="clear") # the chain can only be initialized after authentication is OK if "chain" not in st.session_state: - st.session_state["chain"] = get_chain(DEFAULT_DATA_SOURCE) + build_chain_and_clear_history(DEFAULT_DATA_SOURCE) if clear_button: - # reset everything - reset_data_source(DEFAULT_DATA_SOURCE) + # reset chat history + st.session_state["past"] = [] + st.session_state["generated"] = [] + st.session_state["chat_history"] = [] -# upload file or enter data source +# file upload and data source inputs uploaded_file = st.file_uploader("Upload a file") data_source = st.text_input( "Enter any data source", placeholder="Any path or url pointing to a file or directory of files", ) -if uploaded_file: +# generate new chain for new data source / uploaded file +# make sure to do this only once per input / on change +if data_source and data_source != st.session_state["data_source"]: + print(f"data source provided: '{data_source}'") + build_chain_and_clear_history(data_source) + st.session_state["data_source"] = data_source + +if uploaded_file and uploaded_file != st.session_state["uploaded_file"]: print(f"uploaded file: '{uploaded_file.name}'") data_source = save_uploaded_file(uploaded_file) - reset_data_source(data_source) + build_chain_and_clear_history(data_source) + delete_uploaded_file(uploaded_file) + st.session_state["uploaded_file"] = uploaded_file -if data_source: - print(f"data source provided: '{data_source}'") - reset_data_source(data_source) # container for chat history response_container = st.container() diff --git a/utils.py b/utils.py index a8e5465..c8cb69b 100644 --- a/utils.py +++ b/utils.py @@ -1,7 +1,9 @@ import os import re +import openai import deeplake +import shutil import streamlit as st from langchain.chains import ConversationalRetrievalChain from langchain.chat_models import ChatOpenAI @@ -28,19 +30,18 @@ from constants import DATA_PATH, MODEL, PAGE_ICON def validate_keys(openai_key, activeloop_token, activeloop_org_name): # Validate all API related variables are set and correct - # TODO: Do proper token/key validation, currently activeloop has none all_keys = [openai_key, activeloop_token, activeloop_org_name] if any(all_keys): print(f"{openai_key=}\n{activeloop_token=}\n{activeloop_org_name=}") if not all(all_keys): st.session_state["auth_ok"] = False - st.error("Authentication failed", icon=PAGE_ICON) + st.error("You need to fill all fields", icon=PAGE_ICON) st.stop() os.environ["OPENAI_API_KEY"] = openai_key os.environ["ACTIVELOOP_TOKEN"] = activeloop_token os.environ["ACTIVELOOP_ORG_NAME"] = activeloop_org_name else: - # Fallback for local development or deployments with provided credentials + # Bypass for local development or deployments with stored credentials # either env variables or streamlit secrets need to be set try: assert os.environ.get("OPENAI_API_KEY") @@ -54,12 +55,26 @@ def validate_keys(openai_key, activeloop_token, activeloop_org_name): os.environ["OPENAI_API_KEY"] = st.secrets.get("OPENAI_API_KEY") os.environ["ACTIVELOOP_TOKEN"] = st.secrets.get("ACTIVELOOP_TOKEN") os.environ["ACTIVELOOP_ORG_NAME"] = st.secrets.get("ACTIVELOOP_ORG_NAME") + try: + # Try to access openai and deeplake + with st.spinner("Authentifying..."): + openai.Model.list() + deeplake.exists( + f"hub://{os.environ['ACTIVELOOP_ORG_NAME']}/DataChad-Authentication-Check", + ) + except Exception as e: + print(f"Authentication failed with {e}") + st.session_state["auth_ok"] = False + st.error("Authentication failed", icon=PAGE_ICON) + st.stop() + + print("Authentification successful!") st.session_state["auth_ok"] = True def save_uploaded_file(uploaded_file): - # streamlit uploaded files need to be stored locally before - # TODO: delete local files after they are uploaded to the datalake + # streamlit uploaded files need to be stored locally + # before embedded and uploaded to the hub if not os.path.exists(DATA_PATH): os.makedirs(DATA_PATH) file_path = str(DATA_PATH / uploaded_file.name) @@ -68,25 +83,37 @@ def save_uploaded_file(uploaded_file): file = open(file_path, "wb") file.write(file_bytes) file.close() + print(f"saved {file_path}") return file_path +def delete_uploaded_file(uploaded_file): + # cleanup locally stored files + file_path = DATA_PATH / uploaded_file.name + if os.path.exists(DATA_PATH): + os.remove(file_path) + print(f"removed {file_path}") + + def load_git(data_source): # Thank you github for the "master" to "main" switch repo_name = data_source.split("/")[-1].split(".")[0] repo_path = str(DATA_PATH / repo_name) - if os.path.exists(repo_path): - data_source = None - text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) branches = ["main", "master"] for branch in branches: + if os.path.exists(repo_path): + data_source = None try: docs = GitLoader(repo_path, data_source, branch).load_and_split( text_splitter ) + break except Exception as e: print(f"error loading git: {e}") + if os.path.exists(repo_path): + # cleanup repo afterwards + shutil.rmtree(repo_path) return docs @@ -145,30 +172,32 @@ def load_any_data_source(data_source): def clean_data_source_string(data_source): # replace all non-word characters with dashes - # to get a string that can be used to create a datalake dataset + # to get a string that can be used to create a new dataset dashed_string = re.sub(r"\W+", "-", data_source) cleaned_string = re.sub(r"--+", "- ", dashed_string).strip("-") return cleaned_string def setup_vector_store(data_source): - # either load existing vector store or upload a new one to the datalake + # either load existing vector store or upload a new one to the hub embeddings = OpenAIEmbeddings(disallowed_special=()) data_source_name = clean_data_source_string(data_source) dataset_path = f"hub://{os.environ['ACTIVELOOP_ORG_NAME']}/{data_source_name}" if deeplake.exists(dataset_path): - print(f"{dataset_path} exists -> loading") - vector_store = DeepLake( - dataset_path=dataset_path, read_only=True, embedding_function=embeddings - ) + with st.spinner("Loading vector store..."): + print(f"{dataset_path} exists -> loading") + vector_store = DeepLake( + dataset_path=dataset_path, read_only=True, embedding_function=embeddings + ) else: - print(f"{dataset_path} does not exist -> uploading") - docs = load_any_data_source(data_source) - vector_store = DeepLake.from_documents( - docs, - embeddings, - dataset_path=f"hub://{os.environ['ACTIVELOOP_ORG_NAME']}/{data_source_name}", - ) + with st.spinner("Reading, embedding and uploading data to hub..."): + print(f"{dataset_path} does not exist -> uploading") + docs = load_any_data_source(data_source) + vector_store = DeepLake.from_documents( + docs, + embeddings, + dataset_path=f"hub://{os.environ['ACTIVELOOP_ORG_NAME']}/{data_source_name}", + ) return vector_store @@ -184,31 +213,31 @@ def get_chain(data_source): } retriever.search_kwargs.update(search_kwargs) model = ChatOpenAI(model_name=MODEL) - chain = ConversationalRetrievalChain.from_llm( - model, - retriever=retriever, - chain_type="stuff", - verbose=True, - max_tokens_limit=3375, - ) - print(f"{data_source} is ready to go!") + with st.spinner("Building langchain..."): + chain = ConversationalRetrievalChain.from_llm( + model, + retriever=retriever, + chain_type="stuff", + verbose=True, + max_tokens_limit=3375, + ) + print(f"{data_source} is ready to go!") return chain -def reset_data_source(data_source): - # we need to reset all caches if a new data source is loaded - # otherwise the langchain is confused and produces garbage - st.session_state["past"] = [] - st.session_state["generated"] = [] - st.session_state["chat_history"] = [] +def build_chain_and_clear_history(data_source): + # Get chain and store it in the session state + # Also delete chat history to not confuse the bot with old context st.session_state["chain"] = get_chain(data_source) + st.session_state["chat_history"] = [] def generate_response(prompt): # call the chain to generate responses and add them to the chat history - response = st.session_state["chain"]( - {"question": prompt, "chat_history": st.session_state["chat_history"]} - ) - print(f"{response=}") - st.session_state["chat_history"].append((prompt, response["answer"])) + with st.spinner("Generating response"): + response = st.session_state["chain"]( + {"question": prompt, "chat_history": st.session_state["chat_history"]} + ) + print(f"{response=}") + st.session_state["chat_history"].append((prompt, response["answer"])) return response["answer"]