import os import shutil from typing import List import streamlit as st from langchain.document_loaders import ( CSVLoader, DirectoryLoader, EverNoteLoader, GitLoader, NotebookLoader, OnlinePDFLoader, PDFMinerLoader, PythonLoader, TextLoader, UnstructuredEPubLoader, UnstructuredFileLoader, UnstructuredHTMLLoader, UnstructuredMarkdownLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, UnstructuredWordDocumentLoader, WebBaseLoader, ) from langchain.document_loaders.base import BaseLoader from langchain.schema import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from datachad.constants import DATA_PATH, PAGE_ICON, PROJECT_URL from datachad.utils import logger class AutoGitLoader: def __init__(self, data_source: str) -> None: self.data_source = data_source def load(self) -> List[Document]: # We need to try both common main branches # Thank you github for the "master" to "main" switch # we need to make sure the data path exists if not os.path.exists(DATA_PATH): os.makedirs(DATA_PATH) repo_name = self.data_source.split("/")[-1].split(".")[0] repo_path = str(DATA_PATH / repo_name) clone_url = self.data_source if os.path.exists(repo_path): clone_url = None branches = ["main", "master"] for branch in branches: try: docs = GitLoader(repo_path, clone_url, branch).load() break except Exception as e: logger.error(f"Error loading git: {e}") if os.path.exists(repo_path): # cleanup repo afterwards shutil.rmtree(repo_path) try: return docs except: raise RuntimeError("Make sure to use HTTPS GitHub repo links") FILE_LOADER_MAPPING = { ".csv": (CSVLoader, {"encoding": "utf-8"}), ".doc": (UnstructuredWordDocumentLoader, {}), ".docx": (UnstructuredWordDocumentLoader, {}), ".enex": (EverNoteLoader, {}), ".epub": (UnstructuredEPubLoader, {}), ".html": (UnstructuredHTMLLoader, {}), ".md": (UnstructuredMarkdownLoader, {}), ".odt": (UnstructuredODTLoader, {}), ".pdf": (PDFMinerLoader, {}), ".ppt": (UnstructuredPowerPointLoader, {}), ".pptx": (UnstructuredPowerPointLoader, {}), ".txt": (TextLoader, {"encoding": "utf8"}), ".ipynb": (NotebookLoader, {}), ".py": (PythonLoader, {}), # Add more mappings for other file extensions and loaders as needed } WEB_LOADER_MAPPING = { ".git": (AutoGitLoader, {}), ".pdf": (OnlinePDFLoader, {}), } def get_loader(file_path: str, mapping: dict, default_loader: BaseLoader) -> BaseLoader: # Choose loader from mapping, load default if no match found ext = "." + file_path.rsplit(".", 1)[-1] if ext in mapping: loader_class, loader_args = mapping[ext] loader = loader_class(file_path, **loader_args) else: loader = default_loader(file_path) return loader def load_data_source() -> List[Document]: # Ugly thing that decides how to load data # It aint much, but it's honest work data_source = st.session_state["data_source"] is_web = data_source.startswith("http") is_dir = os.path.isdir(data_source) is_file = os.path.isfile(data_source) loader = None if is_dir: loader = DirectoryLoader(data_source, recursive=True, silent_errors=True) elif is_web: loader = get_loader(data_source, WEB_LOADER_MAPPING, WebBaseLoader) elif is_file: loader = get_loader(data_source, FILE_LOADER_MAPPING, UnstructuredFileLoader) try: # Chunk size is a major trade-off parameter to control result accuracy over computaion text_splitter = RecursiveCharacterTextSplitter( chunk_size=st.session_state["chunk_size"], chunk_overlap=st.session_state["chunk_overlap"], ) docs = loader.load() docs = text_splitter.split_documents(docs) logger.info(f"Loaded: {len(docs)} document chucks") return docs except Exception as e: msg = ( e if loader else f"No Loader found for your data source. Consider contributing:  {PROJECT_URL}!" ) error_msg = f"Failed to load '{st.session_state['data_source']}':\n\n{msg}" st.error(error_msg, icon=PAGE_ICON) logger.error(error_msg) st.stop()