|
|
|
@ -34,6 +34,7 @@ logger = logging.getLogger(APP_NAME)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def configure_logger(debug=0):
|
|
|
|
|
# boilerplate code to enable logging in the streamlit app console
|
|
|
|
|
log_level = logging.DEBUG if debug == 1 else logging.INFO
|
|
|
|
|
logger.setLevel(log_level)
|
|
|
|
|
|
|
|
|
@ -115,6 +116,7 @@ def delete_uploaded_file(uploaded_file):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_git(data_source):
|
|
|
|
|
# We need to try both common main branches
|
|
|
|
|
# Thank you github for the "master" to "main" switch
|
|
|
|
|
repo_name = data_source.split("/")[-1].split(".")[0]
|
|
|
|
|
repo_path = str(DATA_PATH / repo_name)
|
|
|
|
@ -137,7 +139,8 @@ def load_git(data_source):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_any_data_source(data_source):
|
|
|
|
|
# ugly thing that decides how to load data
|
|
|
|
|
# Ugly thing that decides how to load data
|
|
|
|
|
# It aint much, but it's honest work
|
|
|
|
|
is_text = data_source.endswith(".txt")
|
|
|
|
|
is_web = data_source.startswith("http")
|
|
|
|
|
is_pdf = data_source.endswith(".pdf")
|
|
|
|
@ -178,6 +181,7 @@ def load_any_data_source(data_source):
|
|
|
|
|
else:
|
|
|
|
|
loader = UnstructuredFileLoader(data_source)
|
|
|
|
|
if loader:
|
|
|
|
|
# Chunk size is a major trade-off parameter to control result accuracy over computaion
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
|
|
|
|
docs = loader.load_and_split(text_splitter)
|
|
|
|
|
logger.info(f"Loaded: {len(docs)} document chucks")
|
|
|
|
@ -233,10 +237,13 @@ def get_chain(data_source):
|
|
|
|
|
# create the langchain that will be called to generate responses
|
|
|
|
|
vector_store = setup_vector_store(data_source)
|
|
|
|
|
retriever = vector_store.as_retriever()
|
|
|
|
|
# Search params "fetch_k" and "k" define how many documents are pulled from the hub
|
|
|
|
|
# and selected after the document matching to build the context
|
|
|
|
|
# that is fed to the model together with your prompt
|
|
|
|
|
search_kwargs = {
|
|
|
|
|
"maximal_marginal_relevance": True,
|
|
|
|
|
"distance_metric": "cos",
|
|
|
|
|
"fetch_k": 20,
|
|
|
|
|
"maximal_marginal_relevance": True,
|
|
|
|
|
"k": 10,
|
|
|
|
|
}
|
|
|
|
|
retriever.search_kwargs.update(search_kwargs)
|
|
|
|
@ -249,6 +256,8 @@ def get_chain(data_source):
|
|
|
|
|
retriever=retriever,
|
|
|
|
|
chain_type="stuff",
|
|
|
|
|
verbose=True,
|
|
|
|
|
# we limit the maximum number of used tokens
|
|
|
|
|
# to prevent running into the models token limit of 4096
|
|
|
|
|
max_tokens_limit=3375,
|
|
|
|
|
)
|
|
|
|
|
logger.info(f"Data source '{data_source}' is ready to go!")
|
|
|
|
|