refactor utils, prepared for new features

main
Gustav von Zitzewitz 1 year ago
parent ba6b376d56
commit 2a0e0bd4cc

@ -26,7 +26,7 @@ This is an app that let's you ask questions about any data source by leveraging
## TODO LIST ## TODO LIST
If you like to contribute, feel free to grab any task If you like to contribute, feel free to grab any task
- [x] Refactor utils, especially the loaders
- [ ] Add option to choose model and embeddings - [ ] Add option to choose model and embeddings
- [ ] Enable fully local / private mode - [ ] Enable fully local / private mode
- [ ] Refactor utils, especially the loaders
- [ ] Add Image caption and Audio transcription support - [ ] Add Image caption and Audio transcription support

103
app.py

@ -5,20 +5,22 @@ from constants import (
ACTIVELOOP_HELP, ACTIVELOOP_HELP,
APP_NAME, APP_NAME,
AUTHENTICATION_HELP, AUTHENTICATION_HELP,
CHUNK_OVERLAP,
CHUNK_SIZE, CHUNK_SIZE,
DEFAULT_DATA_SOURCE, DEFAULT_DATA_SOURCE,
EMBEDDINGS,
ENABLE_ADVANCED_OPTIONS, ENABLE_ADVANCED_OPTIONS,
FETCH_K, FETCH_K,
MAX_TOKENS, MAX_TOKENS,
MODEL,
OPENAI_HELP, OPENAI_HELP,
PAGE_ICON, PAGE_ICON,
REPO_URL, PROJECT_URL,
TEMPERATURE, TEMPERATURE,
USAGE_HELP, USAGE_HELP,
K, K,
) )
from utils import ( from utils import (
advanced_options_form,
authenticate, authenticate,
delete_uploaded_file, delete_uploaded_file,
generate_response, generate_response,
@ -49,9 +51,12 @@ SESSION_DEFAULTS = {
"openai_api_key": None, "openai_api_key": None,
"activeloop_token": None, "activeloop_token": None,
"activeloop_org_name": None, "activeloop_org_name": None,
"model": MODEL,
"embeddings": EMBEDDINGS,
"k": K, "k": K,
"fetch_k": FETCH_K, "fetch_k": FETCH_K,
"chunk_size": CHUNK_SIZE, "chunk_size": CHUNK_SIZE,
"chunk_overlap": CHUNK_OVERLAP,
"temperature": TEMPERATURE, "temperature": TEMPERATURE,
"max_tokens": MAX_TOKENS, "max_tokens": MAX_TOKENS,
} }
@ -61,9 +66,7 @@ for k, v in SESSION_DEFAULTS.items():
st.session_state[k] = v st.session_state[k] = v
# Sidebar with Authentication def authentication_form() -> None:
# Only start App if authentication is OK
with st.sidebar:
st.title("Authentication", help=AUTHENTICATION_HELP) st.title("Authentication", help=AUTHENTICATION_HELP)
with st.form("authentication"): with st.form("authentication"):
openai_api_key = st.text_input( openai_api_key = st.text_input(
@ -88,7 +91,70 @@ with st.sidebar:
if submitted: if submitted:
authenticate(openai_api_key, activeloop_token, activeloop_org_name) authenticate(openai_api_key, activeloop_token, activeloop_org_name)
st.info(f"Learn how it works [here]({REPO_URL})")
def advanced_options_form() -> None:
# Input Form that takes advanced options and rebuilds chain with them
advanced_options = st.checkbox(
"Advanced Options", help="Caution! This may break things!"
)
if advanced_options:
with st.form("advanced_options"):
temperature = st.slider(
"temperature",
min_value=0.0,
max_value=1.0,
value=TEMPERATURE,
help="Controls the randomness of the language model output",
)
col1, col2 = st.columns(2)
fetch_k = col1.number_input(
"k_fetch",
min_value=1,
max_value=1000,
value=FETCH_K,
help="The number of documents to pull from the vector database",
)
k = col2.number_input(
"k",
min_value=1,
max_value=100,
value=K,
help="The number of most similar documents to build the context from",
)
chunk_size = col1.number_input(
"chunk_size",
min_value=1,
max_value=100000,
value=CHUNK_SIZE,
help=(
"The size at which the text is divided into smaller chunks "
"before being embedded.\n\nChanging this parameter makes re-embedding "
"and re-uploading the data to the database necessary "
),
)
max_tokens = col2.number_input(
"max_tokens",
min_value=1,
max_value=4069,
value=MAX_TOKENS,
help="Limits the documents returned from database based on number of tokens",
)
applied = st.form_submit_button("Apply")
if applied:
st.session_state["k"] = k
st.session_state["fetch_k"] = fetch_k
st.session_state["chunk_size"] = chunk_size
st.session_state["temperature"] = temperature
st.session_state["max_tokens"] = max_tokens
update_chain()
# Sidebar with Authentication and Advanced Options
with st.sidebar:
authentication_form()
st.info(f"Learn how it works [here]({PROJECT_URL})")
# Only start App if authentication is OK
if not st.session_state["auth_ok"]: if not st.session_state["auth_ok"]:
st.stop() st.stop()
@ -99,11 +165,6 @@ with st.sidebar:
if ENABLE_ADVANCED_OPTIONS: if ENABLE_ADVANCED_OPTIONS:
advanced_options_form() advanced_options_form()
# the chain can only be initialized after authentication is OK
if "chain" not in st.session_state:
update_chain()
if clear_button: if clear_button:
# resets all chat history related caches # resets all chat history related caches
st.session_state["past"] = [] st.session_state["past"] = []
@ -118,6 +179,10 @@ data_source = st.text_input(
placeholder="Any path or url pointing to a file or directory of files", placeholder="Any path or url pointing to a file or directory of files",
) )
# the chain can only be initialized after authentication is OK
if "chain" not in st.session_state:
update_chain()
# generate new chain for new data source / uploaded file # generate new chain for new data source / uploaded file
# make sure to do this only once per input / on change # make sure to do this only once per input / on change
if data_source and data_source != st.session_state["data_source"]: if data_source and data_source != st.session_state["data_source"]:
@ -128,28 +193,28 @@ if data_source and data_source != st.session_state["data_source"]:
if uploaded_file and uploaded_file != st.session_state["uploaded_file"]: if uploaded_file and uploaded_file != st.session_state["uploaded_file"]:
logger.info(f"Uploaded file: '{uploaded_file.name}'") logger.info(f"Uploaded file: '{uploaded_file.name}'")
st.session_state["uploaded_file"] = uploaded_file st.session_state["uploaded_file"] = uploaded_file
data_source = save_uploaded_file(uploaded_file) data_source = save_uploaded_file()
st.session_state["data_source"] = data_source st.session_state["data_source"] = data_source
update_chain() update_chain()
delete_uploaded_file(uploaded_file) delete_uploaded_file()
# container for chat history # container for chat history
response_container = st.container() response_container = st.container()
# container for text box # container for text box
container = st.container() text_container = st.container()
# As streamlit reruns the whole script on each change # As streamlit reruns the whole script on each change
# it is necessary to repopulate the chat containers # it is necessary to repopulate the chat containers
with container: with text_container:
with st.form(key="prompt_input", clear_on_submit=True): with st.form(key="prompt_input", clear_on_submit=True):
user_input = st.text_area("You:", key="input", height=100) user_input = st.text_area("You:", key="input", height=100)
submit_button = st.form_submit_button(label="Send") submit_button = st.form_submit_button(label="Send")
if submit_button and user_input: if submit_button and user_input:
output = generate_response(user_input) output = generate_response(user_input)
st.session_state["past"].append(user_input) st.session_state["past"].append(user_input)
st.session_state["generated"].append(output) st.session_state["generated"].append(output)
if st.session_state["generated"]: if st.session_state["generated"]:
with response_container: with response_container:

@ -2,11 +2,15 @@ from pathlib import Path
APP_NAME = "DataChad" APP_NAME = "DataChad"
MODEL = "gpt-3.5-turbo" MODEL = "gpt-3.5-turbo"
EMBEDDINGS = "openai"
PAGE_ICON = "🤖" PAGE_ICON = "🤖"
PROJECT_URL = "https://github.com/gustavz/DataChad"
K = 10 K = 10
FETCH_K = 20 FETCH_K = 20
CHUNK_SIZE = 1000 CHUNK_SIZE = 1000
CHUNK_OVERLAP = 0
TEMPERATURE = 0.7 TEMPERATURE = 0.7
MAX_TOKENS = 3357 MAX_TOKENS = 3357
ENABLE_ADVANCED_OPTIONS = True ENABLE_ADVANCED_OPTIONS = True
@ -14,12 +18,10 @@ ENABLE_ADVANCED_OPTIONS = True
DATA_PATH = Path.cwd() / "data" DATA_PATH = Path.cwd() / "data"
DEFAULT_DATA_SOURCE = "https://github.com/gustavz/DataChad.git" DEFAULT_DATA_SOURCE = "https://github.com/gustavz/DataChad.git"
REPO_URL = "https://github.com/gustavz/DataChad"
AUTHENTICATION_HELP = f""" AUTHENTICATION_HELP = f"""
Your credentials are only stored in your session state.\n Your credentials are only stored in your session state.\n
The keys are neither exposed nor made visible or stored permanently in any way.\n The keys are neither exposed nor made visible or stored permanently in any way.\n
Feel free to check out [the code base]({REPO_URL}) to validate how things work. Feel free to check out [the code base]({PROJECT_URL}) to validate how things work.
""" """
USAGE_HELP = f""" USAGE_HELP = f"""

@ -9,41 +9,36 @@ import deeplake
import openai import openai
import streamlit as st import streamlit as st
from dotenv import load_dotenv from dotenv import load_dotenv
from langchain.base_language import BaseLanguageModel
from langchain.callbacks import OpenAICallbackHandler, get_openai_callback from langchain.callbacks import OpenAICallbackHandler, get_openai_callback
from langchain.chains import ConversationalRetrievalChain from langchain.chains import ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import ( from langchain.document_loaders import (
CSVLoader, CSVLoader,
DirectoryLoader, DirectoryLoader,
EverNoteLoader,
GitLoader, GitLoader,
NotebookLoader, NotebookLoader,
OnlinePDFLoader, OnlinePDFLoader,
PDFMinerLoader,
PythonLoader, PythonLoader,
TextLoader, TextLoader,
UnstructuredEPubLoader,
UnstructuredFileLoader, UnstructuredFileLoader,
UnstructuredHTMLLoader, UnstructuredHTMLLoader,
UnstructuredPDFLoader, UnstructuredMarkdownLoader,
UnstructuredODTLoader,
UnstructuredPowerPointLoader,
UnstructuredWordDocumentLoader, UnstructuredWordDocumentLoader,
WebBaseLoader, WebBaseLoader,
) )
from langchain.embeddings.openai import OpenAIEmbeddings from langchain.document_loaders.base import BaseLoader
from langchain.embeddings.openai import Embeddings, OpenAIEmbeddings
from langchain.schema import Document from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import DeepLake, VectorStore from langchain.vectorstores import DeepLake, VectorStore
from streamlit.runtime.uploaded_file_manager import UploadedFile
from constants import APP_NAME, DATA_PATH, PAGE_ICON, PROJECT_URL
from constants import (
APP_NAME,
CHUNK_SIZE,
DATA_PATH,
FETCH_K,
MAX_TOKENS,
MODEL,
PAGE_ICON,
REPO_URL,
TEMPERATURE,
K,
)
# loads environment variables # loads environment variables
load_dotenv() load_dotenv()
@ -116,66 +111,10 @@ def authenticate(
logger.info("Authentification successful!") logger.info("Authentification successful!")
def advanced_options_form() -> None: def save_uploaded_file() -> str:
# Input Form that takes advanced options and rebuilds chain with them
advanced_options = st.checkbox(
"Advanced Options", help="Caution! This may break things!"
)
if advanced_options:
with st.form("advanced_options"):
temperature = st.slider(
"temperature",
min_value=0.0,
max_value=1.0,
value=TEMPERATURE,
help="Controls the randomness of the language model output",
)
col1, col2 = st.columns(2)
fetch_k = col1.number_input(
"k_fetch",
min_value=1,
max_value=1000,
value=FETCH_K,
help="The number of documents to pull from the vector database",
)
k = col2.number_input(
"k",
min_value=1,
max_value=100,
value=K,
help="The number of most similar documents to build the context from",
)
chunk_size = col1.number_input(
"chunk_size",
min_value=1,
max_value=100000,
value=CHUNK_SIZE,
help=(
"The size at which the text is divided into smaller chunks "
"before being embedded.\n\nChanging this parameter makes re-embedding "
"and re-uploading the data to the database necessary "
),
)
max_tokens = col2.number_input(
"max_tokens",
min_value=1,
max_value=4069,
value=MAX_TOKENS,
help="Limits the documents returned from database based on number of tokens",
)
applied = st.form_submit_button("Apply")
if applied:
st.session_state["k"] = k
st.session_state["fetch_k"] = fetch_k
st.session_state["chunk_size"] = chunk_size
st.session_state["temperature"] = temperature
st.session_state["max_tokens"] = max_tokens
update_chain()
def save_uploaded_file(uploaded_file: UploadedFile) -> str:
# streamlit uploaded files need to be stored locally # streamlit uploaded files need to be stored locally
# before embedded and uploaded to the hub # before embedded and uploaded to the hub
uploaded_file = st.session_state["uploaded_file"]
if not os.path.exists(DATA_PATH): if not os.path.exists(DATA_PATH):
os.makedirs(DATA_PATH) os.makedirs(DATA_PATH)
file_path = str(DATA_PATH / uploaded_file.name) file_path = str(DATA_PATH / uploaded_file.name)
@ -188,128 +127,162 @@ def save_uploaded_file(uploaded_file: UploadedFile) -> str:
return file_path return file_path
def delete_uploaded_file(uploaded_file: UploadedFile) -> None: def delete_uploaded_file() -> None:
# cleanup locally stored files # cleanup locally stored files
file_path = DATA_PATH / uploaded_file.name file_path = DATA_PATH / st.session_state["uploaded_file"].name
if os.path.exists(DATA_PATH): if os.path.exists(DATA_PATH):
os.remove(file_path) os.remove(file_path)
logger.info(f"Removed: {file_path}") logger.info(f"Removed: {file_path}")
def handle_load_error(e: str = None) -> None: class AutoGitLoader:
error_msg = f"Failed to load '{st.session_state['data_source']}':\n\n{e}" def __init__(self, data_source: str) -> None:
st.error(error_msg, icon=PAGE_ICON) self.data_source = data_source
logger.error(error_msg)
st.stop() def load(self) -> List[Document]:
# We need to try both common main branches
# Thank you github for the "master" to "main" switch
def load_git(data_source: str, chunk_size: int = CHUNK_SIZE) -> List[Document]: # we need to make sure the data path exists
# We need to try both common main branches if not os.path.exists(DATA_PATH):
# Thank you github for the "master" to "main" switch os.makedirs(DATA_PATH)
# we need to make sure the data path exists repo_name = self.data_source.split("/")[-1].split(".")[0]
if not os.path.exists(DATA_PATH): repo_path = str(DATA_PATH / repo_name)
os.makedirs(DATA_PATH) clone_url = self.data_source
repo_name = data_source.split("/")[-1].split(".")[0] if os.path.exists(repo_path):
repo_path = str(DATA_PATH / repo_name) clone_url = None
clone_url = data_source branches = ["main", "master"]
if os.path.exists(repo_path): for branch in branches:
clone_url = None try:
text_splitter = RecursiveCharacterTextSplitter( docs = GitLoader(repo_path, clone_url, branch).load()
chunk_size=chunk_size, chunk_overlap=0 break
) except Exception as e:
branches = ["main", "master"] logger.error(f"Error loading git: {e}")
for branch in branches: if os.path.exists(repo_path):
# cleanup repo afterwards
shutil.rmtree(repo_path)
try: try:
docs = GitLoader(repo_path, clone_url, branch).load_and_split(text_splitter) return docs
break except:
except Exception as e: raise RuntimeError("Make sure to use HTTPS GitHub repo links")
logger.error(f"Error loading git: {e}")
if os.path.exists(repo_path):
# cleanup repo afterwards FILE_LOADER_MAPPING = {
shutil.rmtree(repo_path) ".csv": (CSVLoader, {"encoding": "utf-8"}),
try: ".doc": (UnstructuredWordDocumentLoader, {}),
return docs ".docx": (UnstructuredWordDocumentLoader, {}),
except: ".enex": (EverNoteLoader, {}),
msg = "Make sure to use HTTPS git repo links" ".epub": (UnstructuredEPubLoader, {}),
handle_load_error(msg) ".html": (UnstructuredHTMLLoader, {}),
".md": (UnstructuredMarkdownLoader, {}),
".odt": (UnstructuredODTLoader, {}),
".pdf": (PDFMinerLoader, {}),
".ppt": (UnstructuredPowerPointLoader, {}),
".pptx": (UnstructuredPowerPointLoader, {}),
".txt": (TextLoader, {"encoding": "utf8"}),
".ipynb": (NotebookLoader, {}),
".py": (PythonLoader, {}),
# Add more mappings for other file extensions and loaders as needed
}
WEB_LOADER_MAPPING = {
".git": (AutoGitLoader, {}),
".pdf": (OnlinePDFLoader, {}),
}
def get_loader(file_path: str, mapping: dict, default_loader:BaseLoader) -> BaseLoader:
# Choose loader from mapping, load default if no match found
ext = "." + file_path.rsplit(".", 1)[-1]
if ext in mapping:
loader_class, loader_args = mapping[ext]
loader = loader_class(file_path, **loader_args)
else:
loader = default_loader(file_path)
return loader
def load_any_data_source( def load_data_source() -> List[Document]:
data_source: str, chunk_size: int = CHUNK_SIZE
) -> List[Document]:
# Ugly thing that decides how to load data # Ugly thing that decides how to load data
# It aint much, but it's honest work # It aint much, but it's honest work
is_text = data_source.endswith(".txt") data_source = st.session_state["data_source"]
is_web = data_source.startswith("http") is_web = data_source.startswith("http")
is_pdf = data_source.endswith(".pdf")
is_csv = data_source.endswith("csv")
is_html = data_source.endswith(".html")
is_git = data_source.endswith(".git")
is_notebook = data_source.endswith(".ipynb")
is_doc = data_source.endswith(".doc")
is_py = data_source.endswith(".py")
is_dir = os.path.isdir(data_source) is_dir = os.path.isdir(data_source)
is_file = os.path.isfile(data_source) is_file = os.path.isfile(data_source)
loader = None loader = None
if is_dir: if is_dir:
loader = DirectoryLoader(data_source, recursive=True, silent_errors=True) loader = DirectoryLoader(data_source, recursive=True, silent_errors=True)
elif is_git:
return load_git(data_source, chunk_size)
elif is_web: elif is_web:
if is_pdf: loader = get_loader(data_source, WEB_LOADER_MAPPING, WebBaseLoader)
loader = OnlinePDFLoader(data_source)
else:
loader = WebBaseLoader(data_source)
elif is_file: elif is_file:
if is_text: loader = get_loader(data_source, FILE_LOADER_MAPPING, UnstructuredFileLoader)
loader = TextLoader(data_source, encoding="utf-8")
elif is_notebook:
loader = NotebookLoader(data_source)
elif is_pdf:
loader = UnstructuredPDFLoader(data_source)
elif is_html:
loader = UnstructuredHTMLLoader(data_source)
elif is_doc:
loader = UnstructuredWordDocumentLoader(data_source)
elif is_csv:
loader = CSVLoader(data_source, encoding="utf-8")
elif is_py:
loader = PythonLoader(data_source)
else:
loader = UnstructuredFileLoader(data_source)
try: try:
# Chunk size is a major trade-off parameter to control result accuracy over computaion # Chunk size is a major trade-off parameter to control result accuracy over computaion
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=0 chunk_size=st.session_state["chunk_size"],
chunk_overlap=st.session_state["chunk_overlap"],
) )
docs = loader.load_and_split(text_splitter) docs = loader.load()
docs = text_splitter.split_documents(docs)
logger.info(f"Loaded: {len(docs)} document chucks") logger.info(f"Loaded: {len(docs)} document chucks")
return docs return docs
except Exception as e: except Exception as e:
msg = ( msg = (
e e
if loader if loader
else f"No Loader found for your data source. Consider contributing:  {REPO_URL}!" else f"No Loader found for your data source. Consider contributing:  {PROJECT_URL}!"
) )
handle_load_error(msg) error_msg = f"Failed to load '{st.session_state['data_source']}':\n\n{msg}"
st.error(error_msg, icon=PAGE_ICON)
logger.error(error_msg)
st.stop()
def clean_data_source_string(data_source_string: str) -> str: def get_data_source_string() -> str:
# replace all non-word characters with dashes # replace all non-word characters with dashes
# to get a string that can be used to create a new dataset # to get a string that can be used to create a new dataset
dashed_string = re.sub(r"\W+", "-", data_source_string) dashed_string = re.sub(r"\W+", "-", st.session_state["data_source"])
cleaned_string = re.sub(r"--+", "- ", dashed_string).strip("-") cleaned_string = re.sub(r"--+", "- ", dashed_string).strip("-")
return cleaned_string return cleaned_string
def setup_vector_store(data_source: str, chunk_size: int = CHUNK_SIZE) -> VectorStore: def get_model() -> BaseLanguageModel:
match st.session_state["model"]:
case "gpt-3.5-turbo":
model = ChatOpenAI(
model_name=st.session_state["model"],
temperature=st.session_state["temperature"],
openai_api_key=st.session_state["openai_api_key"],
)
# Add more models as needed
case _default:
msg = f"Model {st.session_state['model']} not supported!"
logger.error(msg)
st.error(msg)
exit
return model
def get_embeddings() -> Embeddings:
match st.session_state["embeddings"]:
case "openai":
embeddings = OpenAIEmbeddings(
disallowed_special=(), openai_api_key=st.session_state["openai_api_key"]
)
# Add more embeddings as needed
case _default:
msg = f"Embeddings {st.session_state['embeddings']} not supported!"
logger.error(msg)
st.error(msg)
exit
return embeddings
def get_vector_store() -> VectorStore:
# either load existing vector store or upload a new one to the hub # either load existing vector store or upload a new one to the hub
embeddings = OpenAIEmbeddings( embeddings = get_embeddings()
disallowed_special=(), openai_api_key=st.session_state["openai_api_key"] data_source_name = get_data_source_string()
) dataset_path = f"hub://{st.session_state['activeloop_org_name']}/{data_source_name}-{st.session_state['chunk_size']}"
data_source_name = clean_data_source_string(data_source)
dataset_path = f"hub://{st.session_state['activeloop_org_name']}/{data_source_name}-{chunk_size}"
if deeplake.exists(dataset_path, token=st.session_state["activeloop_token"]): if deeplake.exists(dataset_path, token=st.session_state["activeloop_token"]):
with st.spinner("Loading vector store..."): with st.spinner("Loading vector store..."):
logger.info(f"Dataset '{dataset_path}' exists -> loading") logger.info(f"Dataset '{dataset_path}' exists -> loading")
@ -322,7 +295,7 @@ def setup_vector_store(data_source: str, chunk_size: int = CHUNK_SIZE) -> Vector
else: else:
with st.spinner("Reading, embedding and uploading data to hub..."): with st.spinner("Reading, embedding and uploading data to hub..."):
logger.info(f"Dataset '{dataset_path}' does not exist -> uploading") logger.info(f"Dataset '{dataset_path}' does not exist -> uploading")
docs = load_any_data_source(data_source, chunk_size) docs = load_data_source()
vector_store = DeepLake.from_documents( vector_store = DeepLake.from_documents(
docs, docs,
embeddings, embeddings,
@ -332,16 +305,9 @@ def setup_vector_store(data_source: str, chunk_size: int = CHUNK_SIZE) -> Vector
return vector_store return vector_store
def build_chain( def get_chain() -> ConversationalRetrievalChain:
data_source: str,
k: int = K,
fetch_k: int = FETCH_K,
chunk_size: int = CHUNK_SIZE,
temperature: float = TEMPERATURE,
max_tokens: int = MAX_TOKENS,
) -> ConversationalRetrievalChain:
# create the langchain that will be called to generate responses # create the langchain that will be called to generate responses
vector_store = setup_vector_store(data_source, chunk_size) vector_store = get_vector_store()
retriever = vector_store.as_retriever() retriever = vector_store.as_retriever()
# Search params "fetch_k" and "k" define how many documents are pulled from the hub # Search params "fetch_k" and "k" define how many documents are pulled from the hub
# and selected after the document matching to build the context # and selected after the document matching to build the context
@ -349,15 +315,11 @@ def build_chain(
search_kwargs = { search_kwargs = {
"maximal_marginal_relevance": True, "maximal_marginal_relevance": True,
"distance_metric": "cos", "distance_metric": "cos",
"fetch_k": fetch_k, "fetch_k": st.session_state["fetch_k"],
"k": k, "k": st.session_state["k"],
} }
retriever.search_kwargs.update(search_kwargs) retriever.search_kwargs.update(search_kwargs)
model = ChatOpenAI( model = get_model()
model_name=MODEL,
temperature=temperature,
openai_api_key=st.session_state["openai_api_key"],
)
chain = ConversationalRetrievalChain.from_llm( chain = ConversationalRetrievalChain.from_llm(
model, model,
retriever=retriever, retriever=retriever,
@ -365,9 +327,8 @@ def build_chain(
verbose=True, verbose=True,
# we limit the maximum number of used tokens # we limit the maximum number of used tokens
# to prevent running into the models token limit of 4096 # to prevent running into the models token limit of 4096
max_tokens_limit=max_tokens, max_tokens_limit=st.session_state["max_tokens"],
) )
logger.info(f"Data source '{data_source}' is ready to go!")
return chain return chain
@ -375,15 +336,11 @@ def update_chain() -> None:
# Build chain with parameters from session state and store it back # Build chain with parameters from session state and store it back
# Also delete chat history to not confuse the bot with old context # Also delete chat history to not confuse the bot with old context
try: try:
st.session_state["chain"] = build_chain( st.session_state["chain"] = get_chain()
data_source=st.session_state["data_source"],
k=st.session_state["k"],
fetch_k=st.session_state["fetch_k"],
chunk_size=st.session_state["chunk_size"],
temperature=st.session_state["temperature"],
max_tokens=st.session_state["max_tokens"],
)
st.session_state["chat_history"] = [] st.session_state["chat_history"] = []
msg = f"Data source '{st.session_state['data_source']}' is ready to go!"
logger.info(msg)
st.info(msg, icon=PAGE_ICON)
except Exception as e: except Exception as e:
msg = f"Failed to build chain for data source '{st.session_state['data_source']}' with error: {e}" msg = f"Failed to build chain for data source '{st.session_state['data_source']}' with error: {e}"
logger.error(msg) logger.error(msg)

Loading…
Cancel
Save