|
|
|
@ -3,12 +3,13 @@ import os
|
|
|
|
|
import re
|
|
|
|
|
import shutil
|
|
|
|
|
import sys
|
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
|
|
import deeplake
|
|
|
|
|
import openai
|
|
|
|
|
import streamlit as st
|
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
|
from langchain.callbacks import get_openai_callback
|
|
|
|
|
from langchain.callbacks import OpenAICallbackHandler, get_openai_callback
|
|
|
|
|
from langchain.chains import ConversationalRetrievalChain
|
|
|
|
|
from langchain.chat_models import ChatOpenAI
|
|
|
|
|
from langchain.document_loaders import (
|
|
|
|
@ -26,8 +27,10 @@ from langchain.document_loaders import (
|
|
|
|
|
WebBaseLoader,
|
|
|
|
|
)
|
|
|
|
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
|
|
|
|
from langchain.schema import Document
|
|
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
|
|
from langchain.vectorstores import DeepLake
|
|
|
|
|
from langchain.vectorstores import DeepLake, VectorStore
|
|
|
|
|
from streamlit.runtime.uploaded_file_manager import UploadedFile
|
|
|
|
|
|
|
|
|
|
from constants import (
|
|
|
|
|
APP_NAME,
|
|
|
|
@ -47,7 +50,7 @@ load_dotenv()
|
|
|
|
|
logger = logging.getLogger(APP_NAME)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def configure_logger(debug=0):
|
|
|
|
|
def configure_logger(debug: int = 0) -> None:
|
|
|
|
|
# boilerplate code to enable logging in the streamlit app console
|
|
|
|
|
log_level = logging.DEBUG if debug == 1 else logging.INFO
|
|
|
|
|
logger.setLevel(log_level)
|
|
|
|
@ -66,7 +69,9 @@ def configure_logger(debug=0):
|
|
|
|
|
configure_logger(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def authenticate(openai_api_key, activeloop_token, activeloop_org_name):
|
|
|
|
|
def authenticate(
|
|
|
|
|
openai_api_key: str, activeloop_token: str, activeloop_org_name: str
|
|
|
|
|
) -> None:
|
|
|
|
|
# Validate all credentials are set and correct
|
|
|
|
|
# Check for env variables to enable local dev and deployments with shared credentials
|
|
|
|
|
openai_api_key = (
|
|
|
|
@ -110,7 +115,7 @@ def authenticate(openai_api_key, activeloop_token, activeloop_org_name):
|
|
|
|
|
logger.info("Authentification successful!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def handle_advanced_options():
|
|
|
|
|
def advanced_options_form() -> None:
|
|
|
|
|
# Input Form that takes advanced options and rebuilds chain with them
|
|
|
|
|
advanced_options = st.checkbox(
|
|
|
|
|
"Advanced Options", help="Caution! This may break things!"
|
|
|
|
@ -128,7 +133,7 @@ def handle_advanced_options():
|
|
|
|
|
fetch_k = col1.number_input(
|
|
|
|
|
"k_fetch",
|
|
|
|
|
min_value=1,
|
|
|
|
|
max_value=100,
|
|
|
|
|
max_value=1000,
|
|
|
|
|
value=FETCH_K,
|
|
|
|
|
help="The number of documents to pull from the vector database",
|
|
|
|
|
)
|
|
|
|
@ -167,7 +172,7 @@ def handle_advanced_options():
|
|
|
|
|
update_chain()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_uploaded_file(uploaded_file):
|
|
|
|
|
def save_uploaded_file(uploaded_file: UploadedFile) -> str:
|
|
|
|
|
# streamlit uploaded files need to be stored locally
|
|
|
|
|
# before embedded and uploaded to the hub
|
|
|
|
|
if not os.path.exists(DATA_PATH):
|
|
|
|
@ -182,7 +187,7 @@ def save_uploaded_file(uploaded_file):
|
|
|
|
|
return file_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def delete_uploaded_file(uploaded_file):
|
|
|
|
|
def delete_uploaded_file(uploaded_file: UploadedFile) -> None:
|
|
|
|
|
# cleanup locally stored files
|
|
|
|
|
file_path = DATA_PATH / uploaded_file.name
|
|
|
|
|
if os.path.exists(DATA_PATH):
|
|
|
|
@ -190,7 +195,7 @@ def delete_uploaded_file(uploaded_file):
|
|
|
|
|
logger.info(f"Removed: {file_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_git(data_source, chunk_size=CHUNK_SIZE):
|
|
|
|
|
def load_git(data_source: str, chunk_size: int = CHUNK_SIZE) -> List[Document]:
|
|
|
|
|
# We need to try both common main branches
|
|
|
|
|
# Thank you github for the "master" to "main" switch
|
|
|
|
|
repo_name = data_source.split("/")[-1].split(".")[0]
|
|
|
|
@ -215,7 +220,9 @@ def load_git(data_source, chunk_size=CHUNK_SIZE):
|
|
|
|
|
return docs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_any_data_source(data_source, chunk_size=CHUNK_SIZE):
|
|
|
|
|
def load_any_data_source(
|
|
|
|
|
data_source: str, chunk_size: int = CHUNK_SIZE
|
|
|
|
|
) -> List[Document]:
|
|
|
|
|
# Ugly thing that decides how to load data
|
|
|
|
|
# It aint much, but it's honest work
|
|
|
|
|
is_text = data_source.endswith(".txt")
|
|
|
|
@ -272,7 +279,7 @@ def load_any_data_source(data_source, chunk_size=CHUNK_SIZE):
|
|
|
|
|
st.stop()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_data_source_string(data_source_string):
|
|
|
|
|
def clean_data_source_string(data_source_string: str) -> str:
|
|
|
|
|
# replace all non-word characters with dashes
|
|
|
|
|
# to get a string that can be used to create a new dataset
|
|
|
|
|
dashed_string = re.sub(r"\W+", "-", data_source_string)
|
|
|
|
@ -280,7 +287,7 @@ def clean_data_source_string(data_source_string):
|
|
|
|
|
return cleaned_string
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_vector_store(data_source, chunk_size=CHUNK_SIZE):
|
|
|
|
|
def setup_vector_store(data_source: str, chunk_size: int = CHUNK_SIZE) -> VectorStore:
|
|
|
|
|
# either load existing vector store or upload a new one to the hub
|
|
|
|
|
embeddings = OpenAIEmbeddings(
|
|
|
|
|
disallowed_special=(), openai_api_key=st.session_state["openai_api_key"]
|
|
|
|
@ -310,13 +317,13 @@ def setup_vector_store(data_source, chunk_size=CHUNK_SIZE):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_chain(
|
|
|
|
|
data_source,
|
|
|
|
|
k=K,
|
|
|
|
|
fetch_k=FETCH_K,
|
|
|
|
|
chunk_size=CHUNK_SIZE,
|
|
|
|
|
temperature=TEMPERATURE,
|
|
|
|
|
max_tokens=MAX_TOKENS,
|
|
|
|
|
):
|
|
|
|
|
data_source: str,
|
|
|
|
|
k: int = K,
|
|
|
|
|
fetch_k: int = FETCH_K,
|
|
|
|
|
chunk_size: int = CHUNK_SIZE,
|
|
|
|
|
temperature: float = TEMPERATURE,
|
|
|
|
|
max_tokens: int = MAX_TOKENS,
|
|
|
|
|
) -> ConversationalRetrievalChain:
|
|
|
|
|
# create the langchain that will be called to generate responses
|
|
|
|
|
vector_store = setup_vector_store(data_source, chunk_size)
|
|
|
|
|
retriever = vector_store.as_retriever()
|
|
|
|
@ -348,8 +355,8 @@ def build_chain(
|
|
|
|
|
return chain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_chain():
|
|
|
|
|
# Build chain with parameters from session state and store it there
|
|
|
|
|
def update_chain() -> None:
|
|
|
|
|
# Build chain with parameters from session state and store it back
|
|
|
|
|
# Also delete chat history to not confuse the bot with old context
|
|
|
|
|
st.session_state["chain"] = build_chain(
|
|
|
|
|
data_source=st.session_state["data_source"],
|
|
|
|
@ -362,7 +369,7 @@ def update_chain():
|
|
|
|
|
st.session_state["chat_history"] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_usage(cb):
|
|
|
|
|
def update_usage(cb: OpenAICallbackHandler) -> None:
|
|
|
|
|
# Accumulate API call usage via callbacks
|
|
|
|
|
logger.info(f"Usage: {cb}")
|
|
|
|
|
callback_properties = [
|
|
|
|
@ -377,7 +384,7 @@ def update_usage(cb):
|
|
|
|
|
st.session_state["usage"][prop] += value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_response(prompt):
|
|
|
|
|
def generate_response(prompt: str) -> str:
|
|
|
|
|
# call the chain to generate responses and add them to the chat history
|
|
|
|
|
with st.spinner("Generating response"), get_openai_callback() as cb:
|
|
|
|
|
response = st.session_state["chain"](
|
|
|
|
|