add authentification, remove tmp files and fix chat

pull/1/head
Gustav von Zitzewitz 1 year ago
parent 2300731e56
commit 954a2a5859

@ -3,10 +3,10 @@ from streamlit_chat import message
from constants import APP_NAME, DEFAULT_DATA_SOURCE, PAGE_ICON
from utils import (
delete_uploaded_file,
generate_response,
get_chain,
reset_data_source,
save_uploaded_file,
build_chain_and_clear_history,
validate_keys,
)
@ -28,6 +28,10 @@ if "past" not in st.session_state:
st.session_state["past"] = []
if "auth_ok" not in st.session_state:
st.session_state["auth_ok"] = False
if "data_source" not in st.session_state:
st.session_state["data_source"] = ""
if "uploaded_file" not in st.session_state:
st.session_state["uploaded_file"] = None
# Sidebar
@ -48,31 +52,39 @@ with st.sidebar:
if not st.session_state["auth_ok"]:
st.stop()
clear_button = st.button("Clear Conversation and Reset Data", key="clear")
clear_button = st.button("Clear Conversation", key="clear")
# the chain can only be initialized after authentication is OK
if "chain" not in st.session_state:
st.session_state["chain"] = get_chain(DEFAULT_DATA_SOURCE)
build_chain_and_clear_history(DEFAULT_DATA_SOURCE)
if clear_button:
# reset everything
reset_data_source(DEFAULT_DATA_SOURCE)
# reset chat history
st.session_state["past"] = []
st.session_state["generated"] = []
st.session_state["chat_history"] = []
# upload file or enter data source
# file upload and data source inputs
uploaded_file = st.file_uploader("Upload a file")
data_source = st.text_input(
"Enter any data source",
placeholder="Any path or url pointing to a file or directory of files",
)
if uploaded_file:
# generate new chain for new data source / uploaded file
# make sure to do this only once per input / on change
if data_source and data_source != st.session_state["data_source"]:
print(f"data source provided: '{data_source}'")
build_chain_and_clear_history(data_source)
st.session_state["data_source"] = data_source
if uploaded_file and uploaded_file != st.session_state["uploaded_file"]:
print(f"uploaded file: '{uploaded_file.name}'")
data_source = save_uploaded_file(uploaded_file)
reset_data_source(data_source)
build_chain_and_clear_history(data_source)
delete_uploaded_file(uploaded_file)
st.session_state["uploaded_file"] = uploaded_file
if data_source:
print(f"data source provided: '{data_source}'")
reset_data_source(data_source)
# container for chat history
response_container = st.container()

