|
|
|
@ -5,9 +5,9 @@ from langchain.base_language import BaseLanguageModel
|
|
|
|
|
from langchain.chat_models import ChatOpenAI
|
|
|
|
|
from langchain.embeddings import HuggingFaceEmbeddings
|
|
|
|
|
from langchain.embeddings.openai import Embeddings, OpenAIEmbeddings
|
|
|
|
|
from langchain.llms import GPT4All, LlamaCpp
|
|
|
|
|
from langchain.llms import GPT4All
|
|
|
|
|
|
|
|
|
|
from datachad.constants import GPT4ALL_MODEL_PATH
|
|
|
|
|
from datachad.constants import GPT4ALL_BINARY, MODEL_PATH
|
|
|
|
|
from datachad.utils import logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -50,7 +50,7 @@ class MODELS(Enum):
|
|
|
|
|
name="GPT4All",
|
|
|
|
|
mode=MODES.LOCAL,
|
|
|
|
|
embedding=EMBEDDINGS.HUGGINGFACE,
|
|
|
|
|
path=GPT4ALL_MODEL_PATH,
|
|
|
|
|
path=str(MODEL_PATH / GPT4ALL_BINARY),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@ -96,7 +96,9 @@ def get_embeddings() -> Embeddings:
|
|
|
|
|
disallowed_special=(), openai_api_key=st.session_state["openai_api_key"]
|
|
|
|
|
)
|
|
|
|
|
case EMBEDDINGS.HUGGINGFACE:
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS.HUGGINGFACE)
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(
|
|
|
|
|
model_name=EMBEDDINGS.HUGGINGFACE, cache_folder=str(MODEL_PATH)
|
|
|
|
|
)
|
|
|
|
|
# Added embeddings need to be cased here
|
|
|
|
|
case _default:
|
|
|
|
|
msg = f"Embeddings {st.session_state['embeddings']} not supported!"
|
|
|
|
|