refactor to enable mode and model selection

main
Gustav von Zitzewitz 1 year ago
parent bcd4395907
commit 56cd7e3ba5

@ -1,33 +1,33 @@
import streamlit as st import streamlit as st
from streamlit_chat import message from streamlit_chat import message
from constants import ( from datachad.chain import generate_response, update_chain
from datachad.constants import (
ACTIVELOOP_HELP, ACTIVELOOP_HELP,
APP_NAME, APP_NAME,
AUTHENTICATION_HELP, AUTHENTICATION_HELP,
CHUNK_OVERLAP, CHUNK_OVERLAP,
CHUNK_SIZE, CHUNK_SIZE,
DEFAULT_DATA_SOURCE, DEFAULT_DATA_SOURCE,
EMBEDDINGS,
ENABLE_ADVANCED_OPTIONS, ENABLE_ADVANCED_OPTIONS,
ENABLE_LOCAL_MODE,
FETCH_K, FETCH_K,
LOCAL_MODE_DISABLED_HELP,
MAX_TOKENS, MAX_TOKENS,
MODEL, MODEL_N_CTX,
OPENAI_HELP, OPENAI_HELP,
PAGE_ICON, PAGE_ICON,
PROJECT_URL, PROJECT_URL,
TEMPERATURE, TEMPERATURE,
USAGE_HELP, USAGE_HELP,
MODEL_N_CTX,
K, K,
) )
from utils import ( from datachad.models import MODELS, MODES
from datachad.utils import (
authenticate, authenticate,
delete_uploaded_file, delete_uploaded_file,
generate_response,
logger, logger,
save_uploaded_file, save_uploaded_file,
update_chain,
) )
# Page options and header # Page options and header
@ -52,8 +52,8 @@ SESSION_DEFAULTS = {
"activeloop_org_name": None, "activeloop_org_name": None,
"uploaded_file": None, "uploaded_file": None,
"data_source": DEFAULT_DATA_SOURCE, "data_source": DEFAULT_DATA_SOURCE,
"model": MODEL, "mode": MODES.OPENAI,
"embeddings": EMBEDDINGS, "model": MODELS.GPT35TURBO,
"k": K, "k": K,
"fetch_k": FETCH_K, "fetch_k": FETCH_K,
"chunk_size": CHUNK_SIZE, "chunk_size": CHUNK_SIZE,
@ -72,7 +72,7 @@ def authentication_form() -> None:
st.title("Authentication", help=AUTHENTICATION_HELP) st.title("Authentication", help=AUTHENTICATION_HELP)
with st.form("authentication"): with st.form("authentication"):
openai_api_key = st.text_input( openai_api_key = st.text_input(
"OpenAI API Key", f"{st.session_state['mode']} API Key",
type="password", type="password",
help=OPENAI_HELP, help=OPENAI_HELP,
placeholder="This field is mandatory", placeholder="This field is mandatory",
@ -101,29 +101,34 @@ def advanced_options_form() -> None:
) )
if advanced_options: if advanced_options:
with st.form("advanced_options"): with st.form("advanced_options"):
temperature = st.slider( col1, col2 = st.columns(2)
col1.selectbox("model", MODELS.for_mode(st.session_state["mode"]))
col2.number_input(
"temperature", "temperature",
min_value=0.0, min_value=0.0,
max_value=1.0, max_value=1.0,
value=TEMPERATURE, value=TEMPERATURE,
help="Controls the randomness of the language model output", help="Controls the randomness of the language model output",
key="temperature",
) )
col1, col2 = st.columns(2)
fetch_k = col1.number_input( col1.number_input(
"k_fetch", "k_fetch",
min_value=1, min_value=1,
max_value=1000, max_value=1000,
value=FETCH_K, value=FETCH_K,
help="The number of documents to pull from the vector database", help="The number of documents to pull from the vector database",
key="k_fetch",
) )
k = col2.number_input( col2.number_input(
"k", "k",
min_value=1, min_value=1,
max_value=100, max_value=100,
value=K, value=K,
help="The number of most similar documents to build the context from", help="The number of most similar documents to build the context from",
key="k",
) )
chunk_size = col1.number_input( col1.number_input(
"chunk_size", "chunk_size",
min_value=1, min_value=1,
max_value=100000, max_value=100000,
@ -133,35 +138,37 @@ def advanced_options_form() -> None:
"before being embedded.\n\nChanging this parameter makes re-embedding " "before being embedded.\n\nChanging this parameter makes re-embedding "
"and re-uploading the data to the database necessary " "and re-uploading the data to the database necessary "
), ),
key="chunk_size",
) )
max_tokens = col2.number_input( col2.number_input(
"max_tokens", "max_tokens",
min_value=1, min_value=1,
max_value=4069, max_value=30000,
value=MAX_TOKENS, value=MAX_TOKENS,
help="Limits the documents returned from database based on number of tokens", help="Limits the documents returned from database based on number of tokens",
key="max_tokens",
) )
applied = st.form_submit_button("Apply") applied = st.form_submit_button("Apply")
if applied: if applied:
st.session_state["k"] = k
st.session_state["fetch_k"] = fetch_k
st.session_state["chunk_size"] = chunk_size
st.session_state["temperature"] = temperature
st.session_state["max_tokens"] = max_tokens
update_chain() update_chain()
# Sidebar with Authentication and Advanced Options # Sidebar with Authentication and Advanced Options
with st.sidebar: with st.sidebar:
authentication_form() mode = st.selectbox("Mode", MODES.values(), key="mode")
if mode == MODES.LOCAL and not ENABLE_LOCAL_MODE:
st.error(LOCAL_MODE_DISABLED_HELP, icon=PAGE_ICON)
st.stop()
if mode != MODES.LOCAL:
authentication_form()
st.info(f"Learn how it works [here]({PROJECT_URL})") st.info(f"Learn how it works [here]({PROJECT_URL})")
# Only start App if authentication is OK # Only start App if authentication is OK
if not st.session_state["auth_ok"]: if not (st.session_state["auth_ok"] or mode == MODES.LOCAL):
st.stop() st.stop()
# Clear button to reset all chat communication # Clear button to reset all chat communication
clear_button = st.button("Clear Conversation", key="clear") clear_button = st.button("Clear Conversation")
# Advanced Options # Advanced Options
if ENABLE_ADVANCED_OPTIONS: if ENABLE_ADVANCED_OPTIONS:
@ -214,6 +221,7 @@ with text_container:
submit_button = st.form_submit_button(label="Send") submit_button = st.form_submit_button(label="Send")
if submit_button and user_input: if submit_button and user_input:
text_container.empty()
output = generate_response(user_input) output = generate_response(user_input)
st.session_state["past"].append(user_input) st.session_state["past"].append(user_input)
st.session_state["generated"].append(output) st.session_state["generated"].append(output)

