diff --git a/app.py b/app.py index 4399404..2ca4a61 100644 --- a/app.py +++ b/app.py @@ -18,10 +18,10 @@ from constants import ( K, ) from utils import ( + advanced_options_form, authenticate, delete_uploaded_file, generate_response, - handle_advanced_options, logger, save_uploaded_file, update_chain, @@ -109,7 +109,7 @@ with st.sidebar: # Advanced Options if ENABLE_ADVANCED_OPTIONS: - handle_advanced_options() + advanced_options_form() # the chain can only be initialized after authentication is OK diff --git a/utils.py b/utils.py index 644c4e6..db1466c 100644 --- a/utils.py +++ b/utils.py @@ -3,12 +3,13 @@ import os import re import shutil import sys +from typing import List import deeplake import openai import streamlit as st from dotenv import load_dotenv -from langchain.callbacks import get_openai_callback +from langchain.callbacks import OpenAICallbackHandler, get_openai_callback from langchain.chains import ConversationalRetrievalChain from langchain.chat_models import ChatOpenAI from langchain.document_loaders import ( @@ -26,8 +27,10 @@ from langchain.document_loaders import ( WebBaseLoader, ) from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.schema import Document from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain.vectorstores import DeepLake +from langchain.vectorstores import DeepLake, VectorStore +from streamlit.runtime.uploaded_file_manager import UploadedFile from constants import ( APP_NAME, @@ -47,7 +50,7 @@ load_dotenv() logger = logging.getLogger(APP_NAME) -def configure_logger(debug=0): +def configure_logger(debug: int = 0) -> None: # boilerplate code to enable logging in the streamlit app console log_level = logging.DEBUG if debug == 1 else logging.INFO logger.setLevel(log_level) @@ -66,7 +69,9 @@ def configure_logger(debug=0): configure_logger(0) -def authenticate(openai_api_key, activeloop_token, activeloop_org_name): +def authenticate( + openai_api_key: str, activeloop_token: str, activeloop_org_name: str +) -> None: # Validate all credentials are set and correct # Check for env variables to enable local dev and deployments with shared credentials openai_api_key = ( @@ -110,7 +115,7 @@ def authenticate(openai_api_key, activeloop_token, activeloop_org_name): logger.info("Authentification successful!") -def handle_advanced_options(): +def advanced_options_form() -> None: # Input Form that takes advanced options and rebuilds chain with them advanced_options = st.checkbox( "Advanced Options", help="Caution! This may break things!" @@ -128,7 +133,7 @@ def handle_advanced_options(): fetch_k = col1.number_input( "k_fetch", min_value=1, - max_value=100, + max_value=1000, value=FETCH_K, help="The number of documents to pull from the vector database", ) @@ -167,7 +172,7 @@ def handle_advanced_options(): update_chain() -def save_uploaded_file(uploaded_file): +def save_uploaded_file(uploaded_file: UploadedFile) -> str: # streamlit uploaded files need to be stored locally # before embedded and uploaded to the hub if not os.path.exists(DATA_PATH): @@ -182,7 +187,7 @@ def save_uploaded_file(uploaded_file): return file_path -def delete_uploaded_file(uploaded_file): +def delete_uploaded_file(uploaded_file: UploadedFile) -> None: # cleanup locally stored files file_path = DATA_PATH / uploaded_file.name if os.path.exists(DATA_PATH): @@ -190,7 +195,7 @@ def delete_uploaded_file(uploaded_file): logger.info(f"Removed: {file_path}") -def load_git(data_source, chunk_size=CHUNK_SIZE): +def load_git(data_source: str, chunk_size: int = CHUNK_SIZE) -> List[Document]: # We need to try both common main branches # Thank you github for the "master" to "main" switch repo_name = data_source.split("/")[-1].split(".")[0] @@ -215,7 +220,9 @@ def load_git(data_source, chunk_size=CHUNK_SIZE): return docs -def load_any_data_source(data_source, chunk_size=CHUNK_SIZE): +def load_any_data_source( + data_source: str, chunk_size: int = CHUNK_SIZE +) -> List[Document]: # Ugly thing that decides how to load data # It aint much, but it's honest work is_text = data_source.endswith(".txt") @@ -272,7 +279,7 @@ def load_any_data_source(data_source, chunk_size=CHUNK_SIZE): st.stop() -def clean_data_source_string(data_source_string): +def clean_data_source_string(data_source_string: str) -> str: # replace all non-word characters with dashes # to get a string that can be used to create a new dataset dashed_string = re.sub(r"\W+", "-", data_source_string) @@ -280,7 +287,7 @@ def clean_data_source_string(data_source_string): return cleaned_string -def setup_vector_store(data_source, chunk_size=CHUNK_SIZE): +def setup_vector_store(data_source: str, chunk_size: int = CHUNK_SIZE) -> VectorStore: # either load existing vector store or upload a new one to the hub embeddings = OpenAIEmbeddings( disallowed_special=(), openai_api_key=st.session_state["openai_api_key"] @@ -310,13 +317,13 @@ def setup_vector_store(data_source, chunk_size=CHUNK_SIZE): def build_chain( - data_source, - k=K, - fetch_k=FETCH_K, - chunk_size=CHUNK_SIZE, - temperature=TEMPERATURE, - max_tokens=MAX_TOKENS, -): + data_source: str, + k: int = K, + fetch_k: int = FETCH_K, + chunk_size: int = CHUNK_SIZE, + temperature: float = TEMPERATURE, + max_tokens: int = MAX_TOKENS, +) -> ConversationalRetrievalChain: # create the langchain that will be called to generate responses vector_store = setup_vector_store(data_source, chunk_size) retriever = vector_store.as_retriever() @@ -348,8 +355,8 @@ def build_chain( return chain -def update_chain(): - # Build chain with parameters from session state and store it there +def update_chain() -> None: + # Build chain with parameters from session state and store it back # Also delete chat history to not confuse the bot with old context st.session_state["chain"] = build_chain( data_source=st.session_state["data_source"], @@ -362,7 +369,7 @@ def update_chain(): st.session_state["chat_history"] = [] -def update_usage(cb): +def update_usage(cb: OpenAICallbackHandler) -> None: # Accumulate API call usage via callbacks logger.info(f"Usage: {cb}") callback_properties = [ @@ -377,7 +384,7 @@ def update_usage(cb): st.session_state["usage"][prop] += value -def generate_response(prompt): +def generate_response(prompt: str) -> str: # call the chain to generate responses and add them to the chat history with st.spinner("Generating response"), get_openai_callback() as cb: response = st.session_state["chain"](