init commit

pull/1/head
Gustav von Zitzewitz 1 year ago
commit 2300731e56

3
.gitignore vendored

@ -0,0 +1,3 @@
data
__pycache__
.streamlit/secrets.toml

@ -0,0 +1,3 @@
OPENAI_API_KEY = "your openai key"
ACTIVELOOP_TOKEN = "your activeloop key"
ACTIVELOOP_ORG_NAME = "your activeloop organization name"

@ -0,0 +1,23 @@
# DataChad 🤖
This is an app that let's you ask questions about any data source by leveraging [embeddings](https://platform.openai.com/docs/guides/embeddings), [vector databases](https://www.activeloop.ai/), [large language models](https://platform.openai.com/docs/models/gpt-3-5) and last but not least [langchains](https://github.com/hwchase17/langchain)
## How does it work?
1. Upload any `file` or enter any `path` or `url`
2. The data source is detected and loaded into text documents
3. The text documents are embedded using openai embeddings
4. The embeddings are stored as a vector dataset to a datalake
5. A langchain is created consisting of a LLM model (`gpt-3.5-turbo` by default) and the embedding database index as retriever
6. When sending questions to the bot this chain is used as context to answer your questions
7. Finally the chat history is cached locally to enable a [ChatGPT](https://chat.openai.com/) like Q&A conversation
## Good to know
- As default context this git repository is taken so you can directly start asking question about its functionality without chosing an own data source.
- To run locally or deploy somewhere, execute:
```cp .streamlit/secret.toml.template .streamlit/secret.toml```
and set necessary keys in the newly created secrets file. Another option is to manually set environment variables
- Yes, Chad in `DataChad` refers to the well-known [meme](https://www.google.com/search?q=chad+meme)

@ -0,0 +1,97 @@
import streamlit as st
from streamlit_chat import message
from constants import APP_NAME, DEFAULT_DATA_SOURCE, PAGE_ICON
from utils import (
generate_response,
get_chain,
reset_data_source,
save_uploaded_file,
validate_keys,
)
# Page options and header
st.set_option("client.showErrorDetails", True)
st.set_page_config(page_title=APP_NAME, page_icon=PAGE_ICON)
st.markdown(
f"<h1 style='text-align: center;'>{APP_NAME} {PAGE_ICON} <br> I know all about your data!</h1>",
unsafe_allow_html=True,
)
# Initialise session state variables
if "chat_history" not in st.session_state:
st.session_state["chat_history"] = []
if "generated" not in st.session_state:
st.session_state["generated"] = []
if "past" not in st.session_state:
st.session_state["past"] = []
if "auth_ok" not in st.session_state:
st.session_state["auth_ok"] = False
# Sidebar
with st.sidebar:
st.title("Authentication")
with st.form("authentication"):
openai_key = st.text_input("OpenAI API Key", type="password", key="openai_key")
activeloop_token = st.text_input(
"ActiveLoop Token", type="password", key="activeloop_token"
)
activeloop_org_name = st.text_input(
"ActiveLoop Organisation Name", type="password", key="activeloop_org_name"
)
submitted = st.form_submit_button("Submit")
if submitted:
validate_keys(openai_key, activeloop_token, activeloop_org_name)
if not st.session_state["auth_ok"]:
st.stop()
clear_button = st.button("Clear Conversation and Reset Data", key="clear")
# the chain can only be initialized after authentication is OK
if "chain" not in st.session_state:
st.session_state["chain"] = get_chain(DEFAULT_DATA_SOURCE)
if clear_button:
# reset everything
reset_data_source(DEFAULT_DATA_SOURCE)
# upload file or enter data source
uploaded_file = st.file_uploader("Upload a file")
data_source = st.text_input(
"Enter any data source",
placeholder="Any path or url pointing to a file or directory of files",
)
if uploaded_file:
print(f"uploaded file: '{uploaded_file.name}'")
data_source = save_uploaded_file(uploaded_file)
reset_data_source(data_source)
if data_source:
print(f"data source provided: '{data_source}'")
reset_data_source(data_source)
# container for chat history
response_container = st.container()
# container for text box
container = st.container()
with container:
with st.form(key="prompt_input", clear_on_submit=True):
user_input = st.text_area("You:", key="input", height=100)
submit_button = st.form_submit_button(label="Send")
if submit_button and user_input:
output = generate_response(user_input)
st.session_state["past"].append(user_input)
st.session_state["generated"].append(output)
if st.session_state["generated"]:
with response_container:
for i in range(len(st.session_state["generated"])):
message(st.session_state["past"][i], is_user=True, key=str(i) + "_user")
message(st.session_state["generated"][i], key=str(i))

@ -0,0 +1,8 @@
from pathlib import Path
APP_NAME = "DataChad"
MODEL = "gpt-3.5-turbo"
PAGE_ICON = "🤖"
DATA_PATH = Path.cwd() / "data"
DEFAULT_DATA_SOURCE = "git@github.com:gustavz/DataChad.git"

@ -0,0 +1,11 @@
streamlit==1.22.0
streamlit-chat==0.0.2.2
deeplake==3.4.1
openai==0.27.6
langchain==0.0.164
tiktoken==0.4.0
unstructured==0.6.5
pdf2image==1.16.3
pytesseract==0.3.10
beautifulsoup4==4.12.2
bs4==0.0.1

@ -0,0 +1,214 @@
import os
import re
import deeplake
import streamlit as st
from langchain.chains import ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import (
CSVLoader,
DirectoryLoader,
GitLoader,
NotebookLoader,
OnlinePDFLoader,
PythonLoader,
TextLoader,
UnstructuredFileLoader,
UnstructuredHTMLLoader,
UnstructuredPDFLoader,
UnstructuredWordDocumentLoader,
WebBaseLoader,
)
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import DeepLake
from constants import DATA_PATH, MODEL, PAGE_ICON
def validate_keys(openai_key, activeloop_token, activeloop_org_name):
# Validate all API related variables are set and correct
# TODO: Do proper token/key validation, currently activeloop has none
all_keys = [openai_key, activeloop_token, activeloop_org_name]
if any(all_keys):
print(f"{openai_key=}\n{activeloop_token=}\n{activeloop_org_name=}")
if not all(all_keys):
st.session_state["auth_ok"] = False
st.error("Authentication failed", icon=PAGE_ICON)
st.stop()
os.environ["OPENAI_API_KEY"] = openai_key
os.environ["ACTIVELOOP_TOKEN"] = activeloop_token
os.environ["ACTIVELOOP_ORG_NAME"] = activeloop_org_name
else:
# Fallback for local development or deployments with provided credentials
# either env variables or streamlit secrets need to be set
try:
assert os.environ.get("OPENAI_API_KEY")
assert os.environ.get("ACTIVELOOP_TOKEN")
assert os.environ.get("ACTIVELOOP_ORG_NAME")
except:
assert st.secrets.get("OPENAI_API_KEY")
assert st.secrets.get("ACTIVELOOP_TOKEN")
assert st.secrets.get("ACTIVELOOP_ORG_NAME")
os.environ["OPENAI_API_KEY"] = st.secrets.get("OPENAI_API_KEY")
os.environ["ACTIVELOOP_TOKEN"] = st.secrets.get("ACTIVELOOP_TOKEN")
os.environ["ACTIVELOOP_ORG_NAME"] = st.secrets.get("ACTIVELOOP_ORG_NAME")
st.session_state["auth_ok"] = True
def save_uploaded_file(uploaded_file):
# streamlit uploaded files need to be stored locally before
# TODO: delete local files after they are uploaded to the datalake
if not os.path.exists(DATA_PATH):
os.makedirs(DATA_PATH)
file_path = str(DATA_PATH / uploaded_file.name)
uploaded_file.seek(0)
file_bytes = uploaded_file.read()
file = open(file_path, "wb")
file.write(file_bytes)
file.close()
return file_path
def load_git(data_source):
# Thank you github for the "master" to "main" switch
repo_name = data_source.split("/")[-1].split(".")[0]
repo_path = str(DATA_PATH / repo_name)
if os.path.exists(repo_path):
data_source = None
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
branches = ["main", "master"]
for branch in branches:
try:
docs = GitLoader(repo_path, data_source, branch).load_and_split(
text_splitter
)
except Exception as e:
print(f"error loading git: {e}")
return docs
def load_any_data_source(data_source):
# ugly thing that decides how to load data
is_text = data_source.endswith(".txt")
is_web = data_source.startswith("http")
is_pdf = data_source.endswith(".pdf")
is_csv = data_source.endswith("csv")
is_html = data_source.endswith(".html")
is_git = data_source.endswith(".git")
is_notebook = data_source.endswith(".ipynb")
is_doc = data_source.endswith(".doc")
is_py = data_source.endswith(".py")
is_dir = os.path.isdir(data_source)
is_file = os.path.isfile(data_source)
loader = None
if is_dir:
loader = DirectoryLoader(data_source, recursive=True)
if is_git:
return load_git(data_source)
if is_web:
if is_pdf:
loader = OnlinePDFLoader(data_source)
else:
loader = WebBaseLoader(data_source)
if is_file:
if is_text:
loader = TextLoader(data_source)
elif is_notebook:
loader = NotebookLoader(data_source)
elif is_pdf:
loader = UnstructuredPDFLoader(data_source)
elif is_html:
loader = UnstructuredHTMLLoader(data_source)
elif is_doc:
loader = UnstructuredWordDocumentLoader(data_source)
elif is_csv:
loader = CSVLoader(data_source, encoding="utf-8")
elif is_py:
loader = PythonLoader(data_source)
else:
loader = UnstructuredFileLoader(data_source)
if loader:
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = loader.load_and_split(text_splitter)
print(f"loaded {len(docs)} document chucks")
return docs
error_msg = f"Failed to load {data_source}"
st.error(error_msg, icon=PAGE_ICON)
print(error_msg)
st.stop()
def clean_data_source_string(data_source):
# replace all non-word characters with dashes
# to get a string that can be used to create a datalake dataset
dashed_string = re.sub(r"\W+", "-", data_source)
cleaned_string = re.sub(r"--+", "- ", dashed_string).strip("-")
return cleaned_string
def setup_vector_store(data_source):
# either load existing vector store or upload a new one to the datalake
embeddings = OpenAIEmbeddings(disallowed_special=())
data_source_name = clean_data_source_string(data_source)
dataset_path = f"hub://{os.environ['ACTIVELOOP_ORG_NAME']}/{data_source_name}"
if deeplake.exists(dataset_path):
print(f"{dataset_path} exists -> loading")
vector_store = DeepLake(
dataset_path=dataset_path, read_only=True, embedding_function=embeddings
)
else:
print(f"{dataset_path} does not exist -> uploading")
docs = load_any_data_source(data_source)
vector_store = DeepLake.from_documents(
docs,
embeddings,
dataset_path=f"hub://{os.environ['ACTIVELOOP_ORG_NAME']}/{data_source_name}",
)
return vector_store
def get_chain(data_source):
# create the langchain that will be called to generate responses
vector_store = setup_vector_store(data_source)
retriever = vector_store.as_retriever()
search_kwargs = {
"distance_metric": "cos",
"fetch_k": 20,
"maximal_marginal_relevance": True,
"k": 10,
}
retriever.search_kwargs.update(search_kwargs)
model = ChatOpenAI(model_name=MODEL)
chain = ConversationalRetrievalChain.from_llm(
model,
retriever=retriever,
chain_type="stuff",
verbose=True,
max_tokens_limit=3375,
)
print(f"{data_source} is ready to go!")
return chain
def reset_data_source(data_source):
# we need to reset all caches if a new data source is loaded
# otherwise the langchain is confused and produces garbage
st.session_state["past"] = []
st.session_state["generated"] = []
st.session_state["chat_history"] = []
st.session_state["chain"] = get_chain(data_source)
def generate_response(prompt):
# call the chain to generate responses and add them to the chat history
response = st.session_state["chain"](
{"question": prompt, "chat_history": st.session_state["chat_history"]}
)
print(f"{response=}")
st.session_state["chat_history"].append((prompt, response["answer"]))
return response["answer"]
Loading…
Cancel
Save