init commit
commit
2300731e56
@ -0,0 +1,3 @@
|
|||||||
|
data
|
||||||
|
__pycache__
|
||||||
|
.streamlit/secrets.toml
|
@ -0,0 +1,3 @@
|
|||||||
|
OPENAI_API_KEY = "your openai key"
|
||||||
|
ACTIVELOOP_TOKEN = "your activeloop key"
|
||||||
|
ACTIVELOOP_ORG_NAME = "your activeloop organization name"
|
@ -0,0 +1,23 @@
|
|||||||
|
# DataChad 🤖
|
||||||
|
|
||||||
|
This is an app that let's you ask questions about any data source by leveraging [embeddings](https://platform.openai.com/docs/guides/embeddings), [vector databases](https://www.activeloop.ai/), [large language models](https://platform.openai.com/docs/models/gpt-3-5) and last but not least [langchains](https://github.com/hwchase17/langchain)
|
||||||
|
|
||||||
|
## How does it work?
|
||||||
|
|
||||||
|
1. Upload any `file` or enter any `path` or `url`
|
||||||
|
2. The data source is detected and loaded into text documents
|
||||||
|
3. The text documents are embedded using openai embeddings
|
||||||
|
4. The embeddings are stored as a vector dataset to a datalake
|
||||||
|
5. A langchain is created consisting of a LLM model (`gpt-3.5-turbo` by default) and the embedding database index as retriever
|
||||||
|
6. When sending questions to the bot this chain is used as context to answer your questions
|
||||||
|
7. Finally the chat history is cached locally to enable a [ChatGPT](https://chat.openai.com/) like Q&A conversation
|
||||||
|
|
||||||
|
## Good to know
|
||||||
|
|
||||||
|
- As default context this git repository is taken so you can directly start asking question about its functionality without chosing an own data source.
|
||||||
|
- To run locally or deploy somewhere, execute:
|
||||||
|
|
||||||
|
```cp .streamlit/secret.toml.template .streamlit/secret.toml```
|
||||||
|
|
||||||
|
and set necessary keys in the newly created secrets file. Another option is to manually set environment variables
|
||||||
|
- Yes, Chad in `DataChad` refers to the well-known [meme](https://www.google.com/search?q=chad+meme)
|
@ -0,0 +1,97 @@
|
|||||||
|
import streamlit as st
|
||||||
|
from streamlit_chat import message
|
||||||
|
|
||||||
|
from constants import APP_NAME, DEFAULT_DATA_SOURCE, PAGE_ICON
|
||||||
|
from utils import (
|
||||||
|
generate_response,
|
||||||
|
get_chain,
|
||||||
|
reset_data_source,
|
||||||
|
save_uploaded_file,
|
||||||
|
validate_keys,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Page options and header
|
||||||
|
st.set_option("client.showErrorDetails", True)
|
||||||
|
st.set_page_config(page_title=APP_NAME, page_icon=PAGE_ICON)
|
||||||
|
st.markdown(
|
||||||
|
f"<h1 style='text-align: center;'>{APP_NAME} {PAGE_ICON} <br> I know all about your data!</h1>",
|
||||||
|
unsafe_allow_html=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialise session state variables
|
||||||
|
if "chat_history" not in st.session_state:
|
||||||
|
st.session_state["chat_history"] = []
|
||||||
|
if "generated" not in st.session_state:
|
||||||
|
st.session_state["generated"] = []
|
||||||
|
if "past" not in st.session_state:
|
||||||
|
st.session_state["past"] = []
|
||||||
|
if "auth_ok" not in st.session_state:
|
||||||
|
st.session_state["auth_ok"] = False
|
||||||
|
|
||||||
|
|
||||||
|
# Sidebar
|
||||||
|
with st.sidebar:
|
||||||
|
st.title("Authentication")
|
||||||
|
with st.form("authentication"):
|
||||||
|
openai_key = st.text_input("OpenAI API Key", type="password", key="openai_key")
|
||||||
|
activeloop_token = st.text_input(
|
||||||
|
"ActiveLoop Token", type="password", key="activeloop_token"
|
||||||
|
)
|
||||||
|
activeloop_org_name = st.text_input(
|
||||||
|
"ActiveLoop Organisation Name", type="password", key="activeloop_org_name"
|
||||||
|
)
|
||||||
|
submitted = st.form_submit_button("Submit")
|
||||||
|
if submitted:
|
||||||
|
validate_keys(openai_key, activeloop_token, activeloop_org_name)
|
||||||
|
|
||||||
|
if not st.session_state["auth_ok"]:
|
||||||
|
st.stop()
|
||||||
|
|
||||||
|
clear_button = st.button("Clear Conversation and Reset Data", key="clear")
|
||||||
|
|
||||||
|
# the chain can only be initialized after authentication is OK
|
||||||
|
if "chain" not in st.session_state:
|
||||||
|
st.session_state["chain"] = get_chain(DEFAULT_DATA_SOURCE)
|
||||||
|
|
||||||
|
if clear_button:
|
||||||
|
# reset everything
|
||||||
|
reset_data_source(DEFAULT_DATA_SOURCE)
|
||||||
|
|
||||||
|
# upload file or enter data source
|
||||||
|
uploaded_file = st.file_uploader("Upload a file")
|
||||||
|
data_source = st.text_input(
|
||||||
|
"Enter any data source",
|
||||||
|
placeholder="Any path or url pointing to a file or directory of files",
|
||||||
|
)
|
||||||
|
|
||||||
|
if uploaded_file:
|
||||||
|
print(f"uploaded file: '{uploaded_file.name}'")
|
||||||
|
data_source = save_uploaded_file(uploaded_file)
|
||||||
|
reset_data_source(data_source)
|
||||||
|
|
||||||
|
if data_source:
|
||||||
|
print(f"data source provided: '{data_source}'")
|
||||||
|
reset_data_source(data_source)
|
||||||
|
|
||||||
|
# container for chat history
|
||||||
|
response_container = st.container()
|
||||||
|
# container for text box
|
||||||
|
container = st.container()
|
||||||
|
|
||||||
|
with container:
|
||||||
|
with st.form(key="prompt_input", clear_on_submit=True):
|
||||||
|
user_input = st.text_area("You:", key="input", height=100)
|
||||||
|
submit_button = st.form_submit_button(label="Send")
|
||||||
|
|
||||||
|
if submit_button and user_input:
|
||||||
|
output = generate_response(user_input)
|
||||||
|
st.session_state["past"].append(user_input)
|
||||||
|
st.session_state["generated"].append(output)
|
||||||
|
|
||||||
|
|
||||||
|
if st.session_state["generated"]:
|
||||||
|
with response_container:
|
||||||
|
for i in range(len(st.session_state["generated"])):
|
||||||
|
message(st.session_state["past"][i], is_user=True, key=str(i) + "_user")
|
||||||
|
message(st.session_state["generated"][i], key=str(i))
|
@ -0,0 +1,8 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
APP_NAME = "DataChad"
|
||||||
|
MODEL = "gpt-3.5-turbo"
|
||||||
|
PAGE_ICON = "🤖"
|
||||||
|
|
||||||
|
DATA_PATH = Path.cwd() / "data"
|
||||||
|
DEFAULT_DATA_SOURCE = "git@github.com:gustavz/DataChad.git"
|
@ -0,0 +1,11 @@
|
|||||||
|
streamlit==1.22.0
|
||||||
|
streamlit-chat==0.0.2.2
|
||||||
|
deeplake==3.4.1
|
||||||
|
openai==0.27.6
|
||||||
|
langchain==0.0.164
|
||||||
|
tiktoken==0.4.0
|
||||||
|
unstructured==0.6.5
|
||||||
|
pdf2image==1.16.3
|
||||||
|
pytesseract==0.3.10
|
||||||
|
beautifulsoup4==4.12.2
|
||||||
|
bs4==0.0.1
|
@ -0,0 +1,214 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
import deeplake
|
||||||
|
import streamlit as st
|
||||||
|
from langchain.chains import ConversationalRetrievalChain
|
||||||
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
from langchain.document_loaders import (
|
||||||
|
CSVLoader,
|
||||||
|
DirectoryLoader,
|
||||||
|
GitLoader,
|
||||||
|
NotebookLoader,
|
||||||
|
OnlinePDFLoader,
|
||||||
|
PythonLoader,
|
||||||
|
TextLoader,
|
||||||
|
UnstructuredFileLoader,
|
||||||
|
UnstructuredHTMLLoader,
|
||||||
|
UnstructuredPDFLoader,
|
||||||
|
UnstructuredWordDocumentLoader,
|
||||||
|
WebBaseLoader,
|
||||||
|
)
|
||||||
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
from langchain.vectorstores import DeepLake
|
||||||
|
|
||||||
|
from constants import DATA_PATH, MODEL, PAGE_ICON
|
||||||
|
|
||||||
|
|
||||||
|
def validate_keys(openai_key, activeloop_token, activeloop_org_name):
|
||||||
|
# Validate all API related variables are set and correct
|
||||||
|
# TODO: Do proper token/key validation, currently activeloop has none
|
||||||
|
all_keys = [openai_key, activeloop_token, activeloop_org_name]
|
||||||
|
if any(all_keys):
|
||||||
|
print(f"{openai_key=}\n{activeloop_token=}\n{activeloop_org_name=}")
|
||||||
|
if not all(all_keys):
|
||||||
|
st.session_state["auth_ok"] = False
|
||||||
|
st.error("Authentication failed", icon=PAGE_ICON)
|
||||||
|
st.stop()
|
||||||
|
os.environ["OPENAI_API_KEY"] = openai_key
|
||||||
|
os.environ["ACTIVELOOP_TOKEN"] = activeloop_token
|
||||||
|
os.environ["ACTIVELOOP_ORG_NAME"] = activeloop_org_name
|
||||||
|
else:
|
||||||
|
# Fallback for local development or deployments with provided credentials
|
||||||
|
# either env variables or streamlit secrets need to be set
|
||||||
|
try:
|
||||||
|
assert os.environ.get("OPENAI_API_KEY")
|
||||||
|
assert os.environ.get("ACTIVELOOP_TOKEN")
|
||||||
|
assert os.environ.get("ACTIVELOOP_ORG_NAME")
|
||||||
|
except:
|
||||||
|
assert st.secrets.get("OPENAI_API_KEY")
|
||||||
|
assert st.secrets.get("ACTIVELOOP_TOKEN")
|
||||||
|
assert st.secrets.get("ACTIVELOOP_ORG_NAME")
|
||||||
|
|
||||||
|
os.environ["OPENAI_API_KEY"] = st.secrets.get("OPENAI_API_KEY")
|
||||||
|
os.environ["ACTIVELOOP_TOKEN"] = st.secrets.get("ACTIVELOOP_TOKEN")
|
||||||
|
os.environ["ACTIVELOOP_ORG_NAME"] = st.secrets.get("ACTIVELOOP_ORG_NAME")
|
||||||
|
st.session_state["auth_ok"] = True
|
||||||
|
|
||||||
|
|
||||||
|
def save_uploaded_file(uploaded_file):
|
||||||
|
# streamlit uploaded files need to be stored locally before
|
||||||
|
# TODO: delete local files after they are uploaded to the datalake
|
||||||
|
if not os.path.exists(DATA_PATH):
|
||||||
|
os.makedirs(DATA_PATH)
|
||||||
|
file_path = str(DATA_PATH / uploaded_file.name)
|
||||||
|
uploaded_file.seek(0)
|
||||||
|
file_bytes = uploaded_file.read()
|
||||||
|
file = open(file_path, "wb")
|
||||||
|
file.write(file_bytes)
|
||||||
|
file.close()
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
|
def load_git(data_source):
|
||||||
|
# Thank you github for the "master" to "main" switch
|
||||||
|
repo_name = data_source.split("/")[-1].split(".")[0]
|
||||||
|
repo_path = str(DATA_PATH / repo_name)
|
||||||
|
if os.path.exists(repo_path):
|
||||||
|
data_source = None
|
||||||
|
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||||
|
branches = ["main", "master"]
|
||||||
|
for branch in branches:
|
||||||
|
try:
|
||||||
|
docs = GitLoader(repo_path, data_source, branch).load_and_split(
|
||||||
|
text_splitter
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"error loading git: {e}")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
def load_any_data_source(data_source):
|
||||||
|
# ugly thing that decides how to load data
|
||||||
|
is_text = data_source.endswith(".txt")
|
||||||
|
is_web = data_source.startswith("http")
|
||||||
|
is_pdf = data_source.endswith(".pdf")
|
||||||
|
is_csv = data_source.endswith("csv")
|
||||||
|
is_html = data_source.endswith(".html")
|
||||||
|
is_git = data_source.endswith(".git")
|
||||||
|
is_notebook = data_source.endswith(".ipynb")
|
||||||
|
is_doc = data_source.endswith(".doc")
|
||||||
|
is_py = data_source.endswith(".py")
|
||||||
|
is_dir = os.path.isdir(data_source)
|
||||||
|
is_file = os.path.isfile(data_source)
|
||||||
|
|
||||||
|
loader = None
|
||||||
|
if is_dir:
|
||||||
|
loader = DirectoryLoader(data_source, recursive=True)
|
||||||
|
if is_git:
|
||||||
|
return load_git(data_source)
|
||||||
|
if is_web:
|
||||||
|
if is_pdf:
|
||||||
|
loader = OnlinePDFLoader(data_source)
|
||||||
|
else:
|
||||||
|
loader = WebBaseLoader(data_source)
|
||||||
|
if is_file:
|
||||||
|
if is_text:
|
||||||
|
loader = TextLoader(data_source)
|
||||||
|
elif is_notebook:
|
||||||
|
loader = NotebookLoader(data_source)
|
||||||
|
elif is_pdf:
|
||||||
|
loader = UnstructuredPDFLoader(data_source)
|
||||||
|
elif is_html:
|
||||||
|
loader = UnstructuredHTMLLoader(data_source)
|
||||||
|
elif is_doc:
|
||||||
|
loader = UnstructuredWordDocumentLoader(data_source)
|
||||||
|
elif is_csv:
|
||||||
|
loader = CSVLoader(data_source, encoding="utf-8")
|
||||||
|
elif is_py:
|
||||||
|
loader = PythonLoader(data_source)
|
||||||
|
else:
|
||||||
|
loader = UnstructuredFileLoader(data_source)
|
||||||
|
if loader:
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||||
|
docs = loader.load_and_split(text_splitter)
|
||||||
|
print(f"loaded {len(docs)} document chucks")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
error_msg = f"Failed to load {data_source}"
|
||||||
|
st.error(error_msg, icon=PAGE_ICON)
|
||||||
|
print(error_msg)
|
||||||
|
st.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def clean_data_source_string(data_source):
|
||||||
|
# replace all non-word characters with dashes
|
||||||
|
# to get a string that can be used to create a datalake dataset
|
||||||
|
dashed_string = re.sub(r"\W+", "-", data_source)
|
||||||
|
cleaned_string = re.sub(r"--+", "- ", dashed_string).strip("-")
|
||||||
|
return cleaned_string
|
||||||
|
|
||||||
|
|
||||||
|
def setup_vector_store(data_source):
|
||||||
|
# either load existing vector store or upload a new one to the datalake
|
||||||
|
embeddings = OpenAIEmbeddings(disallowed_special=())
|
||||||
|
data_source_name = clean_data_source_string(data_source)
|
||||||
|
dataset_path = f"hub://{os.environ['ACTIVELOOP_ORG_NAME']}/{data_source_name}"
|
||||||
|
if deeplake.exists(dataset_path):
|
||||||
|
print(f"{dataset_path} exists -> loading")
|
||||||
|
vector_store = DeepLake(
|
||||||
|
dataset_path=dataset_path, read_only=True, embedding_function=embeddings
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(f"{dataset_path} does not exist -> uploading")
|
||||||
|
docs = load_any_data_source(data_source)
|
||||||
|
vector_store = DeepLake.from_documents(
|
||||||
|
docs,
|
||||||
|
embeddings,
|
||||||
|
dataset_path=f"hub://{os.environ['ACTIVELOOP_ORG_NAME']}/{data_source_name}",
|
||||||
|
)
|
||||||
|
return vector_store
|
||||||
|
|
||||||
|
|
||||||
|
def get_chain(data_source):
|
||||||
|
# create the langchain that will be called to generate responses
|
||||||
|
vector_store = setup_vector_store(data_source)
|
||||||
|
retriever = vector_store.as_retriever()
|
||||||
|
search_kwargs = {
|
||||||
|
"distance_metric": "cos",
|
||||||
|
"fetch_k": 20,
|
||||||
|
"maximal_marginal_relevance": True,
|
||||||
|
"k": 10,
|
||||||
|
}
|
||||||
|
retriever.search_kwargs.update(search_kwargs)
|
||||||
|
model = ChatOpenAI(model_name=MODEL)
|
||||||
|
chain = ConversationalRetrievalChain.from_llm(
|
||||||
|
model,
|
||||||
|
retriever=retriever,
|
||||||
|
chain_type="stuff",
|
||||||
|
verbose=True,
|
||||||
|
max_tokens_limit=3375,
|
||||||
|
)
|
||||||
|
print(f"{data_source} is ready to go!")
|
||||||
|
return chain
|
||||||
|
|
||||||
|
|
||||||
|
def reset_data_source(data_source):
|
||||||
|
# we need to reset all caches if a new data source is loaded
|
||||||
|
# otherwise the langchain is confused and produces garbage
|
||||||
|
st.session_state["past"] = []
|
||||||
|
st.session_state["generated"] = []
|
||||||
|
st.session_state["chat_history"] = []
|
||||||
|
st.session_state["chain"] = get_chain(data_source)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_response(prompt):
|
||||||
|
# call the chain to generate responses and add them to the chat history
|
||||||
|
response = st.session_state["chain"](
|
||||||
|
{"question": prompt, "chat_history": st.session_state["chat_history"]}
|
||||||
|
)
|
||||||
|
print(f"{response=}")
|
||||||
|
st.session_state["chat_history"].append((prompt, response["answer"]))
|
||||||
|
return response["answer"]
|
Loading…
Reference in New Issue