add authentification, remove tmp files and fix chat

pull/1/head
Gustav von Zitzewitz 1 year ago
parent 2300731e56
commit 954a2a5859

@ -3,10 +3,10 @@ from streamlit_chat import message
from constants import APP_NAME, DEFAULT_DATA_SOURCE, PAGE_ICON from constants import APP_NAME, DEFAULT_DATA_SOURCE, PAGE_ICON
from utils import ( from utils import (
delete_uploaded_file,
generate_response, generate_response,
get_chain,
reset_data_source,
save_uploaded_file, save_uploaded_file,
build_chain_and_clear_history,
validate_keys, validate_keys,
) )
@ -28,6 +28,10 @@ if "past" not in st.session_state:
st.session_state["past"] = [] st.session_state["past"] = []
if "auth_ok" not in st.session_state: if "auth_ok" not in st.session_state:
st.session_state["auth_ok"] = False st.session_state["auth_ok"] = False
if "data_source" not in st.session_state:
st.session_state["data_source"] = ""
if "uploaded_file" not in st.session_state:
st.session_state["uploaded_file"] = None
# Sidebar # Sidebar
@ -48,31 +52,39 @@ with st.sidebar:
if not st.session_state["auth_ok"]: if not st.session_state["auth_ok"]:
st.stop() st.stop()
clear_button = st.button("Clear Conversation and Reset Data", key="clear") clear_button = st.button("Clear Conversation", key="clear")
# the chain can only be initialized after authentication is OK # the chain can only be initialized after authentication is OK
if "chain" not in st.session_state: if "chain" not in st.session_state:
st.session_state["chain"] = get_chain(DEFAULT_DATA_SOURCE) build_chain_and_clear_history(DEFAULT_DATA_SOURCE)
if clear_button: if clear_button:
# reset everything # reset chat history
reset_data_source(DEFAULT_DATA_SOURCE) st.session_state["past"] = []
st.session_state["generated"] = []
st.session_state["chat_history"] = []
# upload file or enter data source # file upload and data source inputs
uploaded_file = st.file_uploader("Upload a file") uploaded_file = st.file_uploader("Upload a file")
data_source = st.text_input( data_source = st.text_input(
"Enter any data source", "Enter any data source",
placeholder="Any path or url pointing to a file or directory of files", placeholder="Any path or url pointing to a file or directory of files",
) )
if uploaded_file: # generate new chain for new data source / uploaded file
# make sure to do this only once per input / on change
if data_source and data_source != st.session_state["data_source"]:
print(f"data source provided: '{data_source}'")
build_chain_and_clear_history(data_source)
st.session_state["data_source"] = data_source
if uploaded_file and uploaded_file != st.session_state["uploaded_file"]:
print(f"uploaded file: '{uploaded_file.name}'") print(f"uploaded file: '{uploaded_file.name}'")
data_source = save_uploaded_file(uploaded_file) data_source = save_uploaded_file(uploaded_file)
reset_data_source(data_source) build_chain_and_clear_history(data_source)
delete_uploaded_file(uploaded_file)
st.session_state["uploaded_file"] = uploaded_file
if data_source:
print(f"data source provided: '{data_source}'")
reset_data_source(data_source)
# container for chat history # container for chat history
response_container = st.container() response_container = st.container()