@ -0,0 +1,77 @@
import streamlit as st
from langchain.callbacks import OpenAICallbackHandler, get_openai_callback
from langchain.chains import ConversationalRetrievalChain
from datachad.constants import PAGE_ICON
from datachad.database import get_vector_store
from datachad.models import get_model
from datachad.utils import logger
def get_chain() -> ConversationalRetrievalChain:
# create the langchain that will be called to generate responses
vector_store = get_vector_store()
retriever = vector_store.as_retriever()
# Search params "fetch_k" and "k" define how many documents are pulled from the hub
# and selected after the document matching to build the context
# that is fed to the model together with your prompt
search_kwargs = {
"maximal_marginal_relevance": True,
"distance_metric": "cos",
"fetch_k": st.session_state["fetch_k"],
"k": st.session_state["k"],
}
retriever.search_kwargs.update(search_kwargs)
model = get_model()
chain = ConversationalRetrievalChain.from_llm(
model,
retriever=retriever,
chain_type="stuff",
verbose=True,
# we limit the maximum number of used tokens
# to prevent running into the models token limit of 4096
max_tokens_limit=st.session_state["max_tokens"],
)
return chain
def update_chain() -> None:
# Build chain with parameters from session state and store it back
# Also delete chat history to not confuse the bot with old context
try:
st.session_state["chain"] = get_chain()
st.session_state["chat_history"] = []
msg = f"Data source '{st.session_state['data_source']}' is ready to go!"
logger.info(msg)
st.info(msg, icon=PAGE_ICON)
except Exception as e:
msg = f"Failed to build chain for data source '{st.session_state['data_source']}' with error: {e}"
logger.error(msg)
st.error(msg, icon=PAGE_ICON)
def update_usage(cb: OpenAICallbackHandler) -> None:
# Accumulate API call usage via callbacks
logger.info(f"Usage: {cb}")
callback_properties = [
"total_tokens",
"prompt_tokens",
"completion_tokens",
"total_cost",
]
for prop in callback_properties:
value = getattr(cb, prop, 0)
st.session_state["usage"].setdefault(prop, 0)
st.session_state["usage"][prop] += value
def generate_response(prompt: str) -> str:
# call the chain to generate responses and add them to the chat history
with st.spinner("Generating response"), get_openai_callback() as cb:
response = st.session_state["chain"](
{"question": prompt, "chat_history": st.session_state["chat_history"]}
)
update_usage(cb)
logger.info(f"Response: '{response}'")
st.session_state["chat_history"].append((prompt, response["answer"]))
return response["answer"]

