refactor to enable mode and model selection

main
Gustav von Zitzewitz 1 year ago
parent bcd4395907
commit 56cd7e3ba5

@ -1,33 +1,33 @@
import streamlit as st
from streamlit_chat import message
from constants import (
from datachad.chain import generate_response, update_chain
from datachad.constants import (
ACTIVELOOP_HELP,
APP_NAME,
AUTHENTICATION_HELP,
CHUNK_OVERLAP,
CHUNK_SIZE,
DEFAULT_DATA_SOURCE,
EMBEDDINGS,
ENABLE_ADVANCED_OPTIONS,
ENABLE_LOCAL_MODE,
FETCH_K,
LOCAL_MODE_DISABLED_HELP,
MAX_TOKENS,
MODEL,
MODEL_N_CTX,
OPENAI_HELP,
PAGE_ICON,
PROJECT_URL,
TEMPERATURE,
USAGE_HELP,
MODEL_N_CTX,
K,
)
from utils import (
from datachad.models import MODELS, MODES
from datachad.utils import (
authenticate,
delete_uploaded_file,
generate_response,
logger,
save_uploaded_file,
update_chain,
)
# Page options and header
@ -52,8 +52,8 @@ SESSION_DEFAULTS = {
"activeloop_org_name": None,
"uploaded_file": None,
"data_source": DEFAULT_DATA_SOURCE,
"model": MODEL,
"embeddings": EMBEDDINGS,
"mode": MODES.OPENAI,
"model": MODELS.GPT35TURBO,
"k": K,
"fetch_k": FETCH_K,
"chunk_size": CHUNK_SIZE,
@ -72,7 +72,7 @@ def authentication_form() -> None:
st.title("Authentication", help=AUTHENTICATION_HELP)
with st.form("authentication"):
openai_api_key = st.text_input(
"OpenAI API Key",
f"{st.session_state['mode']} API Key",
type="password",
help=OPENAI_HELP,
placeholder="This field is mandatory",
@ -101,29 +101,34 @@ def advanced_options_form() -> None:
)
if advanced_options:
with st.form("advanced_options"):
temperature = st.slider(
col1, col2 = st.columns(2)
col1.selectbox("model", MODELS.for_mode(st.session_state["mode"]))
col2.number_input(
"temperature",
min_value=0.0,
max_value=1.0,
value=TEMPERATURE,
help="Controls the randomness of the language model output",
key="temperature",
)
col1, col2 = st.columns(2)
fetch_k = col1.number_input(
col1.number_input(
"k_fetch",
min_value=1,
max_value=1000,
value=FETCH_K,
help="The number of documents to pull from the vector database",
key="k_fetch",
)
k = col2.number_input(
col2.number_input(
"k",
min_value=1,
max_value=100,
value=K,
help="The number of most similar documents to build the context from",
key="k",
)
chunk_size = col1.number_input(
col1.number_input(
"chunk_size",
min_value=1,
max_value=100000,
@ -133,35 +138,37 @@ def advanced_options_form() -> None:
"before being embedded.\n\nChanging this parameter makes re-embedding "
"and re-uploading the data to the database necessary "
),
key="chunk_size",
)
max_tokens = col2.number_input(
col2.number_input(
"max_tokens",
min_value=1,
max_value=4069,
max_value=30000,
value=MAX_TOKENS,
help="Limits the documents returned from database based on number of tokens",
key="max_tokens",
)
applied = st.form_submit_button("Apply")
if applied:
st.session_state["k"] = k
st.session_state["fetch_k"] = fetch_k
st.session_state["chunk_size"] = chunk_size
st.session_state["temperature"] = temperature
st.session_state["max_tokens"] = max_tokens
update_chain()
# Sidebar with Authentication and Advanced Options
with st.sidebar:
authentication_form()
mode = st.selectbox("Mode", MODES.values(), key="mode")
if mode == MODES.LOCAL and not ENABLE_LOCAL_MODE:
st.error(LOCAL_MODE_DISABLED_HELP, icon=PAGE_ICON)
st.stop()
if mode != MODES.LOCAL:
authentication_form()
st.info(f"Learn how it works [here]({PROJECT_URL})")
# Only start App if authentication is OK
if not st.session_state["auth_ok"]:
if not (st.session_state["auth_ok"] or mode == MODES.LOCAL):
st.stop()
# Clear button to reset all chat communication
clear_button = st.button("Clear Conversation", key="clear")
clear_button = st.button("Clear Conversation")
# Advanced Options
if ENABLE_ADVANCED_OPTIONS:
@ -214,6 +221,7 @@ with text_container:
submit_button = st.form_submit_button(label="Send")
if submit_button and user_input:
text_container.empty()
output = generate_response(user_input)
st.session_state["past"].append(user_input)
st.session_state["generated"].append(output)

@ -0,0 +1,77 @@
import streamlit as st
from langchain.callbacks import OpenAICallbackHandler, get_openai_callback
from langchain.chains import ConversationalRetrievalChain
from datachad.constants import PAGE_ICON
from datachad.database import get_vector_store
from datachad.models import get_model
from datachad.utils import logger
def get_chain() -> ConversationalRetrievalChain:
# create the langchain that will be called to generate responses
vector_store = get_vector_store()
retriever = vector_store.as_retriever()
# Search params "fetch_k" and "k" define how many documents are pulled from the hub
# and selected after the document matching to build the context
# that is fed to the model together with your prompt
search_kwargs = {
"maximal_marginal_relevance": True,
"distance_metric": "cos",
"fetch_k": st.session_state["fetch_k"],
"k": st.session_state["k"],
}
retriever.search_kwargs.update(search_kwargs)
model = get_model()
chain = ConversationalRetrievalChain.from_llm(
model,
retriever=retriever,
chain_type="stuff",
verbose=True,
# we limit the maximum number of used tokens
# to prevent running into the models token limit of 4096
max_tokens_limit=st.session_state["max_tokens"],
)
return chain
def update_chain() -> None:
# Build chain with parameters from session state and store it back
# Also delete chat history to not confuse the bot with old context
try:
st.session_state["chain"] = get_chain()
st.session_state["chat_history"] = []
msg = f"Data source '{st.session_state['data_source']}' is ready to go!"
logger.info(msg)
st.info(msg, icon=PAGE_ICON)
except Exception as e:
msg = f"Failed to build chain for data source '{st.session_state['data_source']}' with error: {e}"
logger.error(msg)
st.error(msg, icon=PAGE_ICON)
def update_usage(cb: OpenAICallbackHandler) -> None:
# Accumulate API call usage via callbacks
logger.info(f"Usage: {cb}")
callback_properties = [
"total_tokens",
"prompt_tokens",
"completion_tokens",
"total_cost",
]
for prop in callback_properties:
value = getattr(cb, prop, 0)
st.session_state["usage"].setdefault(prop, 0)
st.session_state["usage"][prop] += value
def generate_response(prompt: str) -> str:
# call the chain to generate responses and add them to the chat history
with st.spinner("Generating response"), get_openai_callback() as cb:
response = st.session_state["chain"](
{"question": prompt, "chat_history": st.session_state["chat_history"]}
)
update_usage(cb)
logger.info(f"Response: '{response}'")
st.session_state["chat_history"].append((prompt, response["answer"]))
return response["answer"]

@ -1,28 +1,31 @@
from pathlib import Path
APP_NAME = "DataChad"
MODEL = "gpt-3.5-turbo"
EMBEDDINGS = "openai"
PAGE_ICON = "🤖"
APP_NAME = "DataChad"
PROJECT_URL = "https://github.com/gustavz/DataChad"
K = 10
FETCH_K = 20
CHUNK_SIZE = 1000
CHUNK_OVERLAP = 0
TEMPERATURE = 0.7
MAX_TOKENS = 3357
MODEL_N_CTX = 1000
ENABLE_LOCAL_MODE = False
ENABLE_ADVANCED_OPTIONS = True
MODEL_N_CTX = 1000
LLAMACPP_MODEL_PATH = ""
GPT4ALL_MODEL_PATH = ""
ENABLE_LOCAL_MODELS = False
DATA_PATH = Path.cwd() / "data"
DEFAULT_DATA_SOURCE = "https://github.com/gustavz/DataChad.git"
LOCAL_MODE_DISABLED_HELP = """
This is a demo hosted with limited resources. Local Mode is not enabled.\n
To use Local Mode deploy the app on your machine of choice with ENABLE_LOCAL_MODE set to True.
"""
AUTHENTICATION_HELP = f"""
Your credentials are only stored in your session state.\n
The keys are neither exposed nor made visible or stored permanently in any way.\n
@ -31,7 +34,7 @@ Feel free to check out [the code base]({PROJECT_URL}) to validate how things wor
USAGE_HELP = f"""
These are the accumulated OpenAI API usage metrics.\n
The app uses '{MODEL}' for chat and 'text-embedding-ada-002' for embeddings.\n
The app uses 'gpt-3.5-turbo' for chat and 'text-embedding-ada-002' for embeddings.\n
Learn more about OpenAI's pricing [here](https://openai.com/pricing#language-models)
"""

@ -0,0 +1,51 @@
import os
import re
import deeplake
import streamlit as st
from langchain.vectorstores import DeepLake, VectorStore
from datachad.constants import DATA_PATH
from datachad.loader import load_data_source
from datachad.models import MODES, get_embeddings
from datachad.utils import logger
def get_dataset_path() -> str:
# replace all non-word characters with dashes
# to get a string that can be used to create a new dataset
dataset_name = re.sub(r"\W+", "-", st.session_state["data_source"])
dataset_name = re.sub(r"--+", "- ", dataset_name).strip("-")
if st.session_state["mode"] == MODES.LOCAL:
if not os.path.exists(DATA_PATH):
os.makedirs(DATA_PATH)
dataset_path = str(DATA_PATH / dataset_name)
else:
dataset_path = f"hub://{st.session_state['activeloop_org_name']}/{dataset_name}-{st.session_state['chunk_size']}"
return dataset_path
def get_vector_store() -> VectorStore:
# either load existing vector store or upload a new one to the hub
embeddings = get_embeddings()
dataset_path = get_dataset_path()
if deeplake.exists(dataset_path, token=st.session_state["activeloop_token"]):
with st.spinner("Loading vector store..."):
logger.info(f"Dataset '{dataset_path}' exists -> loading")
vector_store = DeepLake(
dataset_path=dataset_path,
read_only=True,
embedding_function=embeddings,
token=st.session_state["activeloop_token"],
)
else:
with st.spinner("Reading, embedding and uploading data to hub..."):
logger.info(f"Dataset '{dataset_path}' does not exist -> uploading")
docs = load_data_source()
vector_store = DeepLake.from_documents(
docs,
embeddings,
dataset_path=dataset_path,
token=st.session_state["activeloop_token"],
)
return vector_store

@ -0,0 +1,133 @@
import os
import shutil
from typing import List
import streamlit as st
from langchain.document_loaders import (
CSVLoader,
DirectoryLoader,
EverNoteLoader,
GitLoader,
NotebookLoader,
OnlinePDFLoader,
PDFMinerLoader,
PythonLoader,
TextLoader,
UnstructuredEPubLoader,
UnstructuredFileLoader,
UnstructuredHTMLLoader,
UnstructuredMarkdownLoader,
UnstructuredODTLoader,
UnstructuredPowerPointLoader,
UnstructuredWordDocumentLoader,
WebBaseLoader,
)
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from datachad.constants import DATA_PATH, PAGE_ICON, PROJECT_URL
from datachad.utils import logger
class AutoGitLoader:
def __init__(self, data_source: str) -> None:
self.data_source = data_source
def load(self) -> List[Document]:
# We need to try both common main branches
# Thank you github for the "master" to "main" switch
# we need to make sure the data path exists
if not os.path.exists(DATA_PATH):
os.makedirs(DATA_PATH)
repo_name = self.data_source.split("/")[-1].split(".")[0]
repo_path = str(DATA_PATH / repo_name)
clone_url = self.data_source
if os.path.exists(repo_path):
clone_url = None
branches = ["main", "master"]
for branch in branches:
try:
docs = GitLoader(repo_path, clone_url, branch).load()
break
except Exception as e:
logger.error(f"Error loading git: {e}")
if os.path.exists(repo_path):
# cleanup repo afterwards
shutil.rmtree(repo_path)
try:
return docs
except:
raise RuntimeError("Make sure to use HTTPS GitHub repo links")
FILE_LOADER_MAPPING = {
".csv": (CSVLoader, {"encoding": "utf-8"}),
".doc": (UnstructuredWordDocumentLoader, {}),
".docx": (UnstructuredWordDocumentLoader, {}),
".enex": (EverNoteLoader, {}),
".epub": (UnstructuredEPubLoader, {}),
".html": (UnstructuredHTMLLoader, {}),
".md": (UnstructuredMarkdownLoader, {}),
".odt": (UnstructuredODTLoader, {}),
".pdf": (PDFMinerLoader, {}),
".ppt": (UnstructuredPowerPointLoader, {}),
".pptx": (UnstructuredPowerPointLoader, {}),
".txt": (TextLoader, {"encoding": "utf8"}),
".ipynb": (NotebookLoader, {}),
".py": (PythonLoader, {}),
# Add more mappings for other file extensions and loaders as needed
}
WEB_LOADER_MAPPING = {
".git": (AutoGitLoader, {}),
".pdf": (OnlinePDFLoader, {}),
}
def get_loader(file_path: str, mapping: dict, default_loader: BaseLoader) -> BaseLoader:
# Choose loader from mapping, load default if no match found
ext = "." + file_path.rsplit(".", 1)[-1]
if ext in mapping:
loader_class, loader_args = mapping[ext]
loader = loader_class(file_path, **loader_args)
else:
loader = default_loader(file_path)
return loader
def load_data_source() -> List[Document]:
# Ugly thing that decides how to load data
# It aint much, but it's honest work
data_source = st.session_state["data_source"]
is_web = data_source.startswith("http")
is_dir = os.path.isdir(data_source)
is_file = os.path.isfile(data_source)
loader = None
if is_dir:
loader = DirectoryLoader(data_source, recursive=True, silent_errors=True)
elif is_web:
loader = get_loader(data_source, WEB_LOADER_MAPPING, WebBaseLoader)
elif is_file:
loader = get_loader(data_source, FILE_LOADER_MAPPING, UnstructuredFileLoader)
try:
# Chunk size is a major trade-off parameter to control result accuracy over computaion
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=st.session_state["chunk_size"],
chunk_overlap=st.session_state["chunk_overlap"],
)
docs = loader.load()
docs = text_splitter.split_documents(docs)
logger.info(f"Loaded: {len(docs)} document chucks")
return docs
except Exception as e:
msg = (
e
if loader
else f"No Loader found for your data source. Consider contributing:  {PROJECT_URL}!"
)
error_msg = f"Failed to load '{st.session_state['data_source']}':\n\n{msg}"
st.error(error_msg, icon=PAGE_ICON)
logger.error(error_msg)
st.stop()

@ -0,0 +1,111 @@
from dataclasses import dataclass
import streamlit as st
from langchain.base_language import BaseLanguageModel
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.embeddings.openai import Embeddings, OpenAIEmbeddings
from langchain.llms import GPT4All, LlamaCpp
from datachad.utils import logger
class Enum:
@classmethod
def values(cls):
return [v for k, v in cls.__dict__.items() if not k.startswith("_")]
@classmethod
def dict(cls):
return {k: v for k, v in cls.__dict__.items() if not k.startswith("_")}
@dataclass
class Model:
name: str
mode: str
embedding: str
path: str = None # for local models only
def __str__(self):
return self.name
class MODES(Enum):
OPENAI = "OpenAI"
LOCAL = "Local"
class EMBEDDINGS(Enum):
OPENAI = "openai"
HUGGINGFACE = "all-MiniLM-L6-v2"
class MODELS(Enum):
GPT35TURBO = Model("gpt-3.5-turbo", MODES.OPENAI, EMBEDDINGS.OPENAI)
GPT4 = Model("gpt-4", MODES.OPENAI, EMBEDDINGS.OPENAI)
LLAMACPP = Model(
"LLAMA", MODES.LOCAL, EMBEDDINGS.HUGGINGFACE, "models/llamacpp.bin"
)
GPT4ALL = Model(
"GPT4All", MODES.LOCAL, EMBEDDINGS.HUGGINGFACE, "models/gpt4all.bin"
)
@classmethod
def for_mode(cls, mode):
return [v for v in cls.values() if isinstance(v, Model) and v.mode == mode]
def get_model() -> BaseLanguageModel:
match st.session_state["model"].name:
case MODELS.GPT35TURBO.name:
model = ChatOpenAI(
model_name=st.session_state["model"].name,
temperature=st.session_state["temperature"],
openai_api_key=st.session_state["openai_api_key"],
)
case MODELS.GPT4.name:
model = ChatOpenAI(
model_name=st.session_state["model"].name,
temperature=st.session_state["temperature"],
openai_api_key=st.session_state["openai_api_key"],
)
case MODELS.LLAMACPP.name:
model = LlamaCpp(
model_path=st.session_state["model"].path,
n_ctx=st.session_state["model_n_ctx"],
temperature=st.session_state["temperature"],
verbose=True,
)
case MODELS.GPT4ALL.name:
model = GPT4All(
model=st.session_state["model"].path,
n_ctx=st.session_state["model_n_ctx"],
backend="gptj",
temp=st.session_state["temperature"],
verbose=True,
)
# Add more models as needed
case _default:
msg = f"Model {st.session_state['model']} not supported!"
logger.error(msg)
st.error(msg)
exit
return model
def get_embeddings() -> Embeddings:
match st.session_state["model"].embedding:
case EMBEDDINGS.OPENAI:
embeddings = OpenAIEmbeddings(
disallowed_special=(), openai_api_key=st.session_state["openai_api_key"]
)
case EMBEDDINGS.HUGGINGFACE:
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS.HUGGINGFACE)
# Add more embeddings as needed
case _default:
msg = f"Embeddings {st.session_state['embeddings']} not supported!"
logger.error(msg)
st.error(msg)
exit
return embeddings

@ -0,0 +1,104 @@
import logging
import os
import sys
import deeplake
import openai
import streamlit as st
from dotenv import load_dotenv
from datachad.constants import APP_NAME, DATA_PATH, PAGE_ICON
# loads environment variables
load_dotenv()
logger = logging.getLogger(APP_NAME)
def configure_logger(debug: int = 0) -> None:
# boilerplate code to enable logging in the streamlit app console
log_level = logging.DEBUG if debug == 1 else logging.INFO
logger.setLevel(log_level)
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setLevel(log_level)
formatter = logging.Formatter("%(message)s")
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
logger.propagate = False
configure_logger(0)
def authenticate(
openai_api_key: str, activeloop_token: str, activeloop_org_name: str
) -> None:
# Validate all credentials are set and correct
# Check for env variables to enable local dev and deployments with shared credentials
openai_api_key = (
openai_api_key
or os.environ.get("OPENAI_API_KEY")
or st.secrets.get("OPENAI_API_KEY")
)
activeloop_token = (
activeloop_token
or os.environ.get("ACTIVELOOP_TOKEN")
or st.secrets.get("ACTIVELOOP_TOKEN")
)
activeloop_org_name = (
activeloop_org_name
or os.environ.get("ACTIVELOOP_ORG_NAME")
or st.secrets.get("ACTIVELOOP_ORG_NAME")
)
if not (openai_api_key and activeloop_token and activeloop_org_name):
st.session_state["auth_ok"] = False
st.error("Credentials neither set nor stored", icon=PAGE_ICON)
return
try:
# Try to access openai and deeplake
with st.spinner("Authentifying..."):
openai.api_key = openai_api_key
openai.Model.list()
deeplake.exists(
f"hub://{activeloop_org_name}/DataChad-Authentication-Check",
token=activeloop_token,
)
except Exception as e:
logger.error(f"Authentication failed with {e}")
st.session_state["auth_ok"] = False
st.error("Authentication failed", icon=PAGE_ICON)
return
# store credentials in the session state
st.session_state["auth_ok"] = True
st.session_state["openai_api_key"] = openai_api_key
st.session_state["activeloop_token"] = activeloop_token
st.session_state["activeloop_org_name"] = activeloop_org_name
logger.info("Authentification successful!")
def save_uploaded_file() -> str:
# streamlit uploaded files need to be stored locally
# before embedded and uploaded to the hub
uploaded_file = st.session_state["uploaded_file"]
if not os.path.exists(DATA_PATH):
os.makedirs(DATA_PATH)
file_path = str(DATA_PATH / uploaded_file.name)
uploaded_file.seek(0)
file_bytes = uploaded_file.read()
file = open(file_path, "wb")
file.write(file_bytes)
file.close()
logger.info(f"Saved: {file_path}")
return file_path
def delete_uploaded_file() -> None:
# cleanup locally stored files
file_path = DATA_PATH / st.session_state["uploaded_file"].name
if os.path.exists(DATA_PATH):
os.remove(file_path)
logger.info(f"Removed: {file_path}")

@ -2,7 +2,7 @@ streamlit==1.22.0
streamlit-chat==0.0.2.2
deeplake==3.4.1
openai==0.27.6
langchain==0.0.164
langchain==0.0.173
tiktoken==0.4.0
unstructured==0.6.5
pdf2image==1.16.3

@ -1,393 +0,0 @@
import logging
import os
import re
import shutil
import sys
from typing import List
import deeplake
import openai
import streamlit as st
from dotenv import load_dotenv
from langchain.base_language import BaseLanguageModel
from langchain.callbacks import OpenAICallbackHandler, get_openai_callback
from langchain.chains import ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import (
CSVLoader,
DirectoryLoader,
EverNoteLoader,
GitLoader,
NotebookLoader,
OnlinePDFLoader,
PDFMinerLoader,
PythonLoader,
TextLoader,
UnstructuredEPubLoader,
UnstructuredFileLoader,
UnstructuredHTMLLoader,
UnstructuredMarkdownLoader,
UnstructuredODTLoader,
UnstructuredPowerPointLoader,
UnstructuredWordDocumentLoader,
WebBaseLoader,
)
from langchain.document_loaders.base import BaseLoader
from langchain.embeddings.openai import Embeddings, OpenAIEmbeddings
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import DeepLake, VectorStore
from langchain.llms import GPT4All, LlamaCpp
from langchain.embeddings import HuggingFaceEmbeddings
from constants import APP_NAME, DATA_PATH, PAGE_ICON, PROJECT_URL, LLAMACPP_MODEL_PATH, GPT4ALL_MODEL_PATH
# loads environment variables
load_dotenv()
logger = logging.getLogger(APP_NAME)
def configure_logger(debug: int = 0) -> None:
# boilerplate code to enable logging in the streamlit app console
log_level = logging.DEBUG if debug == 1 else logging.INFO
logger.setLevel(log_level)
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setLevel(log_level)
formatter = logging.Formatter("%(message)s")
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
logger.propagate = False
configure_logger(0)
def authenticate(
openai_api_key: str, activeloop_token: str, activeloop_org_name: str
) -> None:
# Validate all credentials are set and correct
# Check for env variables to enable local dev and deployments with shared credentials
openai_api_key = (
openai_api_key
or os.environ.get("OPENAI_API_KEY")
or st.secrets.get("OPENAI_API_KEY")
)
activeloop_token = (
activeloop_token
or os.environ.get("ACTIVELOOP_TOKEN")
or st.secrets.get("ACTIVELOOP_TOKEN")
)
activeloop_org_name = (
activeloop_org_name
or os.environ.get("ACTIVELOOP_ORG_NAME")
or st.secrets.get("ACTIVELOOP_ORG_NAME")
)
if not (openai_api_key and activeloop_token and activeloop_org_name):
st.session_state["auth_ok"] = False
st.error("Credentials neither set nor stored", icon=PAGE_ICON)
return
try:
# Try to access openai and deeplake
with st.spinner("Authentifying..."):
openai.api_key = openai_api_key
openai.Model.list()
deeplake.exists(
f"hub://{activeloop_org_name}/DataChad-Authentication-Check",
token=activeloop_token,
)
except Exception as e:
logger.error(f"Authentication failed with {e}")
st.session_state["auth_ok"] = False
st.error("Authentication failed", icon=PAGE_ICON)
return
# store credentials in the session state
st.session_state["auth_ok"] = True
st.session_state["openai_api_key"] = openai_api_key
st.session_state["activeloop_token"] = activeloop_token
st.session_state["activeloop_org_name"] = activeloop_org_name
logger.info("Authentification successful!")
def save_uploaded_file() -> str:
# streamlit uploaded files need to be stored locally
# before embedded and uploaded to the hub
uploaded_file = st.session_state["uploaded_file"]
if not os.path.exists(DATA_PATH):
os.makedirs(DATA_PATH)
file_path = str(DATA_PATH / uploaded_file.name)
uploaded_file.seek(0)
file_bytes = uploaded_file.read()
file = open(file_path, "wb")
file.write(file_bytes)
file.close()
logger.info(f"Saved: {file_path}")
return file_path
def delete_uploaded_file() -> None:
# cleanup locally stored files
file_path = DATA_PATH / st.session_state["uploaded_file"].name
if os.path.exists(DATA_PATH):
os.remove(file_path)
logger.info(f"Removed: {file_path}")
class AutoGitLoader:
def __init__(self, data_source: str) -> None:
self.data_source = data_source
def load(self) -> List[Document]:
# We need to try both common main branches
# Thank you github for the "master" to "main" switch
# we need to make sure the data path exists
if not os.path.exists(DATA_PATH):
os.makedirs(DATA_PATH)
repo_name = self.data_source.split("/")[-1].split(".")[0]
repo_path = str(DATA_PATH / repo_name)
clone_url = self.data_source
if os.path.exists(repo_path):
clone_url = None
branches = ["main", "master"]
for branch in branches:
try:
docs = GitLoader(repo_path, clone_url, branch).load()
break
except Exception as e:
logger.error(f"Error loading git: {e}")
if os.path.exists(repo_path):
# cleanup repo afterwards
shutil.rmtree(repo_path)
try:
return docs
except:
raise RuntimeError("Make sure to use HTTPS GitHub repo links")
FILE_LOADER_MAPPING = {
".csv": (CSVLoader, {"encoding": "utf-8"}),
".doc": (UnstructuredWordDocumentLoader, {}),
".docx": (UnstructuredWordDocumentLoader, {}),
".enex": (EverNoteLoader, {}),
".epub": (UnstructuredEPubLoader, {}),
".html": (UnstructuredHTMLLoader, {}),
".md": (UnstructuredMarkdownLoader, {}),
".odt": (UnstructuredODTLoader, {}),
".pdf": (PDFMinerLoader, {}),
".ppt": (UnstructuredPowerPointLoader, {}),
".pptx": (UnstructuredPowerPointLoader, {}),
".txt": (TextLoader, {"encoding": "utf8"}),
".ipynb": (NotebookLoader, {}),
".py": (PythonLoader, {}),
# Add more mappings for other file extensions and loaders as needed
}
WEB_LOADER_MAPPING = {
".git": (AutoGitLoader, {}),
".pdf": (OnlinePDFLoader, {}),
}
def get_loader(file_path: str, mapping: dict, default_loader: BaseLoader) -> BaseLoader:
# Choose loader from mapping, load default if no match found
ext = "." + file_path.rsplit(".", 1)[-1]
if ext in mapping:
loader_class, loader_args = mapping[ext]
loader = loader_class(file_path, **loader_args)
else:
loader = default_loader(file_path)
return loader
def load_data_source() -> List[Document]:
# Ugly thing that decides how to load data
# It aint much, but it's honest work
data_source = st.session_state["data_source"]
is_web = data_source.startswith("http")
is_dir = os.path.isdir(data_source)
is_file = os.path.isfile(data_source)
loader = None
if is_dir:
loader = DirectoryLoader(data_source, recursive=True, silent_errors=True)
elif is_web:
loader = get_loader(data_source, WEB_LOADER_MAPPING, WebBaseLoader)
elif is_file:
loader = get_loader(data_source, FILE_LOADER_MAPPING, UnstructuredFileLoader)
try:
# Chunk size is a major trade-off parameter to control result accuracy over computaion
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=st.session_state["chunk_size"],
chunk_overlap=st.session_state["chunk_overlap"],
)
docs = loader.load()
docs = text_splitter.split_documents(docs)
logger.info(f"Loaded: {len(docs)} document chucks")
return docs
except Exception as e:
msg = (
e
if loader
else f"No Loader found for your data source. Consider contributing:  {PROJECT_URL}!"
)
error_msg = f"Failed to load '{st.session_state['data_source']}':\n\n{msg}"
st.error(error_msg, icon=PAGE_ICON)
logger.error(error_msg)
st.stop()
def get_dataset_name() -> str:
# replace all non-word characters with dashes
# to get a string that can be used to create a new dataset
dashed_string = re.sub(r"\W+", "-", st.session_state["data_source"])
cleaned_string = re.sub(r"--+", "- ", dashed_string).strip("-")
return cleaned_string
def get_model() -> BaseLanguageModel:
match st.session_state["model"]:
case "gpt-3.5-turbo":
model = ChatOpenAI(
model_name=st.session_state["model"],
temperature=st.session_state["temperature"],
openai_api_key=st.session_state["openai_api_key"],
)
case "LlamaCpp":
model = LlamaCpp(
model_path=LLAMACPP_MODEL_PATH,
n_ctx=st.session_state["model_n_ctx"],
temperature=st.session_state["temperature"],
verbose=True,
)
case "GPT4All":
model = GPT4All(
model=GPT4ALL_MODEL_PATH,
n_ctx=st.session_state["model_n_ctx"],
backend="gptj",
temp=st.session_state["temperature"],
verbose=True,
)
# Add more models as needed
case _default:
msg = f"Model {st.session_state['model']} not supported!"
logger.error(msg)
st.error(msg)
exit
return model
def get_embeddings() -> Embeddings:
match st.session_state["embeddings"]:
case "openai":
embeddings = OpenAIEmbeddings(
disallowed_special=(), openai_api_key=st.session_state["openai_api_key"]
)
case "huggingface-Fall-MiniLM-L6-v2":
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
# Add more embeddings as needed
case _default:
msg = f"Embeddings {st.session_state['embeddings']} not supported!"
logger.error(msg)
st.error(msg)
exit
return embeddings
def get_vector_store() -> VectorStore:
# either load existing vector store or upload a new one to the hub
embeddings = get_embeddings()
dataset_name = get_dataset_name()
dataset_path = f"hub://{st.session_state['activeloop_org_name']}/{dataset_name}-{st.session_state['chunk_size']}"
if deeplake.exists(dataset_path, token=st.session_state["activeloop_token"]):
with st.spinner("Loading vector store..."):
logger.info(f"Dataset '{dataset_path}' exists -> loading")
vector_store = DeepLake(
dataset_path=dataset_path,
read_only=True,
embedding_function=embeddings,
token=st.session_state["activeloop_token"],
)
else:
with st.spinner("Reading, embedding and uploading data to hub..."):
logger.info(f"Dataset '{dataset_path}' does not exist -> uploading")
docs = load_data_source()
vector_store = DeepLake.from_documents(
docs,
embeddings,
dataset_path=dataset_path,
token=st.session_state["activeloop_token"],
)
return vector_store
def get_chain() -> ConversationalRetrievalChain:
# create the langchain that will be called to generate responses
vector_store = get_vector_store()
retriever = vector_store.as_retriever()
# Search params "fetch_k" and "k" define how many documents are pulled from the hub
# and selected after the document matching to build the context
# that is fed to the model together with your prompt
search_kwargs = {
"maximal_marginal_relevance": True,
"distance_metric": "cos",
"fetch_k": st.session_state["fetch_k"],
"k": st.session_state["k"],
}
retriever.search_kwargs.update(search_kwargs)
model = get_model()
chain = ConversationalRetrievalChain.from_llm(
model,
retriever=retriever,
chain_type="stuff",
verbose=True,
# we limit the maximum number of used tokens
# to prevent running into the models token limit of 4096
max_tokens_limit=st.session_state["max_tokens"],
)
return chain
def update_chain() -> None:
# Build chain with parameters from session state and store it back
# Also delete chat history to not confuse the bot with old context
try:
st.session_state["chain"] = get_chain()
st.session_state["chat_history"] = []
msg = f"Data source '{st.session_state['data_source']}' is ready to go!"
logger.info(msg)
st.info(msg, icon=PAGE_ICON)
except Exception as e:
msg = f"Failed to build chain for data source '{st.session_state['data_source']}' with error: {e}"
logger.error(msg)
st.error(msg, icon=PAGE_ICON)
def update_usage(cb: OpenAICallbackHandler) -> None:
# Accumulate API call usage via callbacks
logger.info(f"Usage: {cb}")
callback_properties = [
"total_tokens",
"prompt_tokens",
"completion_tokens",
"total_cost",
]
for prop in callback_properties:
value = getattr(cb, prop, 0)
st.session_state["usage"].setdefault(prop, 0)
st.session_state["usage"][prop] += value
def generate_response(prompt: str) -> str:
# call the chain to generate responses and add them to the chat history
with st.spinner("Generating response"), get_openai_callback() as cb:
response = st.session_state["chain"](
{"question": prompt, "chat_history": st.session_state["chat_history"]}
)
update_usage(cb)
logger.info(f"Response: '{response}'")
st.session_state["chat_history"].append((prompt, response["answer"]))
return response["answer"]
Loading…
Cancel
Save