@ -1,7 +1,9 @@
import os import os
import re import re
import openai
import deeplake import deeplake
import shutil
import streamlit as st import streamlit as st
from langchain.chains import ConversationalRetrievalChain from langchain.chains import ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
@ -28,19 +30,18 @@ from constants import DATA_PATH, MODEL, PAGE_ICON
def validate_keys(openai_key, activeloop_token, activeloop_org_name): def validate_keys(openai_key, activeloop_token, activeloop_org_name):
# Validate all API related variables are set and correct # Validate all API related variables are set and correct
# TODO: Do proper token/key validation, currently activeloop has none
all_keys = [openai_key, activeloop_token, activeloop_org_name] all_keys = [openai_key, activeloop_token, activeloop_org_name]
if any(all_keys): if any(all_keys):
print(f"{openai_key=}\n{activeloop_token=}\n{activeloop_org_name=}") print(f"{openai_key=}\n{activeloop_token=}\n{activeloop_org_name=}")
if not all(all_keys): if not all(all_keys):
st.session_state["auth_ok"] = False st.session_state["auth_ok"] = False
st.error("Authentication failed", icon=PAGE_ICON) st.error("You need to fill all fields", icon=PAGE_ICON)
st.stop() st.stop()
os.environ["OPENAI_API_KEY"] = openai_key os.environ["OPENAI_API_KEY"] = openai_key
os.environ["ACTIVELOOP_TOKEN"] = activeloop_token os.environ["ACTIVELOOP_TOKEN"] = activeloop_token
os.environ["ACTIVELOOP_ORG_NAME"] = activeloop_org_name os.environ["ACTIVELOOP_ORG_NAME"] = activeloop_org_name
else: else:
# Fallback for local development or deployments with provided credentials # Bypass for local development or deployments with stored credentials
# either env variables or streamlit secrets need to be set # either env variables or streamlit secrets need to be set
try: try:
assert os.environ.get("OPENAI_API_KEY") assert os.environ.get("OPENAI_API_KEY")
@ -54,12 +55,26 @@ def validate_keys(openai_key, activeloop_token, activeloop_org_name):
os.environ["OPENAI_API_KEY"] = st.secrets.get("OPENAI_API_KEY") os.environ["OPENAI_API_KEY"] = st.secrets.get("OPENAI_API_KEY")
os.environ["ACTIVELOOP_TOKEN"] = st.secrets.get("ACTIVELOOP_TOKEN") os.environ["ACTIVELOOP_TOKEN"] = st.secrets.get("ACTIVELOOP_TOKEN")
os.environ["ACTIVELOOP_ORG_NAME"] = st.secrets.get("ACTIVELOOP_ORG_NAME") os.environ["ACTIVELOOP_ORG_NAME"] = st.secrets.get("ACTIVELOOP_ORG_NAME")
try:
# Try to access openai and deeplake
with st.spinner("Authentifying..."):
openai.Model.list()
deeplake.exists(
f"hub://{os.environ['ACTIVELOOP_ORG_NAME']}/DataChad-Authentication-Check",
)
except Exception as e:
print(f"Authentication failed with {e}")
st.session_state["auth_ok"] = False
st.error("Authentication failed", icon=PAGE_ICON)
st.stop()
print("Authentification successful!")
st.session_state["auth_ok"] = True st.session_state["auth_ok"] = True
def save_uploaded_file(uploaded_file): def save_uploaded_file(uploaded_file):
# streamlit uploaded files need to be stored locally before # streamlit uploaded files need to be stored locally
# TODO: delete local files after they are uploaded to the datalake # before embedded and uploaded to the hub
if not os.path.exists(DATA_PATH): if not os.path.exists(DATA_PATH):
os.makedirs(DATA_PATH) os.makedirs(DATA_PATH)
file_path = str(DATA_PATH / uploaded_file.name) file_path = str(DATA_PATH / uploaded_file.name)
@ -68,25 +83,37 @@ def save_uploaded_file(uploaded_file):
file = open(file_path, "wb") file = open(file_path, "wb")
file.write(file_bytes) file.write(file_bytes)
file.close() file.close()
print(f"saved {file_path}")
return file_path return file_path
def delete_uploaded_file(uploaded_file):
# cleanup locally stored files
file_path = DATA_PATH / uploaded_file.name
if os.path.exists(DATA_PATH):
os.remove(file_path)
print(f"removed {file_path}")
def load_git(data_source): def load_git(data_source):
# Thank you github for the "master" to "main" switch # Thank you github for the "master" to "main" switch
repo_name = data_source.split("/")[-1].split(".")[0] repo_name = data_source.split("/")[-1].split(".")[0]
repo_path = str(DATA_PATH / repo_name) repo_path = str(DATA_PATH / repo_name)
if os.path.exists(repo_path):
data_source = None
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
branches = ["main", "master"] branches = ["main", "master"]
for branch in branches: for branch in branches:
if os.path.exists(repo_path):
data_source = None
try: try:
docs = GitLoader(repo_path, data_source, branch).load_and_split( docs = GitLoader(repo_path, data_source, branch).load_and_split(
text_splitter text_splitter
) )
break
except Exception as e: except Exception as e:
print(f"error loading git: {e}") print(f"error loading git: {e}")
if os.path.exists(repo_path):
# cleanup repo afterwards
shutil.rmtree(repo_path)
return docs return docs
@ -145,30 +172,32 @@ def load_any_data_source(data_source):
def clean_data_source_string(data_source): def clean_data_source_string(data_source):
# replace all non-word characters with dashes # replace all non-word characters with dashes
# to get a string that can be used to create a datalake dataset # to get a string that can be used to create a new dataset
dashed_string = re.sub(r"\W+", "-", data_source) dashed_string = re.sub(r"\W+", "-", data_source)
cleaned_string = re.sub(r"--+", "- ", dashed_string).strip("-") cleaned_string = re.sub(r"--+", "- ", dashed_string).strip("-")
return cleaned_string return cleaned_string
def setup_vector_store(data_source): def setup_vector_store(data_source):
# either load existing vector store or upload a new one to the datalake # either load existing vector store or upload a new one to the hub
embeddings = OpenAIEmbeddings(disallowed_special=()) embeddings = OpenAIEmbeddings(disallowed_special=())
data_source_name = clean_data_source_string(data_source) data_source_name = clean_data_source_string(data_source)
dataset_path = f"hub://{os.environ['ACTIVELOOP_ORG_NAME']}/{data_source_name}" dataset_path = f"hub://{os.environ['ACTIVELOOP_ORG_NAME']}/{data_source_name}"
if deeplake.exists(dataset_path): if deeplake.exists(dataset_path):
print(f"{dataset_path} exists -> loading") with st.spinner("Loading vector store..."):
vector_store = DeepLake( print(f"{dataset_path} exists -> loading")
dataset_path=dataset_path, read_only=True, embedding_function=embeddings vector_store = DeepLake(
) dataset_path=dataset_path, read_only=True, embedding_function=embeddings
)
else: else:
print(f"{dataset_path} does not exist -> uploading") with st.spinner("Reading, embedding and uploading data to hub..."):
docs = load_any_data_source(data_source) print(f"{dataset_path} does not exist -> uploading")
vector_store = DeepLake.from_documents( docs = load_any_data_source(data_source)
docs, vector_store = DeepLake.from_documents(
embeddings, docs,
dataset_path=f"hub://{os.environ['ACTIVELOOP_ORG_NAME']}/{data_source_name}", embeddings,
) dataset_path=f"hub://{os.environ['ACTIVELOOP_ORG_NAME']}/{data_source_name}",
)
return vector_store return vector_store
@ -184,31 +213,31 @@ def get_chain(data_source):
} }
retriever.search_kwargs.update(search_kwargs) retriever.search_kwargs.update(search_kwargs)
model = ChatOpenAI(model_name=MODEL) model = ChatOpenAI(model_name=MODEL)
chain = ConversationalRetrievalChain.from_llm( with st.spinner("Building langchain..."):
model, chain = ConversationalRetrievalChain.from_llm(
retriever=retriever, model,
chain_type="stuff", retriever=retriever,
verbose=True, chain_type="stuff",
max_tokens_limit=3375, verbose=True,
) max_tokens_limit=3375,
print(f"{data_source} is ready to go!") )
print(f"{data_source} is ready to go!")
return chain return chain
def reset_data_source(data_source): def build_chain_and_clear_history(data_source):
# we need to reset all caches if a new data source is loaded # Get chain and store it in the session state
# otherwise the langchain is confused and produces garbage # Also delete chat history to not confuse the bot with old context
st.session_state["past"] = []
st.session_state["generated"] = []
st.session_state["chat_history"] = []
st.session_state["chain"] = get_chain(data_source) st.session_state["chain"] = get_chain(data_source)
st.session_state["chat_history"] = []
def generate_response(prompt): def generate_response(prompt):
# call the chain to generate responses and add them to the chat history # call the chain to generate responses and add them to the chat history
response = st.session_state["chain"]( with st.spinner("Generating response"):
{"question": prompt, "chat_history": st.session_state["chat_history"]} response = st.session_state["chain"](
) {"question": prompt, "chat_history": st.session_state["chat_history"]}
print(f"{response=}") )
st.session_state["chat_history"].append((prompt, response["answer"])) print(f"{response=}")
st.session_state["chat_history"].append((prompt, response["answer"]))
return response["answer"] return response["answer"]

Loading…
Cancel
Save