@ -1,28 +1,31 @@
from pathlib import Path from pathlib import Path
APP_NAME = "DataChad"
MODEL = "gpt-3.5-turbo"
EMBEDDINGS = "openai"
PAGE_ICON = "🤖" PAGE_ICON = "🤖"
APP_NAME = "DataChad"
PROJECT_URL = "https://github.com/gustavz/DataChad" PROJECT_URL = "https://github.com/gustavz/DataChad"
K = 10 K = 10
FETCH_K = 20 FETCH_K = 20
CHUNK_SIZE = 1000 CHUNK_SIZE = 1000
CHUNK_OVERLAP = 0 CHUNK_OVERLAP = 0
TEMPERATURE = 0.7 TEMPERATURE = 0.7
MAX_TOKENS = 3357 MAX_TOKENS = 3357
MODEL_N_CTX = 1000
ENABLE_LOCAL_MODE = False
ENABLE_ADVANCED_OPTIONS = True ENABLE_ADVANCED_OPTIONS = True
MODEL_N_CTX = 1000
LLAMACPP_MODEL_PATH = ""
GPT4ALL_MODEL_PATH = ""
ENABLE_LOCAL_MODELS = False
DATA_PATH = Path.cwd() / "data" DATA_PATH = Path.cwd() / "data"
DEFAULT_DATA_SOURCE = "https://github.com/gustavz/DataChad.git" DEFAULT_DATA_SOURCE = "https://github.com/gustavz/DataChad.git"
LOCAL_MODE_DISABLED_HELP = """
This is a demo hosted with limited resources. Local Mode is not enabled.\n
To use Local Mode deploy the app on your machine of choice with ENABLE_LOCAL_MODE set to True.
"""
AUTHENTICATION_HELP = f""" AUTHENTICATION_HELP = f"""
Your credentials are only stored in your session state.\n Your credentials are only stored in your session state.\n
The keys are neither exposed nor made visible or stored permanently in any way.\n The keys are neither exposed nor made visible or stored permanently in any way.\n
@ -31,7 +34,7 @@ Feel free to check out [the code base]({PROJECT_URL}) to validate how things wor
USAGE_HELP = f""" USAGE_HELP = f"""
These are the accumulated OpenAI API usage metrics.\n These are the accumulated OpenAI API usage metrics.\n
The app uses '{MODEL}' for chat and 'text-embedding-ada-002' for embeddings.\n The app uses 'gpt-3.5-turbo' for chat and 'text-embedding-ada-002' for embeddings.\n
Learn more about OpenAI's pricing [here](https://openai.com/pricing#language-models) Learn more about OpenAI's pricing [here](https://openai.com/pricing#language-models)
""" """

@ -0,0 +1,51 @@
import os
import re
import deeplake
import streamlit as st
from langchain.vectorstores import DeepLake, VectorStore
from datachad.constants import DATA_PATH
from datachad.loader import load_data_source
from datachad.models import MODES, get_embeddings
from datachad.utils import logger
def get_dataset_path() -> str:
# replace all non-word characters with dashes
# to get a string that can be used to create a new dataset
dataset_name = re.sub(r"\W+", "-", st.session_state["data_source"])
dataset_name = re.sub(r"--+", "- ", dataset_name).strip("-")
if st.session_state["mode"] == MODES.LOCAL:
if not os.path.exists(DATA_PATH):
os.makedirs(DATA_PATH)
dataset_path = str(DATA_PATH / dataset_name)
else:
dataset_path = f"hub://{st.session_state['activeloop_org_name']}/{dataset_name}-{st.session_state['chunk_size']}"
return dataset_path
def get_vector_store() -> VectorStore:
# either load existing vector store or upload a new one to the hub
embeddings = get_embeddings()
dataset_path = get_dataset_path()
if deeplake.exists(dataset_path, token=st.session_state["activeloop_token"]):
with st.spinner("Loading vector store..."):
logger.info(f"Dataset '{dataset_path}' exists -> loading")
vector_store = DeepLake(
dataset_path=dataset_path,
read_only=True,
embedding_function=embeddings,
token=st.session_state["activeloop_token"],
)
else:
with st.spinner("Reading, embedding and uploading data to hub..."):
logger.info(f"Dataset '{dataset_path}' does not exist -> uploading")
docs = load_data_source()
vector_store = DeepLake.from_documents(
docs,
embeddings,
dataset_path=dataset_path,
token=st.session_state["activeloop_token"],
)
return vector_store

@ -0,0 +1,133 @@
import os
import shutil
from typing import List
import streamlit as st
from langchain.document_loaders import (
CSVLoader,
DirectoryLoader,
EverNoteLoader,
GitLoader,
NotebookLoader,
OnlinePDFLoader,
PDFMinerLoader,
PythonLoader,
TextLoader,
UnstructuredEPubLoader,
UnstructuredFileLoader,
UnstructuredHTMLLoader,
UnstructuredMarkdownLoader,
UnstructuredODTLoader,
UnstructuredPowerPointLoader,
UnstructuredWordDocumentLoader,
WebBaseLoader,
)
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from datachad.constants import DATA_PATH, PAGE_ICON, PROJECT_URL
from datachad.utils import logger
class AutoGitLoader:
def __init__(self, data_source: str) -> None:
self.data_source = data_source
def load(self) -> List[Document]:
# We need to try both common main branches
# Thank you github for the "master" to "main" switch
# we need to make sure the data path exists
if not os.path.exists(DATA_PATH):
os.makedirs(DATA_PATH)
repo_name = self.data_source.split("/")[-1].split(".")[0]
repo_path = str(DATA_PATH / repo_name)
clone_url = self.data_source
if os.path.exists(repo_path):
clone_url = None
branches = ["main", "master"]
for branch in branches:
try:
docs = GitLoader(repo_path, clone_url, branch).load()
break
except Exception as e:
logger.error(f"Error loading git: {e}")
if os.path.exists(repo_path):
# cleanup repo afterwards
shutil.rmtree(repo_path)
try:
return docs
except:
raise RuntimeError("Make sure to use HTTPS GitHub repo links")
FILE_LOADER_MAPPING = {
".csv": (CSVLoader, {"encoding": "utf-8"}),
".doc": (UnstructuredWordDocumentLoader, {}),
".docx": (UnstructuredWordDocumentLoader, {}),
".enex": (EverNoteLoader, {}),
".epub": (UnstructuredEPubLoader, {}),
".html": (UnstructuredHTMLLoader, {}),
".md": (UnstructuredMarkdownLoader, {}),
".odt": (UnstructuredODTLoader, {}),
".pdf": (PDFMinerLoader, {}),
".ppt": (UnstructuredPowerPointLoader, {}),
".pptx": (UnstructuredPowerPointLoader, {}),
".txt": (TextLoader, {"encoding": "utf8"}),
".ipynb": (NotebookLoader, {}),
".py": (PythonLoader, {}),
# Add more mappings for other file extensions and loaders as needed
}
WEB_LOADER_MAPPING = {
".git": (AutoGitLoader, {}),
".pdf": (OnlinePDFLoader, {}),
}
def get_loader(file_path: str, mapping: dict, default_loader: BaseLoader) -> BaseLoader:
# Choose loader from mapping, load default if no match found
ext = "." + file_path.rsplit(".", 1)[-1]
if ext in mapping:
loader_class, loader_args = mapping[ext]
loader = loader_class(file_path, **loader_args)
else:
loader = default_loader(file_path)
return loader
def load_data_source() -> List[Document]:
# Ugly thing that decides how to load data
# It aint much, but it's honest work
data_source = st.session_state["data_source"]
is_web = data_source.startswith("http")
is_dir = os.path.isdir(data_source)
is_file = os.path.isfile(data_source)
loader = None
if is_dir:
loader = DirectoryLoader(data_source, recursive=True, silent_errors=True)
elif is_web:
loader = get_loader(data_source, WEB_LOADER_MAPPING, WebBaseLoader)
elif is_file:
loader = get_loader(data_source, FILE_LOADER_MAPPING, UnstructuredFileLoader)
try:
# Chunk size is a major trade-off parameter to control result accuracy over computaion
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=st.session_state["chunk_size"],
chunk_overlap=st.session_state["chunk_overlap"],
)
docs = loader.load()
docs = text_splitter.split_documents(docs)
logger.info(f"Loaded: {len(docs)} document chucks")
return docs
except Exception as e:
msg = (
e
if loader
else f"No Loader found for your data source. Consider contributing:  {PROJECT_URL}!"
)
error_msg = f"Failed to load '{st.session_state['data_source']}':\n\n{msg}"
st.error(error_msg, icon=PAGE_ICON)
logger.error(error_msg)
st.stop()

