import streamlit as st from streamlit_chat import message from datachad.chain import generate_response, update_chain from datachad.constants import ( ACTIVELOOP_HELP, APP_NAME, AUTHENTICATION_HELP, CHUNK_OVERLAP, CHUNK_SIZE, DEFAULT_DATA_SOURCE, ENABLE_ADVANCED_OPTIONS, ENABLE_LOCAL_MODE, FETCH_K, LOCAL_MODE_DISABLED_HELP, MAX_TOKENS, MODE_HELP, MODEL_N_CTX, OPENAI_HELP, PAGE_ICON, PROJECT_URL, TEMPERATURE, USAGE_HELP, K, ) from datachad.models import MODELS, MODES from datachad.utils import ( authenticate, delete_uploaded_file, logger, save_uploaded_file, ) # Page options and header st.set_option("client.showErrorDetails", True) st.set_page_config( page_title=APP_NAME, page_icon=PAGE_ICON, initial_sidebar_state="expanded" ) st.markdown( f"

{APP_NAME} {PAGE_ICON}
I know all about your data!

", unsafe_allow_html=True, ) SESSION_DEFAULTS = { "past": [], "usage": {}, "chat_history": [], "generated": [], "auth_ok": False, "openai_api_key": None, "activeloop_token": None, "activeloop_org_name": None, "uploaded_file": None, "data_source": DEFAULT_DATA_SOURCE, "mode": MODES.OPENAI, "model": MODELS.GPT35TURBO, "k": K, "fetch_k": FETCH_K, "chunk_size": CHUNK_SIZE, "chunk_overlap": CHUNK_OVERLAP, "temperature": TEMPERATURE, "max_tokens": MAX_TOKENS, "model_n_ctx": MODEL_N_CTX, } # Initialise session state variables for k, v in SESSION_DEFAULTS.items(): if k not in st.session_state: st.session_state[k] = v def authentication_form() -> None: st.title("Authentication", help=AUTHENTICATION_HELP) with st.form("authentication"): openai_api_key = st.text_input( f"{st.session_state['mode']} API Key", type="password", help=OPENAI_HELP, placeholder="This field is mandatory", ) activeloop_token = st.text_input( "ActiveLoop Token", type="password", help=ACTIVELOOP_HELP, placeholder="Optional, using ours if empty", ) activeloop_org_name = st.text_input( "ActiveLoop Organisation Name", type="password", help=ACTIVELOOP_HELP, placeholder="Optional, using ours if empty", ) submitted = st.form_submit_button("Submit") if submitted: authenticate(openai_api_key, activeloop_token, activeloop_org_name) def advanced_options_form() -> None: # Input Form that takes advanced options and rebuilds chain with them advanced_options = st.checkbox( "Advanced Options", help="Caution! This may break things!" ) if advanced_options: with st.form("advanced_options"): col1, col2 = st.columns(2) col1.selectbox( "model", options=MODELS.for_mode(st.session_state["mode"]), help=f"Learn more about which models are supported [here]({PROJECT_URL})", key="model", ) col2.number_input( "temperature", min_value=0.0, max_value=1.0, value=TEMPERATURE, help="Controls the randomness of the language model output", key="temperature", ) col1.number_input( "k_fetch", min_value=1, max_value=1000, value=FETCH_K, help="The number of documents to pull from the vector database", key="k_fetch", ) col2.number_input( "k", min_value=1, max_value=100, value=K, help="The number of most similar documents to build the context from", key="k", ) col1.number_input( "chunk_size", min_value=1, max_value=100000, value=CHUNK_SIZE, help=( "The size at which the text is divided into smaller chunks " "before being embedded.\n\nChanging this parameter makes re-embedding " "and re-uploading the data to the database necessary " ), key="chunk_size", ) col2.number_input( "max_tokens", min_value=1, max_value=30000, value=MAX_TOKENS, help="Limits the documents returned from database based on number of tokens", key="max_tokens", ) applied = st.form_submit_button("Apply") if applied: update_chain() def update_model_on_mode_change(): # callback for mode selectbox st.session_state["model"] = MODELS.for_mode(st.session_state["mode"])[0] # Sidebar with Authentication and Advanced Options with st.sidebar: mode = st.selectbox( "Mode", MODES.all(), key="mode", help=MODE_HELP, on_change=update_model_on_mode_change(), ) if mode == MODES.LOCAL and not ENABLE_LOCAL_MODE: st.error(LOCAL_MODE_DISABLED_HELP, icon=PAGE_ICON) st.stop() if mode != MODES.LOCAL: authentication_form() st.info(f"Learn how it works [here]({PROJECT_URL})") # Only start App if authentication is OK if not (st.session_state["auth_ok"] or mode == MODES.LOCAL): st.stop() # Clear button to reset all chat communication clear_button = st.button("Clear Conversation") # Advanced Options if ENABLE_ADVANCED_OPTIONS: advanced_options_form() if clear_button: # resets all chat history related caches st.session_state["past"] = [] st.session_state["generated"] = [] st.session_state["chat_history"] = [] # file upload and data source inputs uploaded_file = st.file_uploader("Upload a file") data_source = st.text_input( "Enter any data source", placeholder="Any path or url pointing to a file or directory of files", ) # the chain can only be initialized after authentication is OK if "chain" not in st.session_state: update_chain() # generate new chain for new data source / uploaded file # make sure to do this only once per input / on change if data_source and data_source != st.session_state["data_source"]: logger.info(f"Data source provided: '{data_source}'") st.session_state["data_source"] = data_source update_chain() if uploaded_file and uploaded_file != st.session_state["uploaded_file"]: logger.info(f"Uploaded file: '{uploaded_file.name}'") st.session_state["uploaded_file"] = uploaded_file data_source = save_uploaded_file() st.session_state["data_source"] = data_source update_chain() delete_uploaded_file() # container for chat history response_container = st.container() # container for text box text_container = st.container() # As streamlit reruns the whole script on each change # it is necessary to repopulate the chat containers with text_container: with st.form(key="prompt_input", clear_on_submit=True): user_input = st.text_area("You:", key="input", height=100) submit_button = st.form_submit_button(label="Send") if submit_button and user_input: text_container.empty() output = generate_response(user_input) st.session_state["past"].append(user_input) st.session_state["generated"].append(output) if st.session_state["generated"]: with response_container: for i in range(len(st.session_state["generated"])): message(st.session_state["past"][i], is_user=True, key=str(i) + "_user") message(st.session_state["generated"][i], key=str(i)) # Usage sidebar with total used tokens and costs # We put this at the end to be able to show usage starting with the first response with st.sidebar: if st.session_state["usage"]: st.divider() st.title("Usage", help=USAGE_HELP) col1, col2 = st.columns(2) col1.metric("Total Tokens", st.session_state["usage"]["total_tokens"]) col2.metric("Total Costs in $", st.session_state["usage"]["total_cost"])