add function signature types

main
Gustav von Zitzewitz 1 year ago
parent 7f9fab8593
commit fce52ddd1d

@ -18,10 +18,10 @@ from constants import (
K,
)
from utils import (
advanced_options_form,
authenticate,
delete_uploaded_file,
generate_response,
handle_advanced_options,
logger,
save_uploaded_file,
update_chain,
@ -109,7 +109,7 @@ with st.sidebar:
# Advanced Options
if ENABLE_ADVANCED_OPTIONS:
handle_advanced_options()
advanced_options_form()
# the chain can only be initialized after authentication is OK

@ -3,12 +3,13 @@ import os
import re
import shutil
import sys
from typing import List
import deeplake
import openai
import streamlit as st
from dotenv import load_dotenv
from langchain.callbacks import get_openai_callback
from langchain.callbacks import OpenAICallbackHandler, get_openai_callback
from langchain.chains import ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import (
@ -26,8 +27,10 @@ from langchain.document_loaders import (
WebBaseLoader,
)
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import DeepLake
from langchain.vectorstores import DeepLake, VectorStore
from streamlit.runtime.uploaded_file_manager import UploadedFile
from constants import (
APP_NAME,
@ -47,7 +50,7 @@ load_dotenv()
logger = logging.getLogger(APP_NAME)
def configure_logger(debug=0):
def configure_logger(debug: int = 0) -> None:
# boilerplate code to enable logging in the streamlit app console
log_level = logging.DEBUG if debug == 1 else logging.INFO
logger.setLevel(log_level)
@ -66,7 +69,9 @@ def configure_logger(debug=0):
configure_logger(0)
def authenticate(openai_api_key, activeloop_token, activeloop_org_name):
def authenticate(
openai_api_key: str, activeloop_token: str, activeloop_org_name: str
) -> None:
# Validate all credentials are set and correct
# Check for env variables to enable local dev and deployments with shared credentials
openai_api_key = (
@ -110,7 +115,7 @@ def authenticate(openai_api_key, activeloop_token, activeloop_org_name):
logger.info("Authentification successful!")
def handle_advanced_options():
def advanced_options_form() -> None:
# Input Form that takes advanced options and rebuilds chain with them
advanced_options = st.checkbox(
"Advanced Options", help="Caution! This may break things!"
@ -128,7 +133,7 @@ def handle_advanced_options():
fetch_k = col1.number_input(
"k_fetch",
min_value=1,
max_value=100,
max_value=1000,
value=FETCH_K,
help="The number of documents to pull from the vector database",
)
@ -167,7 +172,7 @@ def handle_advanced_options():
update_chain()
def save_uploaded_file(uploaded_file):
def save_uploaded_file(uploaded_file: UploadedFile) -> str:
# streamlit uploaded files need to be stored locally
# before embedded and uploaded to the hub
if not os.path.exists(DATA_PATH):
@ -182,7 +187,7 @@ def save_uploaded_file(uploaded_file):
return file_path
def delete_uploaded_file(uploaded_file):
def delete_uploaded_file(uploaded_file: UploadedFile) -> None:
# cleanup locally stored files
file_path = DATA_PATH / uploaded_file.name
if os.path.exists(DATA_PATH):
@ -190,7 +195,7 @@ def delete_uploaded_file(uploaded_file):
logger.info(f"Removed: {file_path}")
def load_git(data_source, chunk_size=CHUNK_SIZE):
def load_git(data_source: str, chunk_size: int = CHUNK_SIZE) -> List[Document]:
# We need to try both common main branches
# Thank you github for the "master" to "main" switch
repo_name = data_source.split("/")[-1].split(".")[0]
@ -215,7 +220,9 @@ def load_git(data_source, chunk_size=CHUNK_SIZE):
return docs
def load_any_data_source(data_source, chunk_size=CHUNK_SIZE):
def load_any_data_source(
data_source: str, chunk_size: int = CHUNK_SIZE
) -> List[Document]:
# Ugly thing that decides how to load data
# It aint much, but it's honest work
is_text = data_source.endswith(".txt")
@ -272,7 +279,7 @@ def load_any_data_source(data_source, chunk_size=CHUNK_SIZE):
st.stop()
def clean_data_source_string(data_source_string):
def clean_data_source_string(data_source_string: str) -> str:
# replace all non-word characters with dashes
# to get a string that can be used to create a new dataset
dashed_string = re.sub(r"\W+", "-", data_source_string)
@ -280,7 +287,7 @@ def clean_data_source_string(data_source_string):
return cleaned_string
def setup_vector_store(data_source, chunk_size=CHUNK_SIZE):
def setup_vector_store(data_source: str, chunk_size: int = CHUNK_SIZE) -> VectorStore:
# either load existing vector store or upload a new one to the hub
embeddings = OpenAIEmbeddings(
disallowed_special=(), openai_api_key=st.session_state["openai_api_key"]
@ -310,13 +317,13 @@ def setup_vector_store(data_source, chunk_size=CHUNK_SIZE):
def build_chain(
data_source,
k=K,
fetch_k=FETCH_K,
chunk_size=CHUNK_SIZE,
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
):
data_source: str,
k: int = K,
fetch_k: int = FETCH_K,
chunk_size: int = CHUNK_SIZE,
temperature: float = TEMPERATURE,
max_tokens: int = MAX_TOKENS,
) -> ConversationalRetrievalChain:
# create the langchain that will be called to generate responses
vector_store = setup_vector_store(data_source, chunk_size)
retriever = vector_store.as_retriever()
@ -348,8 +355,8 @@ def build_chain(
return chain
def update_chain():
# Build chain with parameters from session state and store it there
def update_chain() -> None:
# Build chain with parameters from session state and store it back
# Also delete chat history to not confuse the bot with old context
st.session_state["chain"] = build_chain(
data_source=st.session_state["data_source"],
@ -362,7 +369,7 @@ def update_chain():
st.session_state["chat_history"] = []
def update_usage(cb):
def update_usage(cb: OpenAICallbackHandler) -> None:
# Accumulate API call usage via callbacks
logger.info(f"Usage: {cb}")
callback_properties = [
@ -377,7 +384,7 @@ def update_usage(cb):
st.session_state["usage"][prop] += value
def generate_response(prompt):
def generate_response(prompt: str) -> str:
# call the chain to generate responses and add them to the chat history
with st.spinner("Generating response"), get_openai_callback() as cb:
response = st.session_state["chain"](

Loading…
Cancel
Save