add some comments for the blog post

main
Gustav von Zitzewitz 1 year ago
parent 38e4a630ee
commit 20deaa79f4

@ -52,7 +52,8 @@ if "activeloop_token" not in st.session_state:
if "activeloop_org_name" not in st.session_state:
st.session_state["activeloop_org_name"] = None
# Sidebar
# Sidebar with Authentication
# Only start App if authentication is OK
with st.sidebar:
st.title("Authentication", help=AUTHENTICATION_HELP)
with st.form("authentication"):
@ -82,6 +83,7 @@ with st.sidebar:
if not st.session_state["auth_ok"]:
st.stop()
# Clear button to reset all chat communication
clear_button = st.button("Clear Conversation", key="clear")
@ -90,7 +92,7 @@ if "chain" not in st.session_state:
build_chain_and_clear_history(DEFAULT_DATA_SOURCE)
if clear_button:
# reset chat history
# resets all chat history related caches
st.session_state["past"] = []
st.session_state["generated"] = []
st.session_state["chat_history"] = []
@ -121,6 +123,8 @@ response_container = st.container()
# container for text box
container = st.container()
# As streamlit reruns the whole script on each change
# it is necessary to repopulate the chat containers
with container:
with st.form(key="prompt_input", clear_on_submit=True):
user_input = st.text_area("You:", key="input", height=100)
@ -138,8 +142,8 @@ if st.session_state["generated"]:
message(st.session_state["generated"][i], key=str(i))
# Usage sidebar
# Put at the end to display even the first input
# Usage sidebar with total used tokens and costs
# We put this at the end to be able to show usage starting with the first response
with st.sidebar:
if st.session_state["usage"]:
st.divider()

@ -34,6 +34,7 @@ logger = logging.getLogger(APP_NAME)
def configure_logger(debug=0):
# boilerplate code to enable logging in the streamlit app console
log_level = logging.DEBUG if debug == 1 else logging.INFO
logger.setLevel(log_level)
@ -115,6 +116,7 @@ def delete_uploaded_file(uploaded_file):
def load_git(data_source):
# We need to try both common main branches
# Thank you github for the "master" to "main" switch
repo_name = data_source.split("/")[-1].split(".")[0]
repo_path = str(DATA_PATH / repo_name)
@ -137,7 +139,8 @@ def load_git(data_source):
def load_any_data_source(data_source):
# ugly thing that decides how to load data
# Ugly thing that decides how to load data
# It aint much, but it's honest work
is_text = data_source.endswith(".txt")
is_web = data_source.startswith("http")
is_pdf = data_source.endswith(".pdf")
@ -178,6 +181,7 @@ def load_any_data_source(data_source):
else:
loader = UnstructuredFileLoader(data_source)
if loader:
# Chunk size is a major trade-off parameter to control result accuracy over computaion
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = loader.load_and_split(text_splitter)
logger.info(f"Loaded: {len(docs)} document chucks")
@ -233,10 +237,13 @@ def get_chain(data_source):
# create the langchain that will be called to generate responses
vector_store = setup_vector_store(data_source)
retriever = vector_store.as_retriever()
# Search params "fetch_k" and "k" define how many documents are pulled from the hub
# and selected after the document matching to build the context
# that is fed to the model together with your prompt
search_kwargs = {
"maximal_marginal_relevance": True,
"distance_metric": "cos",
"fetch_k": 20,
"maximal_marginal_relevance": True,
"k": 10,
}
retriever.search_kwargs.update(search_kwargs)
@ -249,6 +256,8 @@ def get_chain(data_source):
retriever=retriever,
chain_type="stuff",
verbose=True,
# we limit the maximum number of used tokens
# to prevent running into the models token limit of 4096
max_tokens_limit=3375,
)
logger.info(f"Data source '{data_source}' is ready to go!")

Loading…
Cancel
Save