@ -0,0 +1,111 @@
from dataclasses import dataclass
import streamlit as st
from langchain.base_language import BaseLanguageModel
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.embeddings.openai import Embeddings, OpenAIEmbeddings
from langchain.llms import GPT4All, LlamaCpp
from datachad.utils import logger
class Enum:
@classmethod
def values(cls):
return [v for k, v in cls.__dict__.items() if not k.startswith("_")]
@classmethod
def dict(cls):
return {k: v for k, v in cls.__dict__.items() if not k.startswith("_")}
@dataclass
class Model:
name: str
mode: str
embedding: str
path: str = None # for local models only
def __str__(self):
return self.name
class MODES(Enum):
OPENAI = "OpenAI"
LOCAL = "Local"
class EMBEDDINGS(Enum):
OPENAI = "openai"
HUGGINGFACE = "all-MiniLM-L6-v2"
class MODELS(Enum):
GPT35TURBO = Model("gpt-3.5-turbo", MODES.OPENAI, EMBEDDINGS.OPENAI)
GPT4 = Model("gpt-4", MODES.OPENAI, EMBEDDINGS.OPENAI)
LLAMACPP = Model(
"LLAMA", MODES.LOCAL, EMBEDDINGS.HUGGINGFACE, "models/llamacpp.bin"
)
GPT4ALL = Model(
"GPT4All", MODES.LOCAL, EMBEDDINGS.HUGGINGFACE, "models/gpt4all.bin"
)
@classmethod
def for_mode(cls, mode):
return [v for v in cls.values() if isinstance(v, Model) and v.mode == mode]
def get_model() -> BaseLanguageModel:
match st.session_state["model"].name:
case MODELS.GPT35TURBO.name:
model = ChatOpenAI(
model_name=st.session_state["model"].name,
temperature=st.session_state["temperature"],
openai_api_key=st.session_state["openai_api_key"],
)
case MODELS.GPT4.name:
model = ChatOpenAI(
model_name=st.session_state["model"].name,
temperature=st.session_state["temperature"],
openai_api_key=st.session_state["openai_api_key"],
)
case MODELS.LLAMACPP.name:
model = LlamaCpp(
model_path=st.session_state["model"].path,
n_ctx=st.session_state["model_n_ctx"],
temperature=st.session_state["temperature"],
verbose=True,
)
case MODELS.GPT4ALL.name:
model = GPT4All(
model=st.session_state["model"].path,
n_ctx=st.session_state["model_n_ctx"],
backend="gptj",
temp=st.session_state["temperature"],
verbose=True,
)
# Add more models as needed
case _default:
msg = f"Model {st.session_state['model']} not supported!"
logger.error(msg)
st.error(msg)
exit
return model
def get_embeddings() -> Embeddings:
match st.session_state["model"].embedding:
case EMBEDDINGS.OPENAI:
embeddings = OpenAIEmbeddings(
disallowed_special=(), openai_api_key=st.session_state["openai_api_key"]
)
case EMBEDDINGS.HUGGINGFACE:
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS.HUGGINGFACE)
# Add more embeddings as needed
case _default:
msg = f"Embeddings {st.session_state['embeddings']} not supported!"
logger.error(msg)
st.error(msg)
exit
return embeddings

