updated bindings code for updated C api

pull/815/head
Richard Guo 1 year ago committed by AT
parent f0be66a221
commit ae42805d49

@ -22,7 +22,7 @@ class GPT4All():
model: Pointer to underlying C model. model: Pointer to underlying C model.
""" """
def __init__(self, model_name: str, model_path: str = None, model_type: str = None, allow_download=True): def __init__(self, model_name: str, model_path: str = None, model_type: str = None, allow_download=True):
""" """
Constructor Constructor
@ -30,20 +30,12 @@ class GPT4All():
model_name: Name of GPT4All or custom model. Including ".bin" file extension is optional but encouraged. model_name: Name of GPT4All or custom model. Including ".bin" file extension is optional but encouraged.
model_path: Path to directory containing model file or, if file does not exist, where to download model. model_path: Path to directory containing model file or, if file does not exist, where to download model.
Default is None, in which case models will be stored in `~/.cache/gpt4all/`. Default is None, in which case models will be stored in `~/.cache/gpt4all/`.
model_type: Model architecture to use - currently, options are 'llama', 'gptj', or 'mpt'. Only required if model model_type: Model architecture. This argument currently does not have any functionality and is just used as
is custom. Note that these models still must be built from llama.cpp or GPTJ ggml architecture. descriptive identifier for user. Default is None.
Default is None.
allow_download: Allow API to download models from gpt4all.io. Default is True. allow_download: Allow API to download models from gpt4all.io. Default is True.
""" """
self.model = None self.model_type = model_type
self.model = pyllmodel.LLModel()
# Model type provided for when model is custom
if model_type:
self.model = GPT4All.get_model_from_type(model_type)
# Else get model from gpt4all model filenames
else:
self.model = GPT4All.get_model_from_name(model_name)
# Retrieve model and download if allowed # Retrieve model and download if allowed
model_dest = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download) model_dest = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download)
self.model.load_model(model_dest) self.model.load_model(model_dest)
@ -265,61 +257,6 @@ class GPT4All():
return full_prompt return full_prompt
@staticmethod
def get_model_from_type(model_type: str) -> pyllmodel.LLModel:
# This needs to be updated for each new model type
# TODO: Might be worth converting model_type to enum
if model_type == "gptj":
return pyllmodel.GPTJModel()
elif model_type == "llama":
return pyllmodel.LlamaModel()
elif model_type == "mpt":
return pyllmodel.MPTModel()
else:
raise ValueError(f"No corresponding model for model_type: {model_type}")
@staticmethod
def get_model_from_name(model_name: str) -> pyllmodel.LLModel:
# This needs to be updated for each new model
# NOTE: We are doing this preprocessing a lot, maybe there's a better way to organize
model_name = append_bin_suffix_if_missing(model_name)
GPTJ_MODELS = [
"ggml-gpt4all-j-v1.3-groovy.bin",
"ggml-gpt4all-j-v1.2-jazzy.bin",
"ggml-gpt4all-j-v1.1-breezy.bin",
"ggml-gpt4all-j.bin"
]
LLAMA_MODELS = [
"ggml-gpt4all-l13b-snoozy.bin",
"ggml-vicuna-7b-1.1-q4_2.bin",
"ggml-vicuna-13b-1.1-q4_2.bin",
"ggml-wizardLM-7B.q4_2.bin",
"ggml-stable-vicuna-13B.q4_2.bin",
"ggml-nous-gpt4-vicuna-13b.bin"
]
MPT_MODELS = [
"ggml-mpt-7b-base.bin",
"ggml-mpt-7b-chat.bin",
"ggml-mpt-7b-instruct.bin"
]
if model_name in GPTJ_MODELS:
return pyllmodel.GPTJModel()
elif model_name in LLAMA_MODELS:
return pyllmodel.LlamaModel()
elif model_name in MPT_MODELS:
return pyllmodel.MPTModel()
err_msg = (f"No corresponding model for provided filename {model_name}.\n"
f"If this is a custom model, make sure to specify a valid model_type.\n")
raise ValueError(err_msg)
def append_bin_suffix_if_missing(model_name): def append_bin_suffix_if_missing(model_name):
if not model_name.endswith(".bin"): if not model_name.endswith(".bin"):

