You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

262 lines
8.1 KiB
Python

import streamlit as st
from streamlit_chat import message
from datachad.chain import generate_response, update_chain
from datachad.constants import (
ACTIVELOOP_HELP,
APP_NAME,
AUTHENTICATION_HELP,
CHUNK_OVERLAP,
CHUNK_SIZE,
DEFAULT_DATA_SOURCE,
ENABLE_ADVANCED_OPTIONS,
ENABLE_LOCAL_MODE,
FETCH_K,
LOCAL_MODE_DISABLED_HELP,
MAX_TOKENS,
MODE_HELP,
MODEL_N_CTX,
OPENAI_HELP,
PAGE_ICON,
PROJECT_URL,
TEMPERATURE,
USAGE_HELP,
K,
)
from datachad.models import MODELS, MODES
from datachad.utils import (
authenticate,
delete_uploaded_file,
logger,
save_uploaded_file,
)
# Page options and header
st.set_option("client.showErrorDetails", True)
st.set_page_config(
page_title=APP_NAME, page_icon=PAGE_ICON, initial_sidebar_state="expanded"
)
st.markdown(
f"<h1 style='text-align: center;'>{APP_NAME} {PAGE_ICON} <br> I know all about your data!</h1>",
unsafe_allow_html=True,
)
SESSION_DEFAULTS = {
"past": [],
"usage": {},
"chat_history": [],
"generated": [],
"auth_ok": False,
"openai_api_key": None,
"activeloop_token": None,
"activeloop_org_name": None,
"uploaded_file": None,
"data_source": DEFAULT_DATA_SOURCE,
"mode": MODES.OPENAI,
"model": MODELS.GPT35TURBO,
"k": K,
"fetch_k": FETCH_K,
"chunk_size": CHUNK_SIZE,
"chunk_overlap": CHUNK_OVERLAP,
"temperature": TEMPERATURE,
"max_tokens": MAX_TOKENS,
"model_n_ctx": MODEL_N_CTX,
}
# Initialise session state variables
for k, v in SESSION_DEFAULTS.items():
if k not in st.session_state:
st.session_state[k] = v
def authentication_form() -> None:
st.title("Authentication", help=AUTHENTICATION_HELP)
with st.form("authentication"):
openai_api_key = st.text_input(
f"{st.session_state['mode']} API Key",
type="password",
help=OPENAI_HELP,
placeholder="This field is mandatory",
)
activeloop_token = st.text_input(
"ActiveLoop Token",
type="password",
help=ACTIVELOOP_HELP,
placeholder="Optional, using ours if empty",
)
activeloop_org_name = st.text_input(
"ActiveLoop Organisation Name",
type="password",
help=ACTIVELOOP_HELP,
placeholder="Optional, using ours if empty",
)
submitted = st.form_submit_button("Submit")
if submitted:
authenticate(openai_api_key, activeloop_token, activeloop_org_name)
def advanced_options_form() -> None:
# Input Form that takes advanced options and rebuilds chain with them
advanced_options = st.checkbox(
"Advanced Options", help="Caution! This may break things!"
)
if advanced_options:
with st.form("advanced_options"):
col1, col2 = st.columns(2)
col1.selectbox(
"model",
options=MODELS.for_mode(st.session_state["mode"]),
help=f"Learn more about which models are supported [here]({PROJECT_URL})",
key="model",
)
col2.number_input(
"temperature",
min_value=0.0,
max_value=1.0,
value=TEMPERATURE,
help="Controls the randomness of the language model output",
key="temperature",
)
col1.number_input(
"k_fetch",
min_value=1,
max_value=1000,
value=FETCH_K,
help="The number of documents to pull from the vector database",
key="k_fetch",
)
col2.number_input(
"k",
min_value=1,
max_value=100,
value=K,
help="The number of most similar documents to build the context from",
key="k",
)
col1.number_input(
"chunk_size",
min_value=1,
max_value=100000,
value=CHUNK_SIZE,
help=(
"The size at which the text is divided into smaller chunks "
"before being embedded.\n\nChanging this parameter makes re-embedding "
"and re-uploading the data to the database necessary "
),
key="chunk_size",
)
col2.number_input(
"max_tokens",
min_value=1,
max_value=30000,
value=MAX_TOKENS,
help="Limits the documents returned from database based on number of tokens",
key="max_tokens",
)
applied = st.form_submit_button("Apply")
if applied:
update_chain()
def update_model_on_mode_change():
# callback for mode selectbox
st.session_state["model"] = MODELS.for_mode(st.session_state["mode"])[0]
# Sidebar with Authentication and Advanced Options
with st.sidebar:
mode = st.selectbox(
"Mode",
MODES.all(),
key="mode",
help=MODE_HELP,
on_change=update_model_on_mode_change(),
)
if mode == MODES.LOCAL and not ENABLE_LOCAL_MODE:
st.error(LOCAL_MODE_DISABLED_HELP, icon=PAGE_ICON)
st.stop()
if mode != MODES.LOCAL:
authentication_form()
st.info(f"Learn how it works [here]({PROJECT_URL})")
# Only start App if authentication is OK
if not (st.session_state["auth_ok"] or mode == MODES.LOCAL):
st.stop()
# Clear button to reset all chat communication
clear_button = st.button("Clear Conversation")
# Advanced Options
if ENABLE_ADVANCED_OPTIONS:
advanced_options_form()
if clear_button:
# resets all chat history related caches
st.session_state["past"] = []
st.session_state["generated"] = []
st.session_state["chat_history"] = []
# file upload and data source inputs
uploaded_file = st.file_uploader("Upload a file")
data_source = st.text_input(
"Enter any data source",
placeholder="Any path or url pointing to a file or directory of files",
)
# the chain can only be initialized after authentication is OK
if "chain" not in st.session_state:
update_chain()
# generate new chain for new data source / uploaded file
# make sure to do this only once per input / on change
if data_source and data_source != st.session_state["data_source"]:
logger.info(f"Data source provided: '{data_source}'")
st.session_state["data_source"] = data_source
update_chain()
if uploaded_file and uploaded_file != st.session_state["uploaded_file"]:
logger.info(f"Uploaded file: '{uploaded_file.name}'")
st.session_state["uploaded_file"] = uploaded_file
data_source = save_uploaded_file()
st.session_state["data_source"] = data_source
update_chain()
delete_uploaded_file()
# container for chat history
response_container = st.container()
# container for text box
text_container = st.container()
# As streamlit reruns the whole script on each change
# it is necessary to repopulate the chat containers
with text_container:
with st.form(key="prompt_input", clear_on_submit=True):
user_input = st.text_area("You:", key="input", height=100)
submit_button = st.form_submit_button(label="Send")
if submit_button and user_input:
text_container.empty()
output = generate_response(user_input)
st.session_state["past"].append(user_input)
st.session_state["generated"].append(output)
if st.session_state["generated"]:
with response_container:
for i in range(len(st.session_state["generated"])):
message(st.session_state["past"][i], is_user=True, key=str(i) + "_user")
message(st.session_state["generated"][i], key=str(i))
# Usage sidebar with total used tokens and costs
# We put this at the end to be able to show usage starting with the first response
with st.sidebar:
if st.session_state["usage"]:
st.divider()
st.title("Usage", help=USAGE_HELP)
col1, col2 = st.columns(2)
col1.metric("Total Tokens", st.session_state["usage"]["total_tokens"])
col2.metric("Total Costs in $", st.session_state["usage"]["total_cost"])