diff --git a/.gitignore b/.gitignore index 5724a5c..4f8dd20 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ data +models __pycache__ .env .ipynb_checkpoints +.DS_Store testing.ipynb \ No newline at end of file diff --git a/app.py b/app.py index a091ecd..5c99d97 100644 --- a/app.py +++ b/app.py @@ -161,6 +161,7 @@ def advanced_options_form() -> None: # Sidebar with Authentication and Advanced Options with st.sidebar: mode = st.selectbox("Mode", MODES.all(), key="mode", help=MODE_HELP) + st.session_state["model"] = MODELS.for_mode(mode)[0] if mode == MODES.LOCAL and not ENABLE_LOCAL_MODE: st.error(LOCAL_MODE_DISABLED_HELP, icon=PAGE_ICON) st.stop() diff --git a/datachad/chain.py b/datachad/chain.py index 2bc91c4..5a110e1 100644 --- a/datachad/chain.py +++ b/datachad/chain.py @@ -41,11 +41,10 @@ def update_chain() -> None: try: st.session_state["chain"] = get_chain() st.session_state["chat_history"] = [] - msg = f"Data source '{st.session_state['data_source']}' is ready to go!" - logger.info(msg) + msg = f"Data source '{st.session_state['data_source']}' is ready to go with model '{st.session_state['model']}'!" st.info(msg, icon=PAGE_ICON) except Exception as e: - msg = f"Failed to build chain for data source '{st.session_state['data_source']}' with error: {e}" + msg = f"Failed to build chain for data source '{st.session_state['data_source']}' with model '{st.session_state['model']}': {e}" logger.error(msg) st.error(msg, icon=PAGE_ICON) diff --git a/datachad/constants.py b/datachad/constants.py index 6022321..2a4ee9b 100644 --- a/datachad/constants.py +++ b/datachad/constants.py @@ -14,9 +14,10 @@ MAX_TOKENS = 3357 MODEL_N_CTX = 1000 ENABLE_ADVANCED_OPTIONS = True -ENABLE_LOCAL_MODE = False +ENABLE_LOCAL_MODE = True -GPT4ALL_MODEL_PATH = "models/ggml-gpt4all-j-v1.3-groovy.bin" +MODEL_PATH = Path.cwd() / "models" +GPT4ALL_BINARY = "ggml-gpt4all-j-v1.3-groovy.bin" DATA_PATH = Path.cwd() / "data" DEFAULT_DATA_SOURCE = "https://github.com/gustavz/DataChad.git" diff --git a/datachad/models.py b/datachad/models.py index fa8f561..d90e4ac 100644 --- a/datachad/models.py +++ b/datachad/models.py @@ -5,9 +5,9 @@ from langchain.base_language import BaseLanguageModel from langchain.chat_models import ChatOpenAI from langchain.embeddings import HuggingFaceEmbeddings from langchain.embeddings.openai import Embeddings, OpenAIEmbeddings -from langchain.llms import GPT4All, LlamaCpp +from langchain.llms import GPT4All -from datachad.constants import GPT4ALL_MODEL_PATH +from datachad.constants import GPT4ALL_BINARY, MODEL_PATH from datachad.utils import logger @@ -50,7 +50,7 @@ class MODELS(Enum): name="GPT4All", mode=MODES.LOCAL, embedding=EMBEDDINGS.HUGGINGFACE, - path=GPT4ALL_MODEL_PATH, + path=str(MODEL_PATH / GPT4ALL_BINARY), ) @classmethod @@ -96,7 +96,9 @@ def get_embeddings() -> Embeddings: disallowed_special=(), openai_api_key=st.session_state["openai_api_key"] ) case EMBEDDINGS.HUGGINGFACE: - embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS.HUGGINGFACE) + embeddings = HuggingFaceEmbeddings( + model_name=EMBEDDINGS.HUGGINGFACE, cache_folder=str(MODEL_PATH) + ) # Added embeddings need to be cased here case _default: msg = f"Embeddings {st.session_state['embeddings']} not supported!" diff --git a/requirements.txt b/requirements.txt index 0f45702..1927023 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,6 @@ pdf2image==1.16.3 pytesseract==0.3.10 beautifulsoup4==4.12.2 bs4==0.0.1 -python-dotenv==1.0.0 \ No newline at end of file +python-dotenv==1.0.0 +sentence-transformers==2.2.2 +pygpt4all==1.1.0 \ No newline at end of file