diff --git a/app.py b/app.py index 3e18ea8..4674628 100644 --- a/app.py +++ b/app.py @@ -52,7 +52,8 @@ if "activeloop_token" not in st.session_state: if "activeloop_org_name" not in st.session_state: st.session_state["activeloop_org_name"] = None -# Sidebar +# Sidebar with Authentication +# Only start App if authentication is OK with st.sidebar: st.title("Authentication", help=AUTHENTICATION_HELP) with st.form("authentication"): @@ -82,6 +83,7 @@ with st.sidebar: if not st.session_state["auth_ok"]: st.stop() + # Clear button to reset all chat communication clear_button = st.button("Clear Conversation", key="clear") @@ -90,7 +92,7 @@ if "chain" not in st.session_state: build_chain_and_clear_history(DEFAULT_DATA_SOURCE) if clear_button: - # reset chat history + # resets all chat history related caches st.session_state["past"] = [] st.session_state["generated"] = [] st.session_state["chat_history"] = [] @@ -121,6 +123,8 @@ response_container = st.container() # container for text box container = st.container() +# As streamlit reruns the whole script on each change +# it is necessary to repopulate the chat containers with container: with st.form(key="prompt_input", clear_on_submit=True): user_input = st.text_area("You:", key="input", height=100) @@ -138,8 +142,8 @@ if st.session_state["generated"]: message(st.session_state["generated"][i], key=str(i)) -# Usage sidebar -# Put at the end to display even the first input +# Usage sidebar with total used tokens and costs +# We put this at the end to be able to show usage starting with the first response with st.sidebar: if st.session_state["usage"]: st.divider() diff --git a/utils.py b/utils.py index b6d8204..33228e1 100644 --- a/utils.py +++ b/utils.py @@ -34,6 +34,7 @@ logger = logging.getLogger(APP_NAME) def configure_logger(debug=0): + # boilerplate code to enable logging in the streamlit app console log_level = logging.DEBUG if debug == 1 else logging.INFO logger.setLevel(log_level) @@ -115,6 +116,7 @@ def delete_uploaded_file(uploaded_file): def load_git(data_source): + # We need to try both common main branches # Thank you github for the "master" to "main" switch repo_name = data_source.split("/")[-1].split(".")[0] repo_path = str(DATA_PATH / repo_name) @@ -137,7 +139,8 @@ def load_git(data_source): def load_any_data_source(data_source): - # ugly thing that decides how to load data + # Ugly thing that decides how to load data + # It aint much, but it's honest work is_text = data_source.endswith(".txt") is_web = data_source.startswith("http") is_pdf = data_source.endswith(".pdf") @@ -178,6 +181,7 @@ def load_any_data_source(data_source): else: loader = UnstructuredFileLoader(data_source) if loader: + # Chunk size is a major trade-off parameter to control result accuracy over computaion text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) docs = loader.load_and_split(text_splitter) logger.info(f"Loaded: {len(docs)} document chucks") @@ -233,10 +237,13 @@ def get_chain(data_source): # create the langchain that will be called to generate responses vector_store = setup_vector_store(data_source) retriever = vector_store.as_retriever() + # Search params "fetch_k" and "k" define how many documents are pulled from the hub + # and selected after the document matching to build the context + # that is fed to the model together with your prompt search_kwargs = { + "maximal_marginal_relevance": True, "distance_metric": "cos", "fetch_k": 20, - "maximal_marginal_relevance": True, "k": 10, } retriever.search_kwargs.update(search_kwargs) @@ -249,6 +256,8 @@ def get_chain(data_source): retriever=retriever, chain_type="stuff", verbose=True, + # we limit the maximum number of used tokens + # to prevent running into the models token limit of 4096 max_tokens_limit=3375, ) logger.info(f"Data source '{data_source}' is ready to go!")