|
|
|
@ -59,33 +59,34 @@ class MODELS(Enum):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_model() -> BaseLanguageModel:
|
|
|
|
|
match st.session_state["model"].name:
|
|
|
|
|
case MODELS.GPT35TURBO.name:
|
|
|
|
|
model = ChatOpenAI(
|
|
|
|
|
model_name=st.session_state["model"].name,
|
|
|
|
|
temperature=st.session_state["temperature"],
|
|
|
|
|
openai_api_key=st.session_state["openai_api_key"],
|
|
|
|
|
)
|
|
|
|
|
case MODELS.GPT4.name:
|
|
|
|
|
model = ChatOpenAI(
|
|
|
|
|
model_name=st.session_state["model"].name,
|
|
|
|
|
temperature=st.session_state["temperature"],
|
|
|
|
|
openai_api_key=st.session_state["openai_api_key"],
|
|
|
|
|
)
|
|
|
|
|
case MODELS.GPT4ALL.name:
|
|
|
|
|
model = GPT4All(
|
|
|
|
|
model=st.session_state["model"].path,
|
|
|
|
|
n_ctx=st.session_state["model_n_ctx"],
|
|
|
|
|
backend="gptj",
|
|
|
|
|
temp=st.session_state["temperature"],
|
|
|
|
|
verbose=True,
|
|
|
|
|
)
|
|
|
|
|
# Added models need to be cased here
|
|
|
|
|
case _default:
|
|
|
|
|
msg = f"Model {st.session_state['model']} not supported!"
|
|
|
|
|
logger.error(msg)
|
|
|
|
|
st.error(msg)
|
|
|
|
|
exit
|
|
|
|
|
with st.spinner("Loading Model..."):
|
|
|
|
|
match st.session_state["model"].name:
|
|
|
|
|
case MODELS.GPT35TURBO.name:
|
|
|
|
|
model = ChatOpenAI(
|
|
|
|
|
model_name=st.session_state["model"].name,
|
|
|
|
|
temperature=st.session_state["temperature"],
|
|
|
|
|
openai_api_key=st.session_state["openai_api_key"],
|
|
|
|
|
)
|
|
|
|
|
case MODELS.GPT4.name:
|
|
|
|
|
model = ChatOpenAI(
|
|
|
|
|
model_name=st.session_state["model"].name,
|
|
|
|
|
temperature=st.session_state["temperature"],
|
|
|
|
|
openai_api_key=st.session_state["openai_api_key"],
|
|
|
|
|
)
|
|
|
|
|
case MODELS.GPT4ALL.name:
|
|
|
|
|
model = GPT4All(
|
|
|
|
|
model=st.session_state["model"].path,
|
|
|
|
|
n_ctx=st.session_state["model_n_ctx"],
|
|
|
|
|
backend="gptj",
|
|
|
|
|
temp=st.session_state["temperature"],
|
|
|
|
|
verbose=True,
|
|
|
|
|
)
|
|
|
|
|
# Added models need to be cased here
|
|
|
|
|
case _default:
|
|
|
|
|
msg = f"Model {st.session_state['model']} not supported!"
|
|
|
|
|
logger.error(msg)
|
|
|
|
|
st.error(msg)
|
|
|
|
|
exit
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|