refactor utils, prepared for new features

main
Gustav von Zitzewitz 1 year ago
parent ba6b376d56
commit 2a0e0bd4cc

@ -26,7 +26,7 @@ This is an app that let's you ask questions about any data source by leveraging
## TODO LIST
If you like to contribute, feel free to grab any task
- [x] Refactor utils, especially the loaders
- [ ] Add option to choose model and embeddings
- [ ] Enable fully local / private mode
- [ ] Refactor utils, especially the loaders
- [ ] Add Image caption and Audio transcription support

103
app.py

@ -5,20 +5,22 @@ from constants import (
ACTIVELOOP_HELP,
APP_NAME,
AUTHENTICATION_HELP,
CHUNK_OVERLAP,
CHUNK_SIZE,
DEFAULT_DATA_SOURCE,
EMBEDDINGS,
ENABLE_ADVANCED_OPTIONS,
FETCH_K,
MAX_TOKENS,
MODEL,
OPENAI_HELP,
PAGE_ICON,
REPO_URL,
PROJECT_URL,
TEMPERATURE,
USAGE_HELP,
K,
)
from utils import (
advanced_options_form,
authenticate,
delete_uploaded_file,
generate_response,
@ -49,9 +51,12 @@ SESSION_DEFAULTS = {
"openai_api_key": None,
"activeloop_token": None,
"activeloop_org_name": None,
"model": MODEL,
"embeddings": EMBEDDINGS,
"k": K,
"fetch_k": FETCH_K,
"chunk_size": CHUNK_SIZE,
"chunk_overlap": CHUNK_OVERLAP,
"temperature": TEMPERATURE,
"max_tokens": MAX_TOKENS,
}
@ -61,9 +66,7 @@ for k, v in SESSION_DEFAULTS.items():
st.session_state[k] = v
# Sidebar with Authentication
# Only start App if authentication is OK
with st.sidebar:
def authentication_form() -> None:
st.title("Authentication", help=AUTHENTICATION_HELP)
with st.form("authentication"):
openai_api_key = st.text_input(
@ -88,7 +91,70 @@ with st.sidebar:
if submitted:
authenticate(openai_api_key, activeloop_token, activeloop_org_name)
st.info(f"Learn how it works [here]({REPO_URL})")
def advanced_options_form() -> None:
# Input Form that takes advanced options and rebuilds chain with them
advanced_options = st.checkbox(
"Advanced Options", help="Caution! This may break things!"
)
if advanced_options:
with st.form("advanced_options"):
temperature = st.slider(
"temperature",
min_value=0.0,
max_value=1.0,
value=TEMPERATURE,
help="Controls the randomness of the language model output",
)
col1, col2 = st.columns(2)
fetch_k = col1.number_input(
"k_fetch",
min_value=1,
max_value=1000,
value=FETCH_K,
help="The number of documents to pull from the vector database",
)
k = col2.number_input(
"k",
min_value=1,
max_value=100,
value=K,
help="The number of most similar documents to build the context from",
)
chunk_size = col1.number_input(
"chunk_size",
min_value=1,
max_value=100000,
value=CHUNK_SIZE,
help=(
"The size at which the text is divided into smaller chunks "
"before being embedded.\n\nChanging this parameter makes re-embedding "
"and re-uploading the data to the database necessary "
),
)
max_tokens = col2.number_input(
"max_tokens",
min_value=1,
max_value=4069,
value=MAX_TOKENS,
help="Limits the documents returned from database based on number of tokens",
)
applied = st.form_submit_button("Apply")
if applied:
st.session_state["k"] = k
st.session_state["fetch_k"] = fetch_k
st.session_state["chunk_size"] = chunk_size
st.session_state["temperature"] = temperature
st.session_state["max_tokens"] = max_tokens
update_chain()
# Sidebar with Authentication and Advanced Options
with st.sidebar:
authentication_form()
st.info(f"Learn how it works [here]({PROJECT_URL})")
# Only start App if authentication is OK
if not st.session_state["auth_ok"]:
st.stop()
@ -99,11 +165,6 @@ with st.sidebar:
if ENABLE_ADVANCED_OPTIONS:
advanced_options_form()
# the chain can only be initialized after authentication is OK
if "chain" not in st.session_state:
update_chain()
if clear_button:
# resets all chat history related caches
st.session_state["past"] = []
@ -118,6 +179,10 @@ data_source = st.text_input(
placeholder="Any path or url pointing to a file or directory of files",
)
# the chain can only be initialized after authentication is OK
if "chain" not in st.session_state:
update_chain()
# generate new chain for new data source / uploaded file
# make sure to do this only once per input / on change
if data_source and data_source != st.session_state["data_source"]:
@ -128,28 +193,28 @@ if data_source and data_source != st.session_state["data_source"]:
if uploaded_file and uploaded_file != st.session_state["uploaded_file"]:
logger.info(f"Uploaded file: '{uploaded_file.name}'")
st.session_state["uploaded_file"] = uploaded_file
data_source = save_uploaded_file(uploaded_file)
data_source = save_uploaded_file()
st.session_state["data_source"] = data_source
update_chain()
delete_uploaded_file(uploaded_file)
delete_uploaded_file()
# container for chat history
response_container = st.container()
# container for text box
container = st.container()
text_container = st.container()
# As streamlit reruns the whole script on each change
# it is necessary to repopulate the chat containers
with container:
with text_container:
with st.form(key="prompt_input", clear_on_submit=True):
user_input = st.text_area("You:", key="input", height=100)
submit_button = st.form_submit_button(label="Send")
if submit_button and user_input:
output = generate_response(user_input)
st.session_state["past"].append(user_input)
st.session_state["generated"].append(output)
if submit_button and user_input:
output = generate_response(user_input)
st.session_state["past"].append(user_input)
st.session_state["generated"].append(output)
if st.session_state["generated"]:
with response_container:

@ -2,11 +2,15 @@ from pathlib import Path
APP_NAME = "DataChad"
MODEL = "gpt-3.5-turbo"
EMBEDDINGS = "openai"
PAGE_ICON = "🤖"
PROJECT_URL = "https://github.com/gustavz/DataChad"
K = 10
FETCH_K = 20
CHUNK_SIZE = 1000
CHUNK_OVERLAP = 0
TEMPERATURE = 0.7
MAX_TOKENS = 3357
ENABLE_ADVANCED_OPTIONS = True
@ -14,12 +18,10 @@ ENABLE_ADVANCED_OPTIONS = True
DATA_PATH = Path.cwd() / "data"
DEFAULT_DATA_SOURCE = "https://github.com/gustavz/DataChad.git"
REPO_URL = "https://github.com/gustavz/DataChad"
AUTHENTICATION_HELP = f"""
Your credentials are only stored in your session state.\n
The keys are neither exposed nor made visible or stored permanently in any way.\n
Feel free to check out [the code base]({REPO_URL}) to validate how things work.
Feel free to check out [the code base]({PROJECT_URL}) to validate how things work.
"""
USAGE_HELP = f"""

@ -9,41 +9,36 @@ import deeplake
import openai
import streamlit as st
from dotenv import load_dotenv
from langchain.base_language import BaseLanguageModel
from langchain.callbacks import OpenAICallbackHandler, get_openai_callback
from langchain.chains import ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import (
CSVLoader,
DirectoryLoader,
EverNoteLoader,
GitLoader,
NotebookLoader,
OnlinePDFLoader,
PDFMinerLoader,
PythonLoader,
TextLoader,
UnstructuredEPubLoader,
UnstructuredFileLoader,
UnstructuredHTMLLoader,
UnstructuredPDFLoader,
UnstructuredMarkdownLoader,
UnstructuredODTLoader,
UnstructuredPowerPointLoader,
UnstructuredWordDocumentLoader,
WebBaseLoader,
)
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.document_loaders.base import BaseLoader
from langchain.embeddings.openai import Embeddings, OpenAIEmbeddings
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import DeepLake, VectorStore
from streamlit.runtime.uploaded_file_manager import UploadedFile
from constants import (
APP_NAME,
CHUNK_SIZE,
DATA_PATH,
FETCH_K,
MAX_TOKENS,
MODEL,
PAGE_ICON,
REPO_URL,
TEMPERATURE,
K,
)
from constants import APP_NAME, DATA_PATH, PAGE_ICON, PROJECT_URL
# loads environment variables
load_dotenv()
@ -116,66 +111,10 @@ def authenticate(
logger.info("Authentification successful!")
def advanced_options_form() -> None:
# Input Form that takes advanced options and rebuilds chain with them
advanced_options = st.checkbox(
"Advanced Options", help="Caution! This may break things!"
)
if advanced_options:
with st.form("advanced_options"):
temperature = st.slider(
"temperature",
min_value=0.0,
max_value=1.0,
value=TEMPERATURE,
help="Controls the randomness of the language model output",
)
col1, col2 = st.columns(2)
fetch_k = col1.number_input(
"k_fetch",
min_value=1,
max_value=1000,
value=FETCH_K,
help="The number of documents to pull from the vector database",
)
k = col2.number_input(
"k",
min_value=1,
max_value=100,
value=K,
help="The number of most similar documents to build the context from",
)
chunk_size = col1.number_input(
"chunk_size",
min_value=1,
max_value=100000,
value=CHUNK_SIZE,
help=(
"The size at which the text is divided into smaller chunks "
"before being embedded.\n\nChanging this parameter makes re-embedding "
"and re-uploading the data to the database necessary "
),
)
max_tokens = col2.number_input(
"max_tokens",
min_value=1,
max_value=4069,
value=MAX_TOKENS,
help="Limits the documents returned from database based on number of tokens",
)
applied = st.form_submit_button("Apply")
if applied:
st.session_state["k"] = k
st.session_state["fetch_k"] = fetch_k
st.session_state["chunk_size"] = chunk_size
st.session_state["temperature"] = temperature
st.session_state["max_tokens"] = max_tokens
update_chain()
def save_uploaded_file(uploaded_file: UploadedFile) -> str:
def save_uploaded_file() -> str:
# streamlit uploaded files need to be stored locally
# before embedded and uploaded to the hub
uploaded_file = st.session_state["uploaded_file"]
if not os.path.exists(DATA_PATH):
os.makedirs(DATA_PATH)
file_path = str(DATA_PATH / uploaded_file.name)
@ -188,128 +127,162 @@ def save_uploaded_file(uploaded_file: UploadedFile) -> str:
return file_path
def delete_uploaded_file(uploaded_file: UploadedFile) -> None:
def delete_uploaded_file() -> None:
# cleanup locally stored files
file_path = DATA_PATH / uploaded_file.name
file_path = DATA_PATH / st.session_state["uploaded_file"].name
if os.path.exists(DATA_PATH):
os.remove(file_path)
logger.info(f"Removed: {file_path}")
def handle_load_error(e: str = None) -> None:
error_msg = f"Failed to load '{st.session_state['data_source']}':\n\n{e}"
st.error(error_msg, icon=PAGE_ICON)
logger.error(error_msg)
st.stop()
def load_git(data_source: str, chunk_size: int = CHUNK_SIZE) -> List[Document]:
# We need to try both common main branches
# Thank you github for the "master" to "main" switch
# we need to make sure the data path exists
if not os.path.exists(DATA_PATH):
os.makedirs(DATA_PATH)
repo_name = data_source.split("/")[-1].split(".")[0]
repo_path = str(DATA_PATH / repo_name)
clone_url = data_source
if os.path.exists(repo_path):
clone_url = None
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=0
)
branches = ["main", "master"]
for branch in branches:
class AutoGitLoader:
def __init__(self, data_source: str) -> None:
self.data_source = data_source
def load(self) -> List[Document]:
# We need to try both common main branches
# Thank you github for the "master" to "main" switch
# we need to make sure the data path exists
if not os.path.exists(DATA_PATH):
os.makedirs(DATA_PATH)
repo_name = self.data_source.split("/")[-1].split(".")[0]
repo_path = str(DATA_PATH / repo_name)
clone_url = self.data_source
if os.path.exists(repo_path):
clone_url = None
branches = ["main", "master"]
for branch in branches:
try:
docs = GitLoader(repo_path, clone_url, branch).load()
break
except Exception as e:
logger.error(f"Error loading git: {e}")
if os.path.exists(repo_path):
# cleanup repo afterwards
shutil.rmtree(repo_path)
try:
docs = GitLoader(repo_path, clone_url, branch).load_and_split(text_splitter)
break
except Exception as e:
logger.error(f"Error loading git: {e}")
if os.path.exists(repo_path):
# cleanup repo afterwards
shutil.rmtree(repo_path)
try:
return docs
except:
msg = "Make sure to use HTTPS git repo links"
handle_load_error(msg)
return docs
except:
raise RuntimeError("Make sure to use HTTPS GitHub repo links")
FILE_LOADER_MAPPING = {
".csv": (CSVLoader, {"encoding": "utf-8"}),
".doc": (UnstructuredWordDocumentLoader, {}),
".docx": (UnstructuredWordDocumentLoader, {}),
".enex": (EverNoteLoader, {}),
".epub": (UnstructuredEPubLoader, {}),
".html": (UnstructuredHTMLLoader, {}),
".md": (UnstructuredMarkdownLoader, {}),
".odt": (UnstructuredODTLoader, {}),
".pdf": (PDFMinerLoader, {}),
".ppt": (UnstructuredPowerPointLoader, {}),
".pptx": (UnstructuredPowerPointLoader, {}),
".txt": (TextLoader, {"encoding": "utf8"}),
".ipynb": (NotebookLoader, {}),
".py": (PythonLoader, {}),
# Add more mappings for other file extensions and loaders as needed
}
WEB_LOADER_MAPPING = {
".git": (AutoGitLoader, {}),
".pdf": (OnlinePDFLoader, {}),
}
def get_loader(file_path: str, mapping: dict, default_loader:BaseLoader) -> BaseLoader:
# Choose loader from mapping, load default if no match found
ext = "." + file_path.rsplit(".", 1)[-1]
if ext in mapping:
loader_class, loader_args = mapping[ext]
loader = loader_class(file_path, **loader_args)
else:
loader = default_loader(file_path)
return loader
def load_any_data_source(
data_source: str, chunk_size: int = CHUNK_SIZE
) -> List[Document]:
def load_data_source() -> List[Document]:
# Ugly thing that decides how to load data
# It aint much, but it's honest work
is_text = data_source.endswith(".txt")
data_source = st.session_state["data_source"]
is_web = data_source.startswith("http")
is_pdf = data_source.endswith(".pdf")
is_csv = data_source.endswith("csv")
is_html = data_source.endswith(".html")
is_git = data_source.endswith(".git")
is_notebook = data_source.endswith(".ipynb")
is_doc = data_source.endswith(".doc")
is_py = data_source.endswith(".py")
is_dir = os.path.isdir(data_source)
is_file = os.path.isfile(data_source)
loader = None
if is_dir:
loader = DirectoryLoader(data_source, recursive=True, silent_errors=True)
elif is_git:
return load_git(data_source, chunk_size)
elif is_web:
if is_pdf:
loader = OnlinePDFLoader(data_source)
else:
loader = WebBaseLoader(data_source)
loader = get_loader(data_source, WEB_LOADER_MAPPING, WebBaseLoader)
elif is_file:
if is_text:
loader = TextLoader(data_source, encoding="utf-8")
elif is_notebook:
loader = NotebookLoader(data_source)
elif is_pdf:
loader = UnstructuredPDFLoader(data_source)
elif is_html:
loader = UnstructuredHTMLLoader(data_source)
elif is_doc:
loader = UnstructuredWordDocumentLoader(data_source)
elif is_csv:
loader = CSVLoader(data_source, encoding="utf-8")
elif is_py:
loader = PythonLoader(data_source)
else:
loader = UnstructuredFileLoader(data_source)
loader = get_loader(data_source, FILE_LOADER_MAPPING, UnstructuredFileLoader)
try:
# Chunk size is a major trade-off parameter to control result accuracy over computaion
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=0
chunk_size=st.session_state["chunk_size"],
chunk_overlap=st.session_state["chunk_overlap"],
)
docs = loader.load_and_split(text_splitter)
docs = loader.load()
docs = text_splitter.split_documents(docs)
logger.info(f"Loaded: {len(docs)} document chucks")
return docs
except Exception as e:
msg = (
e
if loader
else f"No Loader found for your data source. Consider contributing:  {REPO_URL}!"
else f"No Loader found for your data source. Consider contributing:  {PROJECT_URL}!"
)
handle_load_error(msg)
error_msg = f"Failed to load '{st.session_state['data_source']}':\n\n{msg}"
st.error(error_msg, icon=PAGE_ICON)
logger.error(error_msg)
st.stop()
def clean_data_source_string(data_source_string: str) -> str:
def get_data_source_string() -> str:
# replace all non-word characters with dashes
# to get a string that can be used to create a new dataset
dashed_string = re.sub(r"\W+", "-", data_source_string)
dashed_string = re.sub(r"\W+", "-", st.session_state["data_source"])
cleaned_string = re.sub(r"--+", "- ", dashed_string).strip("-")
return cleaned_string
def setup_vector_store(data_source: str, chunk_size: int = CHUNK_SIZE) -> VectorStore:
def get_model() -> BaseLanguageModel:
match st.session_state["model"]:
case "gpt-3.5-turbo":
model = ChatOpenAI(
model_name=st.session_state["model"],
temperature=st.session_state["temperature"],
openai_api_key=st.session_state["openai_api_key"],
)
# Add more models as needed
case _default:
msg = f"Model {st.session_state['model']} not supported!"
logger.error(msg)
st.error(msg)
exit
return model
def get_embeddings() -> Embeddings:
match st.session_state["embeddings"]:
case "openai":
embeddings = OpenAIEmbeddings(
disallowed_special=(), openai_api_key=st.session_state["openai_api_key"]
)
# Add more embeddings as needed
case _default:
msg = f"Embeddings {st.session_state['embeddings']} not supported!"
logger.error(msg)
st.error(msg)
exit
return embeddings
def get_vector_store() -> VectorStore:
# either load existing vector store or upload a new one to the hub
embeddings = OpenAIEmbeddings(
disallowed_special=(), openai_api_key=st.session_state["openai_api_key"]
)
data_source_name = clean_data_source_string(data_source)
dataset_path = f"hub://{st.session_state['activeloop_org_name']}/{data_source_name}-{chunk_size}"
embeddings = get_embeddings()
data_source_name = get_data_source_string()
dataset_path = f"hub://{st.session_state['activeloop_org_name']}/{data_source_name}-{st.session_state['chunk_size']}"
if deeplake.exists(dataset_path, token=st.session_state["activeloop_token"]):
with st.spinner("Loading vector store..."):
logger.info(f"Dataset '{dataset_path}' exists -> loading")
@ -322,7 +295,7 @@ def setup_vector_store(data_source: str, chunk_size: int = CHUNK_SIZE) -> Vector
else:
with st.spinner("Reading, embedding and uploading data to hub..."):
logger.info(f"Dataset '{dataset_path}' does not exist -> uploading")
docs = load_any_data_source(data_source, chunk_size)
docs = load_data_source()
vector_store = DeepLake.from_documents(
docs,
embeddings,
@ -332,16 +305,9 @@ def setup_vector_store(data_source: str, chunk_size: int = CHUNK_SIZE) -> Vector
return vector_store
def build_chain(
data_source: str,
k: int = K,
fetch_k: int = FETCH_K,
chunk_size: int = CHUNK_SIZE,
temperature: float = TEMPERATURE,
max_tokens: int = MAX_TOKENS,
) -> ConversationalRetrievalChain:
def get_chain() -> ConversationalRetrievalChain:
# create the langchain that will be called to generate responses
vector_store = setup_vector_store(data_source, chunk_size)
vector_store = get_vector_store()
retriever = vector_store.as_retriever()
# Search params "fetch_k" and "k" define how many documents are pulled from the hub
# and selected after the document matching to build the context
@ -349,15 +315,11 @@ def build_chain(
search_kwargs = {
"maximal_marginal_relevance": True,
"distance_metric": "cos",
"fetch_k": fetch_k,
"k": k,
"fetch_k": st.session_state["fetch_k"],
"k": st.session_state["k"],
}
retriever.search_kwargs.update(search_kwargs)
model = ChatOpenAI(
model_name=MODEL,
temperature=temperature,
openai_api_key=st.session_state["openai_api_key"],
)
model = get_model()
chain = ConversationalRetrievalChain.from_llm(
model,
retriever=retriever,
@ -365,9 +327,8 @@ def build_chain(
verbose=True,
# we limit the maximum number of used tokens
# to prevent running into the models token limit of 4096
max_tokens_limit=max_tokens,
max_tokens_limit=st.session_state["max_tokens"],
)
logger.info(f"Data source '{data_source}' is ready to go!")
return chain
@ -375,15 +336,11 @@ def update_chain() -> None:
# Build chain with parameters from session state and store it back
# Also delete chat history to not confuse the bot with old context
try:
st.session_state["chain"] = build_chain(
data_source=st.session_state["data_source"],
k=st.session_state["k"],
fetch_k=st.session_state["fetch_k"],
chunk_size=st.session_state["chunk_size"],
temperature=st.session_state["temperature"],
max_tokens=st.session_state["max_tokens"],
)
st.session_state["chain"] = get_chain()
st.session_state["chat_history"] = []
msg = f"Data source '{st.session_state['data_source']}' is ready to go!"
logger.info(msg)
st.info(msg, icon=PAGE_ICON)
except Exception as e:
msg = f"Failed to build chain for data source '{st.session_state['data_source']}' with error: {e}"
logger.error(msg)

Loading…
Cancel
Save