diff --git a/app.py b/app.py index 5c99d97..c016572 100644 --- a/app.py +++ b/app.py @@ -107,6 +107,7 @@ def advanced_options_form() -> None: "model", options=MODELS.for_mode(st.session_state["mode"]), help=f"Learn more about which models are supported [here]({PROJECT_URL})", + key="model", ) col2.number_input( "temperature", @@ -158,10 +159,20 @@ def advanced_options_form() -> None: update_chain() +def update_model_on_mode_change(): + # callback for mode selectbox + st.session_state["model"] = MODELS.for_mode(st.session_state["mode"])[0] + + # Sidebar with Authentication and Advanced Options with st.sidebar: - mode = st.selectbox("Mode", MODES.all(), key="mode", help=MODE_HELP) - st.session_state["model"] = MODELS.for_mode(mode)[0] + mode = st.selectbox( + "Mode", + MODES.all(), + key="mode", + help=MODE_HELP, + on_change=update_model_on_mode_change(), + ) if mode == MODES.LOCAL and not ENABLE_LOCAL_MODE: st.error(LOCAL_MODE_DISABLED_HELP, icon=PAGE_ICON) st.stop() diff --git a/datachad/constants.py b/datachad/constants.py index dc89a55..2a4ee9b 100644 --- a/datachad/constants.py +++ b/datachad/constants.py @@ -14,7 +14,7 @@ MAX_TOKENS = 3357 MODEL_N_CTX = 1000 ENABLE_ADVANCED_OPTIONS = True -ENABLE_LOCAL_MODE = False +ENABLE_LOCAL_MODE = True MODEL_PATH = Path.cwd() / "models" GPT4ALL_BINARY = "ggml-gpt4all-j-v1.3-groovy.bin" diff --git a/datachad/models.py b/datachad/models.py index 9f3a883..eb960fe 100644 --- a/datachad/models.py +++ b/datachad/models.py @@ -82,7 +82,7 @@ def get_model() -> BaseLanguageModel: backend="gptj", temp=st.session_state["temperature"], verbose=True, - callbacks=StreamingStdOutCallbackHandler(), + callbacks=[StreamingStdOutCallbackHandler()], ) # Added models need to be cased here case _default: