From bcd43959073f68bc45fb5669300ade5d12084138 Mon Sep 17 00:00:00 2001 From: Gustav von Zitzewitz Date: Tue, 23 May 2023 09:37:41 +0200 Subject: [PATCH] Add local models --- app.py | 6 ++++-- constants.py | 5 +++++ utils.py | 29 ++++++++++++++++++++++++----- 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/app.py b/app.py index 871f29e..501950c 100644 --- a/app.py +++ b/app.py @@ -18,6 +18,7 @@ from constants import ( PROJECT_URL, TEMPERATURE, USAGE_HELP, + MODEL_N_CTX, K, ) from utils import ( @@ -45,12 +46,12 @@ SESSION_DEFAULTS = { "usage": {}, "chat_history": [], "generated": [], - "data_source": DEFAULT_DATA_SOURCE, - "uploaded_file": None, "auth_ok": False, "openai_api_key": None, "activeloop_token": None, "activeloop_org_name": None, + "uploaded_file": None, + "data_source": DEFAULT_DATA_SOURCE, "model": MODEL, "embeddings": EMBEDDINGS, "k": K, @@ -59,6 +60,7 @@ SESSION_DEFAULTS = { "chunk_overlap": CHUNK_OVERLAP, "temperature": TEMPERATURE, "max_tokens": MAX_TOKENS, + "model_n_ctx": MODEL_N_CTX, } # Initialise session state variables for k, v in SESSION_DEFAULTS.items(): diff --git a/constants.py b/constants.py index 215d6c0..bc90e3f 100644 --- a/constants.py +++ b/constants.py @@ -15,6 +15,11 @@ TEMPERATURE = 0.7 MAX_TOKENS = 3357 ENABLE_ADVANCED_OPTIONS = True +MODEL_N_CTX = 1000 +LLAMACPP_MODEL_PATH = "" +GPT4ALL_MODEL_PATH = "" +ENABLE_LOCAL_MODELS = False + DATA_PATH = Path.cwd() / "data" DEFAULT_DATA_SOURCE = "https://github.com/gustavz/DataChad.git" diff --git a/utils.py b/utils.py index 63f94c6..7aed86e 100644 --- a/utils.py +++ b/utils.py @@ -37,8 +37,10 @@ from langchain.embeddings.openai import Embeddings, OpenAIEmbeddings from langchain.schema import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import DeepLake, VectorStore +from langchain.llms import GPT4All, LlamaCpp +from langchain.embeddings import HuggingFaceEmbeddings -from constants import APP_NAME, DATA_PATH, PAGE_ICON, PROJECT_URL +from constants import APP_NAME, DATA_PATH, PAGE_ICON, PROJECT_URL, LLAMACPP_MODEL_PATH, GPT4ALL_MODEL_PATH # loads environment variables load_dotenv() @@ -190,7 +192,7 @@ WEB_LOADER_MAPPING = { } -def get_loader(file_path: str, mapping: dict, default_loader:BaseLoader) -> BaseLoader: +def get_loader(file_path: str, mapping: dict, default_loader: BaseLoader) -> BaseLoader: # Choose loader from mapping, load default if no match found ext = "." + file_path.rsplit(".", 1)[-1] if ext in mapping: @@ -238,7 +240,7 @@ def load_data_source() -> List[Document]: st.stop() -def get_data_source_string() -> str: +def get_dataset_name() -> str: # replace all non-word characters with dashes # to get a string that can be used to create a new dataset dashed_string = re.sub(r"\W+", "-", st.session_state["data_source"]) @@ -254,6 +256,21 @@ def get_model() -> BaseLanguageModel: temperature=st.session_state["temperature"], openai_api_key=st.session_state["openai_api_key"], ) + case "LlamaCpp": + model = LlamaCpp( + model_path=LLAMACPP_MODEL_PATH, + n_ctx=st.session_state["model_n_ctx"], + temperature=st.session_state["temperature"], + verbose=True, + ) + case "GPT4All": + model = GPT4All( + model=GPT4ALL_MODEL_PATH, + n_ctx=st.session_state["model_n_ctx"], + backend="gptj", + temp=st.session_state["temperature"], + verbose=True, + ) # Add more models as needed case _default: msg = f"Model {st.session_state['model']} not supported!" @@ -269,6 +286,8 @@ def get_embeddings() -> Embeddings: embeddings = OpenAIEmbeddings( disallowed_special=(), openai_api_key=st.session_state["openai_api_key"] ) + case "huggingface-Fall-MiniLM-L6-v2": + embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") # Add more embeddings as needed case _default: msg = f"Embeddings {st.session_state['embeddings']} not supported!" @@ -281,8 +300,8 @@ def get_embeddings() -> Embeddings: def get_vector_store() -> VectorStore: # either load existing vector store or upload a new one to the hub embeddings = get_embeddings() - data_source_name = get_data_source_string() - dataset_path = f"hub://{st.session_state['activeloop_org_name']}/{data_source_name}-{st.session_state['chunk_size']}" + dataset_name = get_dataset_name() + dataset_path = f"hub://{st.session_state['activeloop_org_name']}/{dataset_name}-{st.session_state['chunk_size']}" if deeplake.exists(dataset_path, token=st.session_state["activeloop_token"]): with st.spinner("Loading vector store..."): logger.info(f"Dataset '{dataset_path}' exists -> loading")