@ -0,0 +1,104 @@
import logging
import os
import sys
import deeplake
import openai
import streamlit as st
from dotenv import load_dotenv
from datachad.constants import APP_NAME, DATA_PATH, PAGE_ICON
# loads environment variables
load_dotenv()
logger = logging.getLogger(APP_NAME)
def configure_logger(debug: int = 0) -> None:
# boilerplate code to enable logging in the streamlit app console
log_level = logging.DEBUG if debug == 1 else logging.INFO
logger.setLevel(log_level)
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setLevel(log_level)
formatter = logging.Formatter("%(message)s")
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
logger.propagate = False
configure_logger(0)
def authenticate(
openai_api_key: str, activeloop_token: str, activeloop_org_name: str
) -> None:
# Validate all credentials are set and correct
# Check for env variables to enable local dev and deployments with shared credentials
openai_api_key = (
openai_api_key
or os.environ.get("OPENAI_API_KEY")
or st.secrets.get("OPENAI_API_KEY")
)
activeloop_token = (
activeloop_token
or os.environ.get("ACTIVELOOP_TOKEN")
or st.secrets.get("ACTIVELOOP_TOKEN")
)
activeloop_org_name = (
activeloop_org_name
or os.environ.get("ACTIVELOOP_ORG_NAME")
or st.secrets.get("ACTIVELOOP_ORG_NAME")
)
if not (openai_api_key and activeloop_token and activeloop_org_name):
st.session_state["auth_ok"] = False
st.error("Credentials neither set nor stored", icon=PAGE_ICON)
return
try:
# Try to access openai and deeplake
with st.spinner("Authentifying..."):
openai.api_key = openai_api_key
openai.Model.list()
deeplake.exists(
f"hub://{activeloop_org_name}/DataChad-Authentication-Check",
token=activeloop_token,
)
except Exception as e:
logger.error(f"Authentication failed with {e}")
st.session_state["auth_ok"] = False
st.error("Authentication failed", icon=PAGE_ICON)
return
# store credentials in the session state
st.session_state["auth_ok"] = True
st.session_state["openai_api_key"] = openai_api_key
st.session_state["activeloop_token"] = activeloop_token
st.session_state["activeloop_org_name"] = activeloop_org_name
logger.info("Authentification successful!")
def save_uploaded_file() -> str:
# streamlit uploaded files need to be stored locally
# before embedded and uploaded to the hub
uploaded_file = st.session_state["uploaded_file"]
if not os.path.exists(DATA_PATH):
os.makedirs(DATA_PATH)
file_path = str(DATA_PATH / uploaded_file.name)
uploaded_file.seek(0)
file_bytes = uploaded_file.read()
file = open(file_path, "wb")
file.write(file_bytes)
file.close()
logger.info(f"Saved: {file_path}")
return file_path
def delete_uploaded_file() -> None:
# cleanup locally stored files
file_path = DATA_PATH / st.session_state["uploaded_file"].name
if os.path.exists(DATA_PATH):
os.remove(file_path)
logger.info(f"Removed: {file_path}")

@ -2,7 +2,7 @@ streamlit==1.22.0
streamlit-chat==0.0.2.2 streamlit-chat==0.0.2.2
deeplake==3.4.1 deeplake==3.4.1
openai==0.27.6 openai==0.27.6
langchain==0.0.164 langchain==0.0.173
tiktoken==0.4.0 tiktoken==0.4.0
unstructured==0.6.5 unstructured==0.6.5
pdf2image==1.16.3 pdf2image==1.16.3

@ -1,393 +0,0 @@
import logging
import os
import re
import shutil
import sys
from typing import List
import deeplake
import openai
import streamlit as st
from dotenv import load_dotenv
from langchain.base_language import BaseLanguageModel
from langchain.callbacks import OpenAICallbackHandler, get_openai_callback
from langchain.chains import ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import (
CSVLoader,
DirectoryLoader,
EverNoteLoader,
GitLoader,
NotebookLoader,
OnlinePDFLoader,
PDFMinerLoader,
PythonLoader,
TextLoader,
UnstructuredEPubLoader,
UnstructuredFileLoader,
UnstructuredHTMLLoader,
UnstructuredMarkdownLoader,
UnstructuredODTLoader,
UnstructuredPowerPointLoader,
UnstructuredWordDocumentLoader,
WebBaseLoader,
)
from langchain.document_loaders.base import BaseLoader
from langchain.embeddings.openai import Embeddings, OpenAIEmbeddings
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import DeepLake, VectorStore
from langchain.llms import GPT4All, LlamaCpp
from langchain.embeddings import HuggingFaceEmbeddings
from constants import APP_NAME, DATA_PATH, PAGE_ICON, PROJECT_URL, LLAMACPP_MODEL_PATH, GPT4ALL_MODEL_PATH
# loads environment variables
load_dotenv()
logger = logging.getLogger(APP_NAME)
def configure_logger(debug: int = 0) -> None:
# boilerplate code to enable logging in the streamlit app console
log_level = logging.DEBUG if debug == 1 else logging.INFO
logger.setLevel(log_level)
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setLevel(log_level)
formatter = logging.Formatter("%(message)s")
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
logger.propagate = False
configure_logger(0)
def authenticate(
openai_api_key: str, activeloop_token: str, activeloop_org_name: str
) -> None:
# Validate all credentials are set and correct
# Check for env variables to enable local dev and deployments with shared credentials
openai_api_key = (
openai_api_key
or os.environ.get("OPENAI_API_KEY")
or st.secrets.get("OPENAI_API_KEY")
)
activeloop_token = (
activeloop_token
or os.environ.get("ACTIVELOOP_TOKEN")
or st.secrets.get("ACTIVELOOP_TOKEN")
)
activeloop_org_name = (
activeloop_org_name
or os.environ.get("ACTIVELOOP_ORG_NAME")
or st.secrets.get("ACTIVELOOP_ORG_NAME")
)
if not (openai_api_key and activeloop_token and activeloop_org_name):
st.session_state["auth_ok"] = False
st.error("Credentials neither set nor stored", icon=PAGE_ICON)
return
try:
# Try to access openai and deeplake
with st.spinner("Authentifying..."):
openai.api_key = openai_api_key
openai.Model.list()
deeplake.exists(
f"hub://{activeloop_org_name}/DataChad-Authentication-Check",
token=activeloop_token,
)
except Exception as e:
logger.error(f"Authentication failed with {e}")
st.session_state["auth_ok"] = False
st.error("Authentication failed", icon=PAGE_ICON)
return
# store credentials in the session state
st.session_state["auth_ok"] = True
st.session_state["openai_api_key"] = openai_api_key
st.session_state["activeloop_token"] = activeloop_token
st.session_state["activeloop_org_name"] = activeloop_org_name
logger.info("Authentification successful!")
def save_uploaded_file() -> str:
# streamlit uploaded files need to be stored locally
# before embedded and uploaded to the hub
uploaded_file = st.session_state["uploaded_file"]
if not os.path.exists(DATA_PATH):
os.makedirs(DATA_PATH)
file_path = str(DATA_PATH / uploaded_file.name)
uploaded_file.seek(0)
file_bytes = uploaded_file.read()
file = open(file_path, "wb")
file.write(file_bytes)
file.close()
logger.info(f"Saved: {file_path}")
return file_path
def delete_uploaded_file() -> None:
# cleanup locally stored files
file_path = DATA_PATH / st.session_state["uploaded_file"].name
if os.path.exists(DATA_PATH):
os.remove(file_path)
logger.info(f"Removed: {file_path}")
class AutoGitLoader:
def __init__(self, data_source: str) -> None:
self.data_source = data_source
def load(self) -> List[Document]:
# We need to try both common main branches
# Thank you github for the "master" to "main" switch
# we need to make sure the data path exists
if not os.path.exists(DATA_PATH):
os.makedirs(DATA_PATH)
repo_name = self.data_source.split("/")[-1].split(".")[0]
repo_path = str(DATA_PATH / repo_name)
clone_url = self.data_source
if os.path.exists(repo_path):
clone_url = None
branches = ["main", "master"]
for branch in branches:
try:
docs = GitLoader(repo_path, clone_url, branch).load()
break
except Exception as e:
logger.error(f"Error loading git: {e}")
if os.path.exists(repo_path):
# cleanup repo afterwards
shutil.rmtree(repo_path)
try:
return docs
except:
raise RuntimeError("Make sure to use HTTPS GitHub repo links")
FILE_LOADER_MAPPING = {
".csv": (CSVLoader, {"encoding": "utf-8"}),
".doc": (UnstructuredWordDocumentLoader, {}),
".docx": (UnstructuredWordDocumentLoader, {}),
".enex": (EverNoteLoader, {}),
".epub": (UnstructuredEPubLoader, {}),
".html": (UnstructuredHTMLLoader, {}),
".md": (UnstructuredMarkdownLoader, {}),
".odt": (UnstructuredODTLoader, {}),
".pdf": (PDFMinerLoader, {}),
".ppt": (UnstructuredPowerPointLoader, {}),
".pptx": (UnstructuredPowerPointLoader, {}),
".txt": (TextLoader, {"encoding": "utf8"}),
".ipynb": (NotebookLoader, {}),
".py": (PythonLoader, {}),
# Add more mappings for other file extensions and loaders as needed
}
WEB_LOADER_MAPPING = {
".git": (AutoGitLoader, {}),
".pdf": (OnlinePDFLoader, {}),
}
def get_loader(file_path: str, mapping: dict, default_loader: BaseLoader) -> BaseLoader:
# Choose loader from mapping, load default if no match found
ext = "." + file_path.rsplit(".", 1)[-1]
if ext in mapping:
loader_class, loader_args = mapping[ext]
loader = loader_class(file_path, **loader_args)
else:
loader = default_loader(file_path)
return loader
def load_data_source() -> List[Document]:
# Ugly thing that decides how to load data
# It aint much, but it's honest work
data_source = st.session_state["data_source"]
is_web = data_source.startswith("http")
is_dir = os.path.isdir(data_source)
is_file = os.path.isfile(data_source)
loader = None
if is_dir:
loader = DirectoryLoader(data_source, recursive=True, silent_errors=True)
elif is_web:
loader = get_loader(data_source, WEB_LOADER_MAPPING, WebBaseLoader)
elif is_file:
loader = get_loader(data_source, FILE_LOADER_MAPPING, UnstructuredFileLoader)
try:
# Chunk size is a major trade-off parameter to control result accuracy over computaion
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=st.session_state["chunk_size"],
chunk_overlap=st.session_state["chunk_overlap"],
)
docs = loader.load()
docs = text_splitter.split_documents(docs)
logger.info(f"Loaded: {len(docs)} document chucks")
return docs
except Exception as e:
msg = (
e
if loader
else f"No Loader found for your data source. Consider contributing:  {PROJECT_URL}!"
)
error_msg = f"Failed to load '{st.session_state['data_source']}':\n\n{msg}"
st.error(error_msg, icon=PAGE_ICON)
logger.error(error_msg)
st.stop()
def get_dataset_name() -> str:
# replace all non-word characters with dashes
# to get a string that can be used to create a new dataset
dashed_string = re.sub(r"\W+", "-", st.session_state["data_source"])
cleaned_string = re.sub(r"--+", "- ", dashed_string).strip("-")
return cleaned_string
def get_model() -> BaseLanguageModel:
match st.session_state["model"]:
case "gpt-3.5-turbo":
model = ChatOpenAI(
model_name=st.session_state["model"],
temperature=st.session_state["temperature"],
openai_api_key=st.session_state["openai_api_key"],
)
case "LlamaCpp":
model = LlamaCpp(
model_path=LLAMACPP_MODEL_PATH,
n_ctx=st.session_state["model_n_ctx"],
temperature=st.session_state["temperature"],
verbose=True,
)
case "GPT4All":
model = GPT4All(
model=GPT4ALL_MODEL_PATH,
n_ctx=st.session_state["model_n_ctx"],
backend="gptj",
temp=st.session_state["temperature"],
verbose=True,
)
# Add more models as needed
case _default:
msg = f"Model {st.session_state['model']} not supported!"
logger.error(msg)
st.error(msg)
exit
return model
def get_embeddings() -> Embeddings:
match st.session_state["embeddings"]:
case "openai":
embeddings = OpenAIEmbeddings(
disallowed_special=(), openai_api_key=st.session_state["openai_api_key"]
)
case "huggingface-Fall-MiniLM-L6-v2":
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
# Add more embeddings as needed
case _default:
msg = f"Embeddings {st.session_state['embeddings']} not supported!"
logger.error(msg)
st.error(msg)
exit
return embeddings
def get_vector_store() -> VectorStore:
# either load existing vector store or upload a new one to the hub
embeddings = get_embeddings()
dataset_name = get_dataset_name()
dataset_path = f"hub://{st.session_state['activeloop_org_name']}/{dataset_name}-{st.session_state['chunk_size']}"
if deeplake.exists(dataset_path, token=st.session_state["activeloop_token"]):
with st.spinner("Loading vector store..."):
logger.info(f"Dataset '{dataset_path}' exists -> loading")
vector_store = DeepLake(
dataset_path=dataset_path,
read_only=True,
embedding_function=embeddings,
token=st.session_state["activeloop_token"],
)
else:
with st.spinner("Reading, embedding and uploading data to hub..."):
logger.info(f"Dataset '{dataset_path}' does not exist -> uploading")
docs = load_data_source()
vector_store = DeepLake.from_documents(
docs,
embeddings,
dataset_path=dataset_path,
token=st.session_state["activeloop_token"],
)
return vector_store
def get_chain() -> ConversationalRetrievalChain:
# create the langchain that will be called to generate responses
vector_store = get_vector_store()
retriever = vector_store.as_retriever()
# Search params "fetch_k" and "k" define how many documents are pulled from the hub
# and selected after the document matching to build the context
# that is fed to the model together with your prompt
search_kwargs = {
"maximal_marginal_relevance": True,
"distance_metric": "cos",
"fetch_k": st.session_state["fetch_k"],
"k": st.session_state["k"],
}
retriever.search_kwargs.update(search_kwargs)
model = get_model()
chain = ConversationalRetrievalChain.from_llm(
model,
retriever=retriever,
chain_type="stuff",
verbose=True,
# we limit the maximum number of used tokens
# to prevent running into the models token limit of 4096
max_tokens_limit=st.session_state["max_tokens"],
)
return chain
def update_chain() -> None:
# Build chain with parameters from session state and store it back
# Also delete chat history to not confuse the bot with old context
try:
st.session_state["chain"] = get_chain()
st.session_state["chat_history"] = []
msg = f"Data source '{st.session_state['data_source']}' is ready to go!"
logger.info(msg)
st.info(msg, icon=PAGE_ICON)
except Exception as e:
msg = f"Failed to build chain for data source '{st.session_state['data_source']}' with error: {e}"
logger.error(msg)
st.error(msg, icon=PAGE_ICON)
def update_usage(cb: OpenAICallbackHandler) -> None:
# Accumulate API call usage via callbacks
logger.info(f"Usage: {cb}")
callback_properties = [
"total_tokens",
"prompt_tokens",
"completion_tokens",
"total_cost",
]
for prop in callback_properties:
value = getattr(cb, prop, 0)
st.session_state["usage"].setdefault(prop, 0)
st.session_state["usage"][prop] += value
def generate_response(prompt: str) -> str:
# call the chain to generate responses and add them to the chat history
with st.spinner("Generating response"), get_openai_callback() as cb:
response = st.session_state["chain"](
{"question": prompt, "chat_history": st.session_state["chat_history"]}
)
update_usage(cb)
logger.info(f"Response: '{response}'")
st.session_state["chat_history"].append((prompt, response["answer"]))
return response["answer"]
Loading…
Cancel
Save