init commit
commit
2300731e56
@ -0,0 +1,3 @@
|
||||
data
|
||||
__pycache__
|
||||
.streamlit/secrets.toml
|
@ -0,0 +1,3 @@
|
||||
OPENAI_API_KEY = "your openai key"
|
||||
ACTIVELOOP_TOKEN = "your activeloop key"
|
||||
ACTIVELOOP_ORG_NAME = "your activeloop organization name"
|
@ -0,0 +1,23 @@
|
||||
# DataChad 🤖
|
||||
|
||||
This is an app that let's you ask questions about any data source by leveraging [embeddings](https://platform.openai.com/docs/guides/embeddings), [vector databases](https://www.activeloop.ai/), [large language models](https://platform.openai.com/docs/models/gpt-3-5) and last but not least [langchains](https://github.com/hwchase17/langchain)
|
||||
|
||||
## How does it work?
|
||||
|
||||
1. Upload any `file` or enter any `path` or `url`
|
||||
2. The data source is detected and loaded into text documents
|
||||
3. The text documents are embedded using openai embeddings
|
||||
4. The embeddings are stored as a vector dataset to a datalake
|
||||
5. A langchain is created consisting of a LLM model (`gpt-3.5-turbo` by default) and the embedding database index as retriever
|
||||
6. When sending questions to the bot this chain is used as context to answer your questions
|
||||
7. Finally the chat history is cached locally to enable a [ChatGPT](https://chat.openai.com/) like Q&A conversation
|
||||
|
||||
## Good to know
|
||||
|
||||
- As default context this git repository is taken so you can directly start asking question about its functionality without chosing an own data source.
|
||||
- To run locally or deploy somewhere, execute:
|
||||
|
||||
```cp .streamlit/secret.toml.template .streamlit/secret.toml```
|
||||
|
||||
and set necessary keys in the newly created secrets file. Another option is to manually set environment variables
|
||||
- Yes, Chad in `DataChad` refers to the well-known [meme](https://www.google.com/search?q=chad+meme)
|
@ -0,0 +1,97 @@
|
||||
import streamlit as st
|
||||
from streamlit_chat import message
|
||||
|
||||
from constants import APP_NAME, DEFAULT_DATA_SOURCE, PAGE_ICON
|
||||
from utils import (
|
||||
generate_response,
|
||||
get_chain,
|
||||
reset_data_source,
|
||||
save_uploaded_file,
|
||||
validate_keys,
|
||||
)
|
||||
|
||||
|
||||
# Page options and header
|
||||
st.set_option("client.showErrorDetails", True)
|
||||
st.set_page_config(page_title=APP_NAME, page_icon=PAGE_ICON)
|
||||
st.markdown(
|
||||
f"<h1 style='text-align: center;'>{APP_NAME} {PAGE_ICON} <br> I know all about your data!</h1>",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
# Initialise session state variables
|
||||
if "chat_history" not in st.session_state:
|
||||
st.session_state["chat_history"] = []
|
||||
if "generated" not in st.session_state:
|
||||
st.session_state["generated"] = []
|
||||
if "past" not in st.session_state:
|
||||
st.session_state["past"] = []
|
||||
if "auth_ok" not in st.session_state:
|
||||
st.session_state["auth_ok"] = False
|
||||
|
||||
|
||||
# Sidebar
|
||||
with st.sidebar:
|
||||
st.title("Authentication")
|
||||
with st.form("authentication"):
|
||||
openai_key = st.text_input("OpenAI API Key", type="password", key="openai_key")
|
||||
activeloop_token = st.text_input(
|
||||
"ActiveLoop Token", type="password", key="activeloop_token"
|
||||
)
|
||||
activeloop_org_name = st.text_input(
|
||||
"ActiveLoop Organisation Name", type="password", key="activeloop_org_name"
|
||||
)
|
||||
submitted = st.form_submit_button("Submit")
|
||||
if submitted:
|
||||
validate_keys(openai_key, activeloop_token, activeloop_org_name)
|
||||
|
||||
if not st.session_state["auth_ok"]:
|
||||
st.stop()
|
||||
|
||||
clear_button = st.button("Clear Conversation and Reset Data", key="clear")
|
||||
|
||||
# the chain can only be initialized after authentication is OK
|
||||
if "chain" not in st.session_state:
|
||||
st.session_state["chain"] = get_chain(DEFAULT_DATA_SOURCE)
|
||||
|
||||
if clear_button:
|
||||
# reset everything
|
||||
reset_data_source(DEFAULT_DATA_SOURCE)
|
||||
|
||||
# upload file or enter data source
|
||||
uploaded_file = st.file_uploader("Upload a file")
|
||||
data_source = st.text_input(
|
||||
"Enter any data source",
|
||||
placeholder="Any path or url pointing to a file or directory of files",
|
||||
)
|
||||
|
||||
if uploaded_file:
|
||||
print(f"uploaded file: '{uploaded_file.name}'")
|
||||
data_source = save_uploaded_file(uploaded_file)
|
||||
reset_data_source(data_source)
|
||||
|
||||
if data_source:
|
||||
print(f"data source provided: '{data_source}'")
|
||||
reset_data_source(data_source)
|
||||
|
||||
# container for chat history
|
||||
response_container = st.container()
|
||||
# container for text box
|
||||
container = st.container()
|
||||
|
||||
with container:
|
||||
with st.form(key="prompt_input", clear_on_submit=True):
|
||||
user_input = st.text_area("You:", key="input", height=100)
|
||||
submit_button = st.form_submit_button(label="Send")
|
||||
|
||||
if submit_button and user_input:
|
||||
output = generate_response(user_input)
|
||||
st.session_state["past"].append(user_input)
|
||||
st.session_state["generated"].append(output)
|
||||
|
||||
|
||||
if st.session_state["generated"]:
|
||||
with response_container:
|
||||
for i in range(len(st.session_state["generated"])):
|
||||
message(st.session_state["past"][i], is_user=True, key=str(i) + "_user")
|
||||
message(st.session_state["generated"][i], key=str(i))
|
@ -0,0 +1,8 @@
|
||||
from pathlib import Path
|
||||
|
||||
APP_NAME = "DataChad"
|
||||
MODEL = "gpt-3.5-turbo"
|
||||
PAGE_ICON = "🤖"
|
||||
|
||||
DATA_PATH = Path.cwd() / "data"
|
||||
DEFAULT_DATA_SOURCE = "git@github.com:gustavz/DataChad.git"
|
@ -0,0 +1,11 @@
|
||||
streamlit==1.22.0
|
||||
streamlit-chat==0.0.2.2
|
||||
deeplake==3.4.1
|
||||
openai==0.27.6
|
||||
langchain==0.0.164
|
||||
tiktoken==0.4.0
|
||||
unstructured==0.6.5
|
||||
pdf2image==1.16.3
|
||||
pytesseract==0.3.10
|
||||
beautifulsoup4==4.12.2
|
||||
bs4==0.0.1
|
@ -0,0 +1,214 @@
|
||||
import os
|
||||
import re
|
||||
|
||||
import deeplake
|
||||
import streamlit as st
|
||||
from langchain.chains import ConversationalRetrievalChain
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.document_loaders import (
|
||||
CSVLoader,
|
||||
DirectoryLoader,
|
||||
GitLoader,
|
||||
NotebookLoader,
|
||||
OnlinePDFLoader,
|
||||
PythonLoader,
|
||||
TextLoader,
|
||||
UnstructuredFileLoader,
|
||||
UnstructuredHTMLLoader,
|
||||
UnstructuredPDFLoader,
|
||||
UnstructuredWordDocumentLoader,
|
||||
WebBaseLoader,
|
||||
)
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.vectorstores import DeepLake
|
||||
|
||||
from constants import DATA_PATH, MODEL, PAGE_ICON
|
||||
|
||||
|
||||
def validate_keys(openai_key, activeloop_token, activeloop_org_name):
|
||||
# Validate all API related variables are set and correct
|
||||
# TODO: Do proper token/key validation, currently activeloop has none
|
||||
all_keys = [openai_key, activeloop_token, activeloop_org_name]
|
||||
if any(all_keys):
|
||||
print(f"{openai_key=}\n{activeloop_token=}\n{activeloop_org_name=}")
|
||||
if not all(all_keys):
|
||||
st.session_state["auth_ok"] = False
|
||||
st.error("Authentication failed", icon=PAGE_ICON)
|
||||
st.stop()
|
||||
os.environ["OPENAI_API_KEY"] = openai_key
|
||||
os.environ["ACTIVELOOP_TOKEN"] = activeloop_token
|
||||
os.environ["ACTIVELOOP_ORG_NAME"] = activeloop_org_name
|
||||
else:
|
||||
# Fallback for local development or deployments with provided credentials
|
||||
# either env variables or streamlit secrets need to be set
|
||||
try:
|
||||
assert os.environ.get("OPENAI_API_KEY")
|
||||
assert os.environ.get("ACTIVELOOP_TOKEN")
|
||||
assert os.environ.get("ACTIVELOOP_ORG_NAME")
|
||||
except:
|
||||
assert st.secrets.get("OPENAI_API_KEY")
|
||||
assert st.secrets.get("ACTIVELOOP_TOKEN")
|
||||
assert st.secrets.get("ACTIVELOOP_ORG_NAME")
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = st.secrets.get("OPENAI_API_KEY")
|
||||
os.environ["ACTIVELOOP_TOKEN"] = st.secrets.get("ACTIVELOOP_TOKEN")
|
||||
os.environ["ACTIVELOOP_ORG_NAME"] = st.secrets.get("ACTIVELOOP_ORG_NAME")
|
||||
st.session_state["auth_ok"] = True
|
||||
|
||||
|
||||
def save_uploaded_file(uploaded_file):
|
||||
# streamlit uploaded files need to be stored locally before
|
||||
# TODO: delete local files after they are uploaded to the datalake
|
||||
if not os.path.exists(DATA_PATH):
|
||||
os.makedirs(DATA_PATH)
|
||||
file_path = str(DATA_PATH / uploaded_file.name)
|
||||
uploaded_file.seek(0)
|
||||
file_bytes = uploaded_file.read()
|
||||
file = open(file_path, "wb")
|
||||
file.write(file_bytes)
|
||||
file.close()
|
||||
return file_path
|
||||
|
||||
|
||||
def load_git(data_source):
|
||||
# Thank you github for the "master" to "main" switch
|
||||
repo_name = data_source.split("/")[-1].split(".")[0]
|
||||
repo_path = str(DATA_PATH / repo_name)
|
||||
if os.path.exists(repo_path):
|
||||
data_source = None
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||
branches = ["main", "master"]
|
||||
for branch in branches:
|
||||
try:
|
||||
docs = GitLoader(repo_path, data_source, branch).load_and_split(
|
||||
text_splitter
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"error loading git: {e}")
|
||||
return docs
|
||||
|
||||
|
||||
def load_any_data_source(data_source):
|
||||
# ugly thing that decides how to load data
|
||||
is_text = data_source.endswith(".txt")
|
||||
is_web = data_source.startswith("http")
|
||||
is_pdf = data_source.endswith(".pdf")
|
||||
is_csv = data_source.endswith("csv")
|
||||
is_html = data_source.endswith(".html")
|
||||
is_git = data_source.endswith(".git")
|
||||
is_notebook = data_source.endswith(".ipynb")
|
||||
is_doc = data_source.endswith(".doc")
|
||||
is_py = data_source.endswith(".py")
|
||||
is_dir = os.path.isdir(data_source)
|
||||
is_file = os.path.isfile(data_source)
|
||||
|
||||
loader = None
|
||||
if is_dir:
|
||||
loader = DirectoryLoader(data_source, recursive=True)
|
||||
if is_git:
|
||||
return load_git(data_source)
|
||||
if is_web:
|
||||
if is_pdf:
|
||||
loader = OnlinePDFLoader(data_source)
|
||||
else:
|
||||
loader = WebBaseLoader(data_source)
|
||||
if is_file:
|
||||
if is_text:
|
||||
loader = TextLoader(data_source)
|
||||
elif is_notebook:
|
||||
loader = NotebookLoader(data_source)
|
||||
elif is_pdf:
|
||||
loader = UnstructuredPDFLoader(data_source)
|
||||
elif is_html:
|
||||
loader = UnstructuredHTMLLoader(data_source)
|
||||
elif is_doc:
|
||||
loader = UnstructuredWordDocumentLoader(data_source)
|
||||
elif is_csv:
|
||||
loader = CSVLoader(data_source, encoding="utf-8")
|
||||
elif is_py:
|
||||
loader = PythonLoader(data_source)
|
||||
else:
|
||||
loader = UnstructuredFileLoader(data_source)
|
||||
if loader:
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||
docs = loader.load_and_split(text_splitter)
|
||||
print(f"loaded {len(docs)} document chucks")
|
||||
return docs
|
||||
|
||||
error_msg = f"Failed to load {data_source}"
|
||||
st.error(error_msg, icon=PAGE_ICON)
|
||||
print(error_msg)
|
||||
st.stop()
|
||||
|
||||
|
||||
def clean_data_source_string(data_source):
|
||||
# replace all non-word characters with dashes
|
||||
# to get a string that can be used to create a datalake dataset
|
||||
dashed_string = re.sub(r"\W+", "-", data_source)
|
||||
cleaned_string = re.sub(r"--+", "- ", dashed_string).strip("-")
|
||||
return cleaned_string
|
||||
|
||||
|
||||
def setup_vector_store(data_source):
|
||||
# either load existing vector store or upload a new one to the datalake
|
||||
embeddings = OpenAIEmbeddings(disallowed_special=())
|
||||
data_source_name = clean_data_source_string(data_source)
|
||||
dataset_path = f"hub://{os.environ['ACTIVELOOP_ORG_NAME']}/{data_source_name}"
|
||||
if deeplake.exists(dataset_path):
|
||||
print(f"{dataset_path} exists -> loading")
|
||||
vector_store = DeepLake(
|
||||
dataset_path=dataset_path, read_only=True, embedding_function=embeddings
|
||||
)
|
||||
else:
|
||||
print(f"{dataset_path} does not exist -> uploading")
|
||||
docs = load_any_data_source(data_source)
|
||||
vector_store = DeepLake.from_documents(
|
||||
docs,
|
||||
embeddings,
|
||||
dataset_path=f"hub://{os.environ['ACTIVELOOP_ORG_NAME']}/{data_source_name}",
|
||||
)
|
||||
return vector_store
|
||||
|
||||
|
||||
def get_chain(data_source):
|
||||
# create the langchain that will be called to generate responses
|
||||
vector_store = setup_vector_store(data_source)
|
||||
retriever = vector_store.as_retriever()
|
||||
search_kwargs = {
|
||||
"distance_metric": "cos",
|
||||
"fetch_k": 20,
|
||||
"maximal_marginal_relevance": True,
|
||||
"k": 10,
|
||||
}
|
||||
retriever.search_kwargs.update(search_kwargs)
|
||||
model = ChatOpenAI(model_name=MODEL)
|
||||
chain = ConversationalRetrievalChain.from_llm(
|
||||
model,
|
||||
retriever=retriever,
|
||||
chain_type="stuff",
|
||||
verbose=True,
|
||||
max_tokens_limit=3375,
|
||||
)
|
||||
print(f"{data_source} is ready to go!")
|
||||
return chain
|
||||
|
||||
|
||||
def reset_data_source(data_source):
|
||||
# we need to reset all caches if a new data source is loaded
|
||||
# otherwise the langchain is confused and produces garbage
|
||||
st.session_state["past"] = []
|
||||
st.session_state["generated"] = []
|
||||
st.session_state["chat_history"] = []
|
||||
st.session_state["chain"] = get_chain(data_source)
|
||||
|
||||
|
||||
def generate_response(prompt):
|
||||
# call the chain to generate responses and add them to the chat history
|
||||
response = st.session_state["chain"](
|
||||
{"question": prompt, "chat_history": st.session_state["chat_history"]}
|
||||
)
|
||||
print(f"{response=}")
|
||||
st.session_state["chat_history"].append((prompt, response["answer"]))
|
||||
return response["answer"]
|
Loading…
Reference in New Issue