Add advanced options

main
Gustav von Zitzewitz 1 year ago
parent 3ce0e8a292
commit 8a9afae083

@ -5,23 +5,29 @@ from constants import (
ACTIVELOOP_HELP,
APP_NAME,
AUTHENTICATION_HELP,
CHUNK_SIZE,
DEFAULT_DATA_SOURCE,
ENABLE_ADVANCED_OPTIONS,
FETCH_K,
MAX_TOKENS,
OPENAI_HELP,
PAGE_ICON,
REPO_URL,
TEMPERATURE,
USAGE_HELP,
K,
)
from utils import (
authenticate,
build_chain_and_clear_history,
delete_uploaded_file,
generate_response,
logger,
save_uploaded_file,
update_chain,
)
# Page options and header
st.set_option("client.showErrorDetails", False)
st.set_option("client.showErrorDetails", True)
st.set_page_config(
page_title=APP_NAME, page_icon=PAGE_ICON, initial_sidebar_state="expanded"
)
@ -31,26 +37,40 @@ st.markdown(
)
# Initialise session state variables
# Chat and Data Source
if "past" not in st.session_state:
st.session_state["past"] = []
if "usage" not in st.session_state:
st.session_state["usage"] = {}
if "generated" not in st.session_state:
st.session_state["generated"] = []
if "auth_ok" not in st.session_state:
st.session_state["auth_ok"] = False
if "chat_history" not in st.session_state:
st.session_state["chat_history"] = []
if "generated" not in st.session_state:
st.session_state["generated"] = []
if "data_source" not in st.session_state:
st.session_state["data_source"] = ""
st.session_state["data_source"] = DEFAULT_DATA_SOURCE
if "uploaded_file" not in st.session_state:
st.session_state["uploaded_file"] = None
# Authentication and Credentials
if "auth_ok" not in st.session_state:
st.session_state["auth_ok"] = False
if "openai_api_key" not in st.session_state:
st.session_state["openai_api_key"] = None
if "activeloop_token" not in st.session_state:
st.session_state["activeloop_token"] = None
if "activeloop_org_name" not in st.session_state:
st.session_state["activeloop_org_name"] = None
# Advanced Options
if "k" not in st.session_state:
st.session_state["k"] = K
if "fetch_k" not in st.session_state:
st.session_state["fetch_k"] = FETCH_K
if "chunk_size" not in st.session_state:
st.session_state["chunk_size"] = CHUNK_SIZE
if "temperature" not in st.session_state:
st.session_state["temperature"] = TEMPERATURE
if "max_tokens" not in st.session_state:
st.session_state["max_tokens"] = MAX_TOKENS
# Sidebar with Authentication
# Only start App if authentication is OK
@ -86,10 +106,62 @@ with st.sidebar:
# Clear button to reset all chat communication
clear_button = st.button("Clear Conversation", key="clear")
# Advanced Options
if ENABLE_ADVANCED_OPTIONS:
advanced_options = st.checkbox(
"Advanced Options", help="Caution! This may break things!"
)
if advanced_options:
with st.form("advanced_options"):
temperature = st.slider(
"temperature",
min_value=0.0,
max_value=1.0,
value=TEMPERATURE,
help="Controls the randomness of the language model output",
)
col1, col2 = st.columns(2)
fetch_k = col1.number_input(
"k_fetch",
min_value=1,
max_value=100,
value=FETCH_K,
help="The number of documents to pull from the vector database",
)
k = col2.number_input(
"k",
min_value=1,
max_value=100,
value=K,
help="The number of most similar documents returned from the vector store",
)
chunk_size = col1.number_input(
"chunk_size",
min_value=1,
max_value=100000,
value=CHUNK_SIZE,
help="The size at which the text is divided into smaller chunks before being embedded",
)
max_tokens = col2.number_input(
"max_tokens",
min_value=1,
max_value=4069,
value=MAX_TOKENS,
help="Limits the documents returned from database based on tokens",
)
applied = st.form_submit_button("Apply")
if applied:
st.session_state["k"] = k
st.session_state["fetch_k"] = fetch_k
st.session_state["chunk_size"] = chunk_size
st.session_state["temperature"] = temperature
st.session_state["max_tokens"] = max_tokens
update_chain()
# the chain can only be initialized after authentication is OK
if "chain" not in st.session_state:
build_chain_and_clear_history(DEFAULT_DATA_SOURCE)
update_chain()
if clear_button:
# resets all chat history related caches
@ -108,15 +180,16 @@ data_source = st.text_input(
# make sure to do this only once per input / on change
if data_source and data_source != st.session_state["data_source"]:
logger.info(f"Data source provided: '{data_source}'")
build_chain_and_clear_history(data_source)
st.session_state["data_source"] = data_source
update_chain()
if uploaded_file and uploaded_file != st.session_state["uploaded_file"]:
logger.info(f"Uploaded file: '{uploaded_file.name}'")
st.session_state["uploaded_file"] = uploaded_file
data_source = save_uploaded_file(uploaded_file)
build_chain_and_clear_history(data_source)
st.session_state["data_source"] = data_source
update_chain()
delete_uploaded_file(uploaded_file)
st.session_state["uploaded_file"] = uploaded_file
# container for chat history
response_container = st.container()

@ -4,6 +4,13 @@ APP_NAME = "DataChad"
MODEL = "gpt-3.5-turbo"
PAGE_ICON = "🤖"
K = 10
FETCH_K = 20
CHUNK_SIZE = 1000
TEMPERATURE = 0.7
MAX_TOKENS = 3357
ENABLE_ADVANCED_OPTIONS = True
DATA_PATH = Path.cwd() / "data"
DEFAULT_DATA_SOURCE = "git@github.com:gustavz/DataChad.git"

@ -29,7 +29,17 @@ from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import DeepLake
from constants import APP_NAME, DATA_PATH, MODEL, PAGE_ICON
from constants import (
APP_NAME,
CHUNK_SIZE,
DATA_PATH,
FETCH_K,
MAX_TOKENS,
MODEL,
PAGE_ICON,
TEMPERATURE,
K,
)
# loads environment variables
load_dotenv()
@ -123,12 +133,14 @@ def delete_uploaded_file(uploaded_file):
logger.info(f"Removed: {file_path}")
def load_git(data_source):
def load_git(data_source, chunk_size=CHUNK_SIZE):
# We need to try both common main branches
# Thank you github for the "master" to "main" switch
repo_name = data_source.split("/")[-1].split(".")[0]
repo_path = str(DATA_PATH / repo_name)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=0
)
branches = ["main", "master"]
for branch in branches:
if os.path.exists(repo_path):
@ -146,7 +158,7 @@ def load_git(data_source):
return docs
def load_any_data_source(data_source):
def load_any_data_source(data_source, chunk_size=CHUNK_SIZE):
# Ugly thing that decides how to load data
# It aint much, but it's honest work
is_text = data_source.endswith(".txt")
@ -165,7 +177,7 @@ def load_any_data_source(data_source):
if is_dir:
loader = DirectoryLoader(data_source, recursive=True, silent_errors=True)
elif is_git:
return load_git(data_source)
return load_git(data_source, chunk_size)
elif is_web:
if is_pdf:
loader = OnlinePDFLoader(data_source)
@ -190,7 +202,9 @@ def load_any_data_source(data_source):
loader = UnstructuredFileLoader(data_source)
if loader:
# Chunk size is a major trade-off parameter to control result accuracy over computaion
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=0
)
docs = loader.load_and_split(text_splitter)
logger.info(f"Loaded: {len(docs)} document chucks")
return docs
@ -201,21 +215,21 @@ def load_any_data_source(data_source):
st.stop()
def clean_data_source_string(data_source):
def clean_data_source_string(data_source_string):
# replace all non-word characters with dashes
# to get a string that can be used to create a new dataset
dashed_string = re.sub(r"\W+", "-", data_source)
dashed_string = re.sub(r"\W+", "-", data_source_string)
cleaned_string = re.sub(r"--+", "- ", dashed_string).strip("-")
return cleaned_string
def setup_vector_store(data_source):
def setup_vector_store(data_source, chunk_size=CHUNK_SIZE):
# either load existing vector store or upload a new one to the hub
embeddings = OpenAIEmbeddings(
disallowed_special=(), openai_api_key=st.session_state["openai_api_key"]
)
data_source_name = clean_data_source_string(data_source)
dataset_path = f"hub://{st.session_state['activeloop_org_name']}/{data_source_name}"
dataset_path = f"hub://{st.session_state['activeloop_org_name']}/{data_source_name}-{chunk_size}"
if deeplake.exists(dataset_path, token=st.session_state["activeloop_token"]):
with st.spinner("Loading vector store..."):
logger.info(f"Dataset '{dataset_path}' exists -> loading")
@ -226,24 +240,28 @@ def setup_vector_store(data_source):
token=st.session_state["activeloop_token"],
)
else:
with st.spinner(
"Reading, embedding and uploading data to hub..."
), get_openai_callback() as cb:
with st.spinner("Reading, embedding and uploading data to hub..."):
logger.info(f"Dataset '{dataset_path}' does not exist -> uploading")
docs = load_any_data_source(data_source)
docs = load_any_data_source(data_source, chunk_size)
vector_store = DeepLake.from_documents(
docs,
embeddings,
dataset_path=f"hub://{st.session_state['activeloop_org_name']}/{data_source_name}",
dataset_path=dataset_path,
token=st.session_state["activeloop_token"],
)
update_usage(cb)
return vector_store
def get_chain(data_source):
def build_chain(
data_source,
k=K,
fetch_k=FETCH_K,
chunk_size=CHUNK_SIZE,
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
):
# create the langchain that will be called to generate responses
vector_store = setup_vector_store(data_source)
vector_store = setup_vector_store(data_source, chunk_size)
retriever = vector_store.as_retriever()
# Search params "fetch_k" and "k" define how many documents are pulled from the hub
# and selected after the document matching to build the context
@ -251,31 +269,39 @@ def get_chain(data_source):
search_kwargs = {
"maximal_marginal_relevance": True,
"distance_metric": "cos",
"fetch_k": 20,
"k": 10,
"fetch_k": fetch_k,
"k": k,
}
retriever.search_kwargs.update(search_kwargs)
model = ChatOpenAI(
model_name=MODEL, openai_api_key=st.session_state["openai_api_key"]
model_name=MODEL,
temperature=temperature,
openai_api_key=st.session_state["openai_api_key"],
)
with st.spinner("Building langchain..."):
chain = ConversationalRetrievalChain.from_llm(
model,
retriever=retriever,
chain_type="stuff",
verbose=True,
# we limit the maximum number of used tokens
# to prevent running into the models token limit of 4096
max_tokens_limit=3375,
)
logger.info(f"Data source '{data_source}' is ready to go!")
chain = ConversationalRetrievalChain.from_llm(
model,
retriever=retriever,
chain_type="stuff",
verbose=True,
# we limit the maximum number of used tokens
# to prevent running into the models token limit of 4096
max_tokens_limit=max_tokens,
)
logger.info(f"Data source '{data_source}' is ready to go!")
return chain
def build_chain_and_clear_history(data_source):
# Get chain and store it in the session state
def update_chain():
# Build chain with parameters from session state and store it there
# Also delete chat history to not confuse the bot with old context
st.session_state["chain"] = get_chain(data_source)
st.session_state["chain"] = build_chain(
data_source=st.session_state["data_source"],
k=st.session_state["k"],
fetch_k=st.session_state["fetch_k"],
chunk_size=st.session_state["chunk_size"],
temperature=st.session_state["temperature"],
max_tokens=st.session_state["max_tokens"],
)
st.session_state["chat_history"] = []

Loading…
Cancel
Save