|
|
|
@ -9,41 +9,36 @@ import deeplake
|
|
|
|
|
import openai
|
|
|
|
|
import streamlit as st
|
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
|
from langchain.base_language import BaseLanguageModel
|
|
|
|
|
from langchain.callbacks import OpenAICallbackHandler, get_openai_callback
|
|
|
|
|
from langchain.chains import ConversationalRetrievalChain
|
|
|
|
|
from langchain.chat_models import ChatOpenAI
|
|
|
|
|
from langchain.document_loaders import (
|
|
|
|
|
CSVLoader,
|
|
|
|
|
DirectoryLoader,
|
|
|
|
|
EverNoteLoader,
|
|
|
|
|
GitLoader,
|
|
|
|
|
NotebookLoader,
|
|
|
|
|
OnlinePDFLoader,
|
|
|
|
|
PDFMinerLoader,
|
|
|
|
|
PythonLoader,
|
|
|
|
|
TextLoader,
|
|
|
|
|
UnstructuredEPubLoader,
|
|
|
|
|
UnstructuredFileLoader,
|
|
|
|
|
UnstructuredHTMLLoader,
|
|
|
|
|
UnstructuredPDFLoader,
|
|
|
|
|
UnstructuredMarkdownLoader,
|
|
|
|
|
UnstructuredODTLoader,
|
|
|
|
|
UnstructuredPowerPointLoader,
|
|
|
|
|
UnstructuredWordDocumentLoader,
|
|
|
|
|
WebBaseLoader,
|
|
|
|
|
)
|
|
|
|
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
|
|
|
|
from langchain.document_loaders.base import BaseLoader
|
|
|
|
|
from langchain.embeddings.openai import Embeddings, OpenAIEmbeddings
|
|
|
|
|
from langchain.schema import Document
|
|
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
|
|
from langchain.vectorstores import DeepLake, VectorStore
|
|
|
|
|
from streamlit.runtime.uploaded_file_manager import UploadedFile
|
|
|
|
|
|
|
|
|
|
from constants import (
|
|
|
|
|
APP_NAME,
|
|
|
|
|
CHUNK_SIZE,
|
|
|
|
|
DATA_PATH,
|
|
|
|
|
FETCH_K,
|
|
|
|
|
MAX_TOKENS,
|
|
|
|
|
MODEL,
|
|
|
|
|
PAGE_ICON,
|
|
|
|
|
REPO_URL,
|
|
|
|
|
TEMPERATURE,
|
|
|
|
|
K,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
from constants import APP_NAME, DATA_PATH, PAGE_ICON, PROJECT_URL
|
|
|
|
|
|
|
|
|
|
# loads environment variables
|
|
|
|
|
load_dotenv()
|
|
|
|
@ -116,66 +111,10 @@ def authenticate(
|
|
|
|
|
logger.info("Authentification successful!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def advanced_options_form() -> None:
|
|
|
|
|
# Input Form that takes advanced options and rebuilds chain with them
|
|
|
|
|
advanced_options = st.checkbox(
|
|
|
|
|
"Advanced Options", help="Caution! This may break things!"
|
|
|
|
|
)
|
|
|
|
|
if advanced_options:
|
|
|
|
|
with st.form("advanced_options"):
|
|
|
|
|
temperature = st.slider(
|
|
|
|
|
"temperature",
|
|
|
|
|
min_value=0.0,
|
|
|
|
|
max_value=1.0,
|
|
|
|
|
value=TEMPERATURE,
|
|
|
|
|
help="Controls the randomness of the language model output",
|
|
|
|
|
)
|
|
|
|
|
col1, col2 = st.columns(2)
|
|
|
|
|
fetch_k = col1.number_input(
|
|
|
|
|
"k_fetch",
|
|
|
|
|
min_value=1,
|
|
|
|
|
max_value=1000,
|
|
|
|
|
value=FETCH_K,
|
|
|
|
|
help="The number of documents to pull from the vector database",
|
|
|
|
|
)
|
|
|
|
|
k = col2.number_input(
|
|
|
|
|
"k",
|
|
|
|
|
min_value=1,
|
|
|
|
|
max_value=100,
|
|
|
|
|
value=K,
|
|
|
|
|
help="The number of most similar documents to build the context from",
|
|
|
|
|
)
|
|
|
|
|
chunk_size = col1.number_input(
|
|
|
|
|
"chunk_size",
|
|
|
|
|
min_value=1,
|
|
|
|
|
max_value=100000,
|
|
|
|
|
value=CHUNK_SIZE,
|
|
|
|
|
help=(
|
|
|
|
|
"The size at which the text is divided into smaller chunks "
|
|
|
|
|
"before being embedded.\n\nChanging this parameter makes re-embedding "
|
|
|
|
|
"and re-uploading the data to the database necessary "
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
max_tokens = col2.number_input(
|
|
|
|
|
"max_tokens",
|
|
|
|
|
min_value=1,
|
|
|
|
|
max_value=4069,
|
|
|
|
|
value=MAX_TOKENS,
|
|
|
|
|
help="Limits the documents returned from database based on number of tokens",
|
|
|
|
|
)
|
|
|
|
|
applied = st.form_submit_button("Apply")
|
|
|
|
|
if applied:
|
|
|
|
|
st.session_state["k"] = k
|
|
|
|
|
st.session_state["fetch_k"] = fetch_k
|
|
|
|
|
st.session_state["chunk_size"] = chunk_size
|
|
|
|
|
st.session_state["temperature"] = temperature
|
|
|
|
|
st.session_state["max_tokens"] = max_tokens
|
|
|
|
|
update_chain()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_uploaded_file(uploaded_file: UploadedFile) -> str:
|
|
|
|
|
def save_uploaded_file() -> str:
|
|
|
|
|
# streamlit uploaded files need to be stored locally
|
|
|
|
|
# before embedded and uploaded to the hub
|
|
|
|
|
uploaded_file = st.session_state["uploaded_file"]
|
|
|
|
|
if not os.path.exists(DATA_PATH):
|
|
|
|
|
os.makedirs(DATA_PATH)
|
|
|
|
|
file_path = str(DATA_PATH / uploaded_file.name)
|
|
|
|
@ -188,128 +127,162 @@ def save_uploaded_file(uploaded_file: UploadedFile) -> str:
|
|
|
|
|
return file_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def delete_uploaded_file(uploaded_file: UploadedFile) -> None:
|
|
|
|
|
def delete_uploaded_file() -> None:
|
|
|
|
|
# cleanup locally stored files
|
|
|
|
|
file_path = DATA_PATH / uploaded_file.name
|
|
|
|
|
file_path = DATA_PATH / st.session_state["uploaded_file"].name
|
|
|
|
|
if os.path.exists(DATA_PATH):
|
|
|
|
|
os.remove(file_path)
|
|
|
|
|
logger.info(f"Removed: {file_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def handle_load_error(e: str = None) -> None:
|
|
|
|
|
error_msg = f"Failed to load '{st.session_state['data_source']}':\n\n{e}"
|
|
|
|
|
st.error(error_msg, icon=PAGE_ICON)
|
|
|
|
|
logger.error(error_msg)
|
|
|
|
|
st.stop()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_git(data_source: str, chunk_size: int = CHUNK_SIZE) -> List[Document]:
|
|
|
|
|
# We need to try both common main branches
|
|
|
|
|
# Thank you github for the "master" to "main" switch
|
|
|
|
|
# we need to make sure the data path exists
|
|
|
|
|
if not os.path.exists(DATA_PATH):
|
|
|
|
|
os.makedirs(DATA_PATH)
|
|
|
|
|
repo_name = data_source.split("/")[-1].split(".")[0]
|
|
|
|
|
repo_path = str(DATA_PATH / repo_name)
|
|
|
|
|
clone_url = data_source
|
|
|
|
|
if os.path.exists(repo_path):
|
|
|
|
|
clone_url = None
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(
|
|
|
|
|
chunk_size=chunk_size, chunk_overlap=0
|
|
|
|
|
)
|
|
|
|
|
branches = ["main", "master"]
|
|
|
|
|
for branch in branches:
|
|
|
|
|
class AutoGitLoader:
|
|
|
|
|
def __init__(self, data_source: str) -> None:
|
|
|
|
|
self.data_source = data_source
|
|
|
|
|
|
|
|
|
|
def load(self) -> List[Document]:
|
|
|
|
|
# We need to try both common main branches
|
|
|
|
|
# Thank you github for the "master" to "main" switch
|
|
|
|
|
# we need to make sure the data path exists
|
|
|
|
|
if not os.path.exists(DATA_PATH):
|
|
|
|
|
os.makedirs(DATA_PATH)
|
|
|
|
|
repo_name = self.data_source.split("/")[-1].split(".")[0]
|
|
|
|
|
repo_path = str(DATA_PATH / repo_name)
|
|
|
|
|
clone_url = self.data_source
|
|
|
|
|
if os.path.exists(repo_path):
|
|
|
|
|
clone_url = None
|
|
|
|
|
branches = ["main", "master"]
|
|
|
|
|
for branch in branches:
|
|
|
|
|
try:
|
|
|
|
|
docs = GitLoader(repo_path, clone_url, branch).load()
|
|
|
|
|
break
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error loading git: {e}")
|
|
|
|
|
if os.path.exists(repo_path):
|
|
|
|
|
# cleanup repo afterwards
|
|
|
|
|
shutil.rmtree(repo_path)
|
|
|
|
|
try:
|
|
|
|
|
docs = GitLoader(repo_path, clone_url, branch).load_and_split(text_splitter)
|
|
|
|
|
break
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error loading git: {e}")
|
|
|
|
|
if os.path.exists(repo_path):
|
|
|
|
|
# cleanup repo afterwards
|
|
|
|
|
shutil.rmtree(repo_path)
|
|
|
|
|
try:
|
|
|
|
|
return docs
|
|
|
|
|
except:
|
|
|
|
|
msg = "Make sure to use HTTPS git repo links"
|
|
|
|
|
handle_load_error(msg)
|
|
|
|
|
return docs
|
|
|
|
|
except:
|
|
|
|
|
raise RuntimeError("Make sure to use HTTPS GitHub repo links")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FILE_LOADER_MAPPING = {
|
|
|
|
|
".csv": (CSVLoader, {"encoding": "utf-8"}),
|
|
|
|
|
".doc": (UnstructuredWordDocumentLoader, {}),
|
|
|
|
|
".docx": (UnstructuredWordDocumentLoader, {}),
|
|
|
|
|
".enex": (EverNoteLoader, {}),
|
|
|
|
|
".epub": (UnstructuredEPubLoader, {}),
|
|
|
|
|
".html": (UnstructuredHTMLLoader, {}),
|
|
|
|
|
".md": (UnstructuredMarkdownLoader, {}),
|
|
|
|
|
".odt": (UnstructuredODTLoader, {}),
|
|
|
|
|
".pdf": (PDFMinerLoader, {}),
|
|
|
|
|
".ppt": (UnstructuredPowerPointLoader, {}),
|
|
|
|
|
".pptx": (UnstructuredPowerPointLoader, {}),
|
|
|
|
|
".txt": (TextLoader, {"encoding": "utf8"}),
|
|
|
|
|
".ipynb": (NotebookLoader, {}),
|
|
|
|
|
".py": (PythonLoader, {}),
|
|
|
|
|
# Add more mappings for other file extensions and loaders as needed
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
WEB_LOADER_MAPPING = {
|
|
|
|
|
".git": (AutoGitLoader, {}),
|
|
|
|
|
".pdf": (OnlinePDFLoader, {}),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_loader(file_path: str, mapping: dict, default_loader:BaseLoader) -> BaseLoader:
|
|
|
|
|
# Choose loader from mapping, load default if no match found
|
|
|
|
|
ext = "." + file_path.rsplit(".", 1)[-1]
|
|
|
|
|
if ext in mapping:
|
|
|
|
|
loader_class, loader_args = mapping[ext]
|
|
|
|
|
loader = loader_class(file_path, **loader_args)
|
|
|
|
|
else:
|
|
|
|
|
loader = default_loader(file_path)
|
|
|
|
|
return loader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_any_data_source(
|
|
|
|
|
data_source: str, chunk_size: int = CHUNK_SIZE
|
|
|
|
|
) -> List[Document]:
|
|
|
|
|
def load_data_source() -> List[Document]:
|
|
|
|
|
# Ugly thing that decides how to load data
|
|
|
|
|
# It aint much, but it's honest work
|
|
|
|
|
is_text = data_source.endswith(".txt")
|
|
|
|
|
data_source = st.session_state["data_source"]
|
|
|
|
|
is_web = data_source.startswith("http")
|
|
|
|
|
is_pdf = data_source.endswith(".pdf")
|
|
|
|
|
is_csv = data_source.endswith("csv")
|
|
|
|
|
is_html = data_source.endswith(".html")
|
|
|
|
|
is_git = data_source.endswith(".git")
|
|
|
|
|
is_notebook = data_source.endswith(".ipynb")
|
|
|
|
|
is_doc = data_source.endswith(".doc")
|
|
|
|
|
is_py = data_source.endswith(".py")
|
|
|
|
|
is_dir = os.path.isdir(data_source)
|
|
|
|
|
is_file = os.path.isfile(data_source)
|
|
|
|
|
|
|
|
|
|
loader = None
|
|
|
|
|
if is_dir:
|
|
|
|
|
loader = DirectoryLoader(data_source, recursive=True, silent_errors=True)
|
|
|
|
|
elif is_git:
|
|
|
|
|
return load_git(data_source, chunk_size)
|
|
|
|
|
elif is_web:
|
|
|
|
|
if is_pdf:
|
|
|
|
|
loader = OnlinePDFLoader(data_source)
|
|
|
|
|
else:
|
|
|
|
|
loader = WebBaseLoader(data_source)
|
|
|
|
|
loader = get_loader(data_source, WEB_LOADER_MAPPING, WebBaseLoader)
|
|
|
|
|
elif is_file:
|
|
|
|
|
if is_text:
|
|
|
|
|
loader = TextLoader(data_source, encoding="utf-8")
|
|
|
|
|
elif is_notebook:
|
|
|
|
|
loader = NotebookLoader(data_source)
|
|
|
|
|
elif is_pdf:
|
|
|
|
|
loader = UnstructuredPDFLoader(data_source)
|
|
|
|
|
elif is_html:
|
|
|
|
|
loader = UnstructuredHTMLLoader(data_source)
|
|
|
|
|
elif is_doc:
|
|
|
|
|
loader = UnstructuredWordDocumentLoader(data_source)
|
|
|
|
|
elif is_csv:
|
|
|
|
|
loader = CSVLoader(data_source, encoding="utf-8")
|
|
|
|
|
elif is_py:
|
|
|
|
|
loader = PythonLoader(data_source)
|
|
|
|
|
else:
|
|
|
|
|
loader = UnstructuredFileLoader(data_source)
|
|
|
|
|
loader = get_loader(data_source, FILE_LOADER_MAPPING, UnstructuredFileLoader)
|
|
|
|
|
try:
|
|
|
|
|
# Chunk size is a major trade-off parameter to control result accuracy over computaion
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(
|
|
|
|
|
chunk_size=chunk_size, chunk_overlap=0
|
|
|
|
|
chunk_size=st.session_state["chunk_size"],
|
|
|
|
|
chunk_overlap=st.session_state["chunk_overlap"],
|
|
|
|
|
)
|
|
|
|
|
docs = loader.load_and_split(text_splitter)
|
|
|
|
|
docs = loader.load()
|
|
|
|
|
docs = text_splitter.split_documents(docs)
|
|
|
|
|
logger.info(f"Loaded: {len(docs)} document chucks")
|
|
|
|
|
return docs
|
|
|
|
|
except Exception as e:
|
|
|
|
|
msg = (
|
|
|
|
|
e
|
|
|
|
|
if loader
|
|
|
|
|
else f"No Loader found for your data source. Consider contributing: {REPO_URL}!"
|
|
|
|
|
else f"No Loader found for your data source. Consider contributing: {PROJECT_URL}!"
|
|
|
|
|
)
|
|
|
|
|
handle_load_error(msg)
|
|
|
|
|
error_msg = f"Failed to load '{st.session_state['data_source']}':\n\n{msg}"
|
|
|
|
|
st.error(error_msg, icon=PAGE_ICON)
|
|
|
|
|
logger.error(error_msg)
|
|
|
|
|
st.stop()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_data_source_string(data_source_string: str) -> str:
|
|
|
|
|
def get_data_source_string() -> str:
|
|
|
|
|
# replace all non-word characters with dashes
|
|
|
|
|
# to get a string that can be used to create a new dataset
|
|
|
|
|
dashed_string = re.sub(r"\W+", "-", data_source_string)
|
|
|
|
|
dashed_string = re.sub(r"\W+", "-", st.session_state["data_source"])
|
|
|
|
|
cleaned_string = re.sub(r"--+", "- ", dashed_string).strip("-")
|
|
|
|
|
return cleaned_string
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_vector_store(data_source: str, chunk_size: int = CHUNK_SIZE) -> VectorStore:
|
|
|
|
|
def get_model() -> BaseLanguageModel:
|
|
|
|
|
match st.session_state["model"]:
|
|
|
|
|
case "gpt-3.5-turbo":
|
|
|
|
|
model = ChatOpenAI(
|
|
|
|
|
model_name=st.session_state["model"],
|
|
|
|
|
temperature=st.session_state["temperature"],
|
|
|
|
|
openai_api_key=st.session_state["openai_api_key"],
|
|
|
|
|
)
|
|
|
|
|
# Add more models as needed
|
|
|
|
|
case _default:
|
|
|
|
|
msg = f"Model {st.session_state['model']} not supported!"
|
|
|
|
|
logger.error(msg)
|
|
|
|
|
st.error(msg)
|
|
|
|
|
exit
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_embeddings() -> Embeddings:
|
|
|
|
|
match st.session_state["embeddings"]:
|
|
|
|
|
case "openai":
|
|
|
|
|
embeddings = OpenAIEmbeddings(
|
|
|
|
|
disallowed_special=(), openai_api_key=st.session_state["openai_api_key"]
|
|
|
|
|
)
|
|
|
|
|
# Add more embeddings as needed
|
|
|
|
|
case _default:
|
|
|
|
|
msg = f"Embeddings {st.session_state['embeddings']} not supported!"
|
|
|
|
|
logger.error(msg)
|
|
|
|
|
st.error(msg)
|
|
|
|
|
exit
|
|
|
|
|
return embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_vector_store() -> VectorStore:
|
|
|
|
|
# either load existing vector store or upload a new one to the hub
|
|
|
|
|
embeddings = OpenAIEmbeddings(
|
|
|
|
|
disallowed_special=(), openai_api_key=st.session_state["openai_api_key"]
|
|
|
|
|
)
|
|
|
|
|
data_source_name = clean_data_source_string(data_source)
|
|
|
|
|
dataset_path = f"hub://{st.session_state['activeloop_org_name']}/{data_source_name}-{chunk_size}"
|
|
|
|
|
embeddings = get_embeddings()
|
|
|
|
|
data_source_name = get_data_source_string()
|
|
|
|
|
dataset_path = f"hub://{st.session_state['activeloop_org_name']}/{data_source_name}-{st.session_state['chunk_size']}"
|
|
|
|
|
if deeplake.exists(dataset_path, token=st.session_state["activeloop_token"]):
|
|
|
|
|
with st.spinner("Loading vector store..."):
|
|
|
|
|
logger.info(f"Dataset '{dataset_path}' exists -> loading")
|
|
|
|
@ -322,7 +295,7 @@ def setup_vector_store(data_source: str, chunk_size: int = CHUNK_SIZE) -> Vector
|
|
|
|
|
else:
|
|
|
|
|
with st.spinner("Reading, embedding and uploading data to hub..."):
|
|
|
|
|
logger.info(f"Dataset '{dataset_path}' does not exist -> uploading")
|
|
|
|
|
docs = load_any_data_source(data_source, chunk_size)
|
|
|
|
|
docs = load_data_source()
|
|
|
|
|
vector_store = DeepLake.from_documents(
|
|
|
|
|
docs,
|
|
|
|
|
embeddings,
|
|
|
|
@ -332,16 +305,9 @@ def setup_vector_store(data_source: str, chunk_size: int = CHUNK_SIZE) -> Vector
|
|
|
|
|
return vector_store
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_chain(
|
|
|
|
|
data_source: str,
|
|
|
|
|
k: int = K,
|
|
|
|
|
fetch_k: int = FETCH_K,
|
|
|
|
|
chunk_size: int = CHUNK_SIZE,
|
|
|
|
|
temperature: float = TEMPERATURE,
|
|
|
|
|
max_tokens: int = MAX_TOKENS,
|
|
|
|
|
) -> ConversationalRetrievalChain:
|
|
|
|
|
def get_chain() -> ConversationalRetrievalChain:
|
|
|
|
|
# create the langchain that will be called to generate responses
|
|
|
|
|
vector_store = setup_vector_store(data_source, chunk_size)
|
|
|
|
|
vector_store = get_vector_store()
|
|
|
|
|
retriever = vector_store.as_retriever()
|
|
|
|
|
# Search params "fetch_k" and "k" define how many documents are pulled from the hub
|
|
|
|
|
# and selected after the document matching to build the context
|
|
|
|
@ -349,15 +315,11 @@ def build_chain(
|
|
|
|
|
search_kwargs = {
|
|
|
|
|
"maximal_marginal_relevance": True,
|
|
|
|
|
"distance_metric": "cos",
|
|
|
|
|
"fetch_k": fetch_k,
|
|
|
|
|
"k": k,
|
|
|
|
|
"fetch_k": st.session_state["fetch_k"],
|
|
|
|
|
"k": st.session_state["k"],
|
|
|
|
|
}
|
|
|
|
|
retriever.search_kwargs.update(search_kwargs)
|
|
|
|
|
model = ChatOpenAI(
|
|
|
|
|
model_name=MODEL,
|
|
|
|
|
temperature=temperature,
|
|
|
|
|
openai_api_key=st.session_state["openai_api_key"],
|
|
|
|
|
)
|
|
|
|
|
model = get_model()
|
|
|
|
|
chain = ConversationalRetrievalChain.from_llm(
|
|
|
|
|
model,
|
|
|
|
|
retriever=retriever,
|
|
|
|
@ -365,9 +327,8 @@ def build_chain(
|
|
|
|
|
verbose=True,
|
|
|
|
|
# we limit the maximum number of used tokens
|
|
|
|
|
# to prevent running into the models token limit of 4096
|
|
|
|
|
max_tokens_limit=max_tokens,
|
|
|
|
|
max_tokens_limit=st.session_state["max_tokens"],
|
|
|
|
|
)
|
|
|
|
|
logger.info(f"Data source '{data_source}' is ready to go!")
|
|
|
|
|
return chain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -375,15 +336,11 @@ def update_chain() -> None:
|
|
|
|
|
# Build chain with parameters from session state and store it back
|
|
|
|
|
# Also delete chat history to not confuse the bot with old context
|
|
|
|
|
try:
|
|
|
|
|
st.session_state["chain"] = build_chain(
|
|
|
|
|
data_source=st.session_state["data_source"],
|
|
|
|
|
k=st.session_state["k"],
|
|
|
|
|
fetch_k=st.session_state["fetch_k"],
|
|
|
|
|
chunk_size=st.session_state["chunk_size"],
|
|
|
|
|
temperature=st.session_state["temperature"],
|
|
|
|
|
max_tokens=st.session_state["max_tokens"],
|
|
|
|
|
)
|
|
|
|
|
st.session_state["chain"] = get_chain()
|
|
|
|
|
st.session_state["chat_history"] = []
|
|
|
|
|
msg = f"Data source '{st.session_state['data_source']}' is ready to go!"
|
|
|
|
|
logger.info(msg)
|
|
|
|
|
st.info(msg, icon=PAGE_ICON)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
msg = f"Failed to build chain for data source '{st.session_state['data_source']}' with error: {e}"
|
|
|
|
|
logger.error(msg)
|
|
|
|
|