diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index c485f2de..97ed37ab 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -22,7 +22,7 @@ class GPT4All(): model: Pointer to underlying C model. """ - def __init__(self, model_name: str, model_path: str = None, model_type: str = None, allow_download=True): + def __init__(self, model_name: str, model_path: str = None, model_type: str = None, allow_download=True): """ Constructor @@ -30,20 +30,12 @@ class GPT4All(): model_name: Name of GPT4All or custom model. Including ".bin" file extension is optional but encouraged. model_path: Path to directory containing model file or, if file does not exist, where to download model. Default is None, in which case models will be stored in `~/.cache/gpt4all/`. - model_type: Model architecture to use - currently, options are 'llama', 'gptj', or 'mpt'. Only required if model - is custom. Note that these models still must be built from llama.cpp or GPTJ ggml architecture. - Default is None. + model_type: Model architecture. This argument currently does not have any functionality and is just used as + descriptive identifier for user. Default is None. allow_download: Allow API to download models from gpt4all.io. Default is True. """ - self.model = None - - # Model type provided for when model is custom - if model_type: - self.model = GPT4All.get_model_from_type(model_type) - # Else get model from gpt4all model filenames - else: - self.model = GPT4All.get_model_from_name(model_name) - + self.model_type = model_type + self.model = pyllmodel.LLModel() # Retrieve model and download if allowed model_dest = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download) self.model.load_model(model_dest) @@ -265,61 +257,6 @@ class GPT4All(): return full_prompt - @staticmethod - def get_model_from_type(model_type: str) -> pyllmodel.LLModel: - # This needs to be updated for each new model type - # TODO: Might be worth converting model_type to enum - - if model_type == "gptj": - return pyllmodel.GPTJModel() - elif model_type == "llama": - return pyllmodel.LlamaModel() - elif model_type == "mpt": - return pyllmodel.MPTModel() - else: - raise ValueError(f"No corresponding model for model_type: {model_type}") - - @staticmethod - def get_model_from_name(model_name: str) -> pyllmodel.LLModel: - # This needs to be updated for each new model - - # NOTE: We are doing this preprocessing a lot, maybe there's a better way to organize - model_name = append_bin_suffix_if_missing(model_name) - - GPTJ_MODELS = [ - "ggml-gpt4all-j-v1.3-groovy.bin", - "ggml-gpt4all-j-v1.2-jazzy.bin", - "ggml-gpt4all-j-v1.1-breezy.bin", - "ggml-gpt4all-j.bin" - ] - - LLAMA_MODELS = [ - "ggml-gpt4all-l13b-snoozy.bin", - "ggml-vicuna-7b-1.1-q4_2.bin", - "ggml-vicuna-13b-1.1-q4_2.bin", - "ggml-wizardLM-7B.q4_2.bin", - "ggml-stable-vicuna-13B.q4_2.bin", - "ggml-nous-gpt4-vicuna-13b.bin" - ] - - MPT_MODELS = [ - "ggml-mpt-7b-base.bin", - "ggml-mpt-7b-chat.bin", - "ggml-mpt-7b-instruct.bin" - ] - - if model_name in GPTJ_MODELS: - return pyllmodel.GPTJModel() - elif model_name in LLAMA_MODELS: - return pyllmodel.LlamaModel() - elif model_name in MPT_MODELS: - return pyllmodel.MPTModel() - - err_msg = (f"No corresponding model for provided filename {model_name}.\n" - f"If this is a custom model, make sure to specify a valid model_type.\n") - - raise ValueError(err_msg) - def append_bin_suffix_if_missing(model_name): if not model_name.endswith(".bin"): diff --git a/gpt4all-bindings/python/gpt4all/pyllmodel.py b/gpt4all-bindings/python/gpt4all/pyllmodel.py index 6117c9fa..2e312ffc 100644 --- a/gpt4all-bindings/python/gpt4all/pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/pyllmodel.py @@ -54,19 +54,9 @@ def load_llmodel_library(): llmodel, llama = load_llmodel_library() -# Define C function signatures using ctypes -llmodel.llmodel_gptj_create.restype = ctypes.c_void_p -llmodel.llmodel_gptj_destroy.argtypes = [ctypes.c_void_p] -llmodel.llmodel_llama_create.restype = ctypes.c_void_p -llmodel.llmodel_llama_destroy.argtypes = [ctypes.c_void_p] -llmodel.llmodel_mpt_create.restype = ctypes.c_void_p -llmodel.llmodel_mpt_destroy.argtypes = [ctypes.c_void_p] - - -llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p] -llmodel.llmodel_loadModel.restype = ctypes.c_bool -llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p] -llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool +class LLModelError(ctypes.Structure): + _fields_ = [("message", ctypes.c_char_p), + ("code", ctypes.c_int32)] class LLModelPromptContext(ctypes.Structure): _fields_ = [("logits", ctypes.POINTER(ctypes.c_float)), @@ -83,7 +73,17 @@ class LLModelPromptContext(ctypes.Structure): ("repeat_penalty", ctypes.c_float), ("repeat_last_n", ctypes.c_int32), ("context_erase", ctypes.c_float)] - + +# Define C function signatures using ctypes + +llmodel.llmodel_model_create2.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.POINTER(LLModelError)] +llmodel.llmodel_model_create2.restype = ctypes.c_void_p + +llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p] +llmodel.llmodel_loadModel.restype = ctypes.c_bool +llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p] +llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool + PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32) ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p) RecalculateCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_bool) @@ -113,18 +113,17 @@ class LLModel: ---------- model: llmodel_model Ctype pointer to underlying model - model_type : str - Model architecture identifier + model_name: str + Model name. """ - model_type: str = None - def __init__(self): self.model = None self.model_name = None def __del__(self): - pass + if self.model is not None and llmodel is not None: + llmodel.llmodel_model_destroy(self.model) def load_model(self, model_path: str) -> bool: """ @@ -139,7 +138,10 @@ class LLModel: ------- True if model loaded successfully, False otherwise """ - llmodel.llmodel_loadModel(self.model, model_path.encode('utf-8')) + model_path_enc = model_path.encode("utf-8") + build_var = "auto".encode("utf-8") + self.model = llmodel.llmodel_model_create2(model_path_enc, build_var, None) + llmodel.llmodel_loadModel(self.model, model_path_enc) filename = os.path.basename(model_path) self.model_name = os.path.splitext(filename)[0] @@ -148,7 +150,6 @@ class LLModel: else: return False - def set_thread_count(self, n_threads): if not llmodel.llmodel_isModelLoaded(self.model): raise Exception("Model not loaded") @@ -159,7 +160,6 @@ class LLModel: raise Exception("Model not loaded") return llmodel.llmodel_threadCount(self.model) - def generate(self, prompt: str, logits_size: int = 0, @@ -246,45 +246,3 @@ class LLModel: @staticmethod def _recalculate_callback(is_recalculating): return is_recalculating - - -class GPTJModel(LLModel): - - model_type = "gptj" - - def __init__(self): - super().__init__() - self.model = llmodel.llmodel_gptj_create() - - def __del__(self): - if self.model is not None and llmodel is not None: - llmodel.llmodel_gptj_destroy(self.model) - super().__del__() - - -class LlamaModel(LLModel): - - model_type = "llama" - - def __init__(self): - super().__init__() - self.model = llmodel.llmodel_llama_create() - - def __del__(self): - if self.model is not None and llmodel is not None: - llmodel.llmodel_llama_destroy(self.model) - super().__del__() - - -class MPTModel(LLModel): - - model_type = "mpt" - - def __init__(self): - super().__init__() - self.model = llmodel.llmodel_mpt_create() - - def __del__(self): - if self.model is not None and llmodel is not None: - llmodel.llmodel_mpt_destroy(self.model) - super().__del__()