@ -1,7 +1,9 @@
import os
import re
import openai
import deeplake
import shutil
import streamlit as st
from langchain.chains import ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI
@ -28,19 +30,18 @@ from constants import DATA_PATH, MODEL, PAGE_ICON
def validate_keys(openai_key, activeloop_token, activeloop_org_name):
# Validate all API related variables are set and correct
# TODO: Do proper token/key validation, currently activeloop has none
all_keys = [openai_key, activeloop_token, activeloop_org_name]
if any(all_keys):
print(f"{openai_key=}\n{activeloop_token=}\n{activeloop_org_name=}")
if not all(all_keys):
st.session_state["auth_ok"] = False
st.error("Authentication failed", icon=PAGE_ICON)
st.error("You need to fill all fields", icon=PAGE_ICON)
st.stop()
os.environ["OPENAI_API_KEY"] = openai_key
os.environ["ACTIVELOOP_TOKEN"] = activeloop_token
os.environ["ACTIVELOOP_ORG_NAME"] = activeloop_org_name
else:
# Fallback for local development or deployments with provided credentials
# Bypass for local development or deployments with stored credentials
# either env variables or streamlit secrets need to be set
try:
assert os.environ.get("OPENAI_API_KEY")
@ -54,12 +55,26 @@ def validate_keys(openai_key, activeloop_token, activeloop_org_name):
os.environ["OPENAI_API_KEY"] = st.secrets.get("OPENAI_API_KEY")
os.environ["ACTIVELOOP_TOKEN"] = st.secrets.get("ACTIVELOOP_TOKEN")
os.environ["ACTIVELOOP_ORG_NAME"] = st.secrets.get("ACTIVELOOP_ORG_NAME")
try:
# Try to access openai and deeplake
with st.spinner("Authentifying..."):
openai.Model.list()
deeplake.exists(
f"hub://{os.environ['ACTIVELOOP_ORG_NAME']}/DataChad-Authentication-Check",
)
except Exception as e:
print(f"Authentication failed with {e}")
st.session_state["auth_ok"] = False
st.error("Authentication failed", icon=PAGE_ICON)
st.stop()
print("Authentification successful!")
st.session_state["auth_ok"] = True
def save_uploaded_file(uploaded_file):
# streamlit uploaded files need to be stored locally before
# TODO: delete local files after they are uploaded to the datalake
# streamlit uploaded files need to be stored locally
# before embedded and uploaded to the hub
if not os.path.exists(DATA_PATH):
os.makedirs(DATA_PATH)
file_path = str(DATA_PATH / uploaded_file.name)
@ -68,25 +83,37 @@ def save_uploaded_file(uploaded_file):
file = open(file_path, "wb")
file.write(file_bytes)
file.close()
print(f"saved {file_path}")
return file_path
def delete_uploaded_file(uploaded_file):
# cleanup locally stored files
file_path = DATA_PATH / uploaded_file.name
if os.path.exists(DATA_PATH):
os.remove(file_path)
print(f"removed {file_path}")
def load_git(data_source):
# Thank you github for the "master" to "main" switch
repo_name = data_source.split("/")[-1].split(".")[0]
repo_path = str(DATA_PATH / repo_name)
if os.path.exists(repo_path):
data_source = None
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
branches = ["main", "master"]
for branch in branches:
if os.path.exists(repo_path):
data_source = None
try:
docs = GitLoader(repo_path, data_source, branch).load_and_split(
text_splitter
)
break
except Exception as e:
print(f"error loading git: {e}")
if os.path.exists(repo_path):
# cleanup repo afterwards
shutil.rmtree(repo_path)
return docs
@ -145,30 +172,32 @@ def load_any_data_source(data_source):
def clean_data_source_string(data_source):
# replace all non-word characters with dashes
# to get a string that can be used to create a datalake dataset
# to get a string that can be used to create a new dataset
dashed_string = re.sub(r"\W+", "-", data_source)
cleaned_string = re.sub(r"--+", "- ", dashed_string).strip("-")
return cleaned_string
def setup_vector_store(data_source):
# either load existing vector store or upload a new one to the datalake
# either load existing vector store or upload a new one to the hub
embeddings = OpenAIEmbeddings(disallowed_special=())
data_source_name = clean_data_source_string(data_source)
dataset_path = f"hub://{os.environ['ACTIVELOOP_ORG_NAME']}/{data_source_name}"
if deeplake.exists(dataset_path):
print(f"{dataset_path} exists -> loading")
vector_store = DeepLake(
dataset_path=dataset_path, read_only=True, embedding_function=embeddings
)
with st.spinner("Loading vector store..."):
print(f"{dataset_path} exists -> loading")
vector_store = DeepLake(
dataset_path=dataset_path, read_only=True, embedding_function=embeddings
)
else:
print(f"{dataset_path} does not exist -> uploading")
docs = load_any_data_source(data_source)
vector_store = DeepLake.from_documents(
docs,
embeddings,
dataset_path=f"hub://{os.environ['ACTIVELOOP_ORG_NAME']}/{data_source_name}",
)
with st.spinner("Reading, embedding and uploading data to hub..."):
print(f"{dataset_path} does not exist -> uploading")
docs = load_any_data_source(data_source)
vector_store = DeepLake.from_documents(
docs,
embeddings,
dataset_path=f"hub://{os.environ['ACTIVELOOP_ORG_NAME']}/{data_source_name}",
)
return vector_store
@ -184,31 +213,31 @@ def get_chain(data_source):
}
retriever.search_kwargs.update(search_kwargs)
model = ChatOpenAI(model_name=MODEL)
chain = ConversationalRetrievalChain.from_llm(
model,
retriever=retriever,
chain_type="stuff",
verbose=True,
max_tokens_limit=3375,
)
print(f"{data_source} is ready to go!")
with st.spinner("Building langchain..."):
chain = ConversationalRetrievalChain.from_llm(
model,
retriever=retriever,
chain_type="stuff",
verbose=True,
max_tokens_limit=3375,
)
print(f"{data_source} is ready to go!")
return chain
def reset_data_source(data_source):
# we need to reset all caches if a new data source is loaded
# otherwise the langchain is confused and produces garbage
st.session_state["past"] = []
st.session_state["generated"] = []
st.session_state["chat_history"] = []
def build_chain_and_clear_history(data_source):
# Get chain and store it in the session state
# Also delete chat history to not confuse the bot with old context
st.session_state["chain"] = get_chain(data_source)
st.session_state["chat_history"] = []
def generate_response(prompt):
# call the chain to generate responses and add them to the chat history
response = st.session_state["chain"](
{"question": prompt, "chat_history": st.session_state["chat_history"]}
)
print(f"{response=}")
st.session_state["chat_history"].append((prompt, response["answer"]))
with st.spinner("Generating response"):
response = st.session_state["chain"](
{"question": prompt, "chat_history": st.session_state["chat_history"]}
)
print(f"{response=}")
st.session_state["chat_history"].append((prompt, response["answer"]))
return response["answer"]

Loading…
Cancel
Save