You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

134 lines
4.4 KiB
Python

This file contains invisible Unicode characters!

This file contains invisible Unicode characters that may be processed differently from what appears below. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to reveal hidden characters.

import os
import shutil
from typing import List
import streamlit as st
from langchain.document_loaders import (
CSVLoader,
DirectoryLoader,
EverNoteLoader,
GitLoader,
NotebookLoader,
OnlinePDFLoader,
PDFMinerLoader,
PythonLoader,
TextLoader,
UnstructuredEPubLoader,
UnstructuredFileLoader,
UnstructuredHTMLLoader,
UnstructuredMarkdownLoader,
UnstructuredODTLoader,
UnstructuredPowerPointLoader,
UnstructuredWordDocumentLoader,
WebBaseLoader,
)
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from datachad.constants import DATA_PATH, PAGE_ICON, PROJECT_URL
from datachad.utils import logger
class AutoGitLoader:
def __init__(self, data_source: str) -> None:
self.data_source = data_source
def load(self) -> List[Document]:
# We need to try both common main branches
# Thank you github for the "master" to "main" switch
# we need to make sure the data path exists
if not os.path.exists(DATA_PATH):
os.makedirs(DATA_PATH)
repo_name = self.data_source.split("/")[-1].split(".")[0]
repo_path = str(DATA_PATH / repo_name)
clone_url = self.data_source
if os.path.exists(repo_path):
clone_url = None
branches = ["main", "master"]
for branch in branches:
try:
docs = GitLoader(repo_path, clone_url, branch).load()
break
except Exception as e:
logger.error(f"Error loading git: {e}")
if os.path.exists(repo_path):
# cleanup repo afterwards
shutil.rmtree(repo_path)
try:
return docs
except:
raise RuntimeError("Make sure to use HTTPS GitHub repo links")
FILE_LOADER_MAPPING = {
".csv": (CSVLoader, {"encoding": "utf-8"}),
".doc": (UnstructuredWordDocumentLoader, {}),
".docx": (UnstructuredWordDocumentLoader, {}),
".enex": (EverNoteLoader, {}),
".epub": (UnstructuredEPubLoader, {}),
".html": (UnstructuredHTMLLoader, {}),
".md": (UnstructuredMarkdownLoader, {}),
".odt": (UnstructuredODTLoader, {}),
".pdf": (PDFMinerLoader, {}),
".ppt": (UnstructuredPowerPointLoader, {}),
".pptx": (UnstructuredPowerPointLoader, {}),
".txt": (TextLoader, {"encoding": "utf8"}),
".ipynb": (NotebookLoader, {}),
".py": (PythonLoader, {}),
# Add more mappings for other file extensions and loaders as needed
}
WEB_LOADER_MAPPING = {
".git": (AutoGitLoader, {}),
".pdf": (OnlinePDFLoader, {}),
}
def get_loader(file_path: str, mapping: dict, default_loader: BaseLoader) -> BaseLoader:
# Choose loader from mapping, load default if no match found
ext = "." + file_path.rsplit(".", 1)[-1]
if ext in mapping:
loader_class, loader_args = mapping[ext]
loader = loader_class(file_path, **loader_args)
else:
loader = default_loader(file_path)
return loader
def load_data_source() -> List[Document]:
# Ugly thing that decides how to load data
# It aint much, but it's honest work
data_source = st.session_state["data_source"]
is_web = data_source.startswith("http")
is_dir = os.path.isdir(data_source)
is_file = os.path.isfile(data_source)
loader = None
if is_dir:
loader = DirectoryLoader(data_source, recursive=True, silent_errors=True)
elif is_web:
loader = get_loader(data_source, WEB_LOADER_MAPPING, WebBaseLoader)
elif is_file:
loader = get_loader(data_source, FILE_LOADER_MAPPING, UnstructuredFileLoader)
try:
# Chunk size is a major trade-off parameter to control result accuracy over computaion
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=st.session_state["chunk_size"],
chunk_overlap=st.session_state["chunk_overlap"],
)
docs = loader.load()
docs = text_splitter.split_documents(docs)
logger.info(f"Loaded: {len(docs)} document chucks")
return docs
except Exception as e:
msg = (
e
if loader
else f"No Loader found for your data source. Consider contributing:  {PROJECT_URL}!"
)
error_msg = f"Failed to load '{st.session_state['data_source']}':\n\n{msg}"
st.error(error_msg, icon=PAGE_ICON)
logger.error(error_msg)
st.stop()