@ -54,19 +54,9 @@ def load_llmodel_library():
llmodel, llama = load_llmodel_library() llmodel, llama = load_llmodel_library()
# Define C function signatures using ctypes class LLModelError(ctypes.Structure):
llmodel.llmodel_gptj_create.restype = ctypes.c_void_p _fields_ = [("message", ctypes.c_char_p),
llmodel.llmodel_gptj_destroy.argtypes = [ctypes.c_void_p] ("code", ctypes.c_int32)]
llmodel.llmodel_llama_create.restype = ctypes.c_void_p
llmodel.llmodel_llama_destroy.argtypes = [ctypes.c_void_p]
llmodel.llmodel_mpt_create.restype = ctypes.c_void_p
llmodel.llmodel_mpt_destroy.argtypes = [ctypes.c_void_p]
llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
llmodel.llmodel_loadModel.restype = ctypes.c_bool
llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p]
llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
class LLModelPromptContext(ctypes.Structure): class LLModelPromptContext(ctypes.Structure):
_fields_ = [("logits", ctypes.POINTER(ctypes.c_float)), _fields_ = [("logits", ctypes.POINTER(ctypes.c_float)),
@ -83,7 +73,17 @@ class LLModelPromptContext(ctypes.Structure):
("repeat_penalty", ctypes.c_float), ("repeat_penalty", ctypes.c_float),
("repeat_last_n", ctypes.c_int32), ("repeat_last_n", ctypes.c_int32),
("context_erase", ctypes.c_float)] ("context_erase", ctypes.c_float)]
# Define C function signatures using ctypes
llmodel.llmodel_model_create2.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.POINTER(LLModelError)]
llmodel.llmodel_model_create2.restype = ctypes.c_void_p
llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
llmodel.llmodel_loadModel.restype = ctypes.c_bool
llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p]
llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32) PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32)
ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p) ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p)
RecalculateCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_bool) RecalculateCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_bool)
@ -113,18 +113,17 @@ class LLModel:
---------- ----------
model: llmodel_model model: llmodel_model
Ctype pointer to underlying model Ctype pointer to underlying model
model_type : str model_name: str
Model architecture identifier Model name.
""" """
model_type: str = None
def __init__(self): def __init__(self):
self.model = None self.model = None
self.model_name = None self.model_name = None
def __del__(self): def __del__(self):
pass if self.model is not None and llmodel is not None:
llmodel.llmodel_model_destroy(self.model)
def load_model(self, model_path: str) -> bool: def load_model(self, model_path: str) -> bool:
""" """
@ -139,7 +138,10 @@ class LLModel:
------- -------
True if model loaded successfully, False otherwise True if model loaded successfully, False otherwise
""" """
llmodel.llmodel_loadModel(self.model, model_path.encode('utf-8')) model_path_enc = model_path.encode("utf-8")
build_var = "auto".encode("utf-8")
self.model = llmodel.llmodel_model_create2(model_path_enc, build_var, None)
llmodel.llmodel_loadModel(self.model, model_path_enc)
filename = os.path.basename(model_path) filename = os.path.basename(model_path)
self.model_name = os.path.splitext(filename)[0] self.model_name = os.path.splitext(filename)[0]
@ -148,7 +150,6 @@ class LLModel:
else: else:
return False return False
def set_thread_count(self, n_threads): def set_thread_count(self, n_threads):
if not llmodel.llmodel_isModelLoaded(self.model): if not llmodel.llmodel_isModelLoaded(self.model):
raise Exception("Model not loaded") raise Exception("Model not loaded")
@ -159,7 +160,6 @@ class LLModel:
raise Exception("Model not loaded") raise Exception("Model not loaded")
return llmodel.llmodel_threadCount(self.model) return llmodel.llmodel_threadCount(self.model)
def generate(self, def generate(self,
prompt: str, prompt: str,
logits_size: int = 0, logits_size: int = 0,
@ -246,45 +246,3 @@ class LLModel:
@staticmethod @staticmethod
def _recalculate_callback(is_recalculating): def _recalculate_callback(is_recalculating):
return is_recalculating return is_recalculating
class GPTJModel(LLModel):
model_type = "gptj"
def __init__(self):
super().__init__()
self.model = llmodel.llmodel_gptj_create()
def __del__(self):
if self.model is not None and llmodel is not None:
llmodel.llmodel_gptj_destroy(self.model)
super().__del__()
class LlamaModel(LLModel):
model_type = "llama"
def __init__(self):
super().__init__()
self.model = llmodel.llmodel_llama_create()
def __del__(self):
if self.model is not None and llmodel is not None:
llmodel.llmodel_llama_destroy(self.model)
super().__del__()
class MPTModel(LLModel):
model_type = "mpt"
def __init__(self):
super().__init__()
self.model = llmodel.llmodel_mpt_create()
def __del__(self):
if self.model is not None and llmodel is not None:
llmodel.llmodel_mpt_destroy(self.model)
super().__del__()

Loading…
Cancel
Save