|
|
@ -54,19 +54,9 @@ def load_llmodel_library():
|
|
|
|
|
|
|
|
|
|
|
|
llmodel, llama = load_llmodel_library()
|
|
|
|
llmodel, llama = load_llmodel_library()
|
|
|
|
|
|
|
|
|
|
|
|
# Define C function signatures using ctypes
|
|
|
|
class LLModelError(ctypes.Structure):
|
|
|
|
llmodel.llmodel_gptj_create.restype = ctypes.c_void_p
|
|
|
|
_fields_ = [("message", ctypes.c_char_p),
|
|
|
|
llmodel.llmodel_gptj_destroy.argtypes = [ctypes.c_void_p]
|
|
|
|
("code", ctypes.c_int32)]
|
|
|
|
llmodel.llmodel_llama_create.restype = ctypes.c_void_p
|
|
|
|
|
|
|
|
llmodel.llmodel_llama_destroy.argtypes = [ctypes.c_void_p]
|
|
|
|
|
|
|
|
llmodel.llmodel_mpt_create.restype = ctypes.c_void_p
|
|
|
|
|
|
|
|
llmodel.llmodel_mpt_destroy.argtypes = [ctypes.c_void_p]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
|
|
|
|
|
|
|
llmodel.llmodel_loadModel.restype = ctypes.c_bool
|
|
|
|
|
|
|
|
llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p]
|
|
|
|
|
|
|
|
llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LLModelPromptContext(ctypes.Structure):
|
|
|
|
class LLModelPromptContext(ctypes.Structure):
|
|
|
|
_fields_ = [("logits", ctypes.POINTER(ctypes.c_float)),
|
|
|
|
_fields_ = [("logits", ctypes.POINTER(ctypes.c_float)),
|
|
|
@ -84,6 +74,16 @@ class LLModelPromptContext(ctypes.Structure):
|
|
|
|
("repeat_last_n", ctypes.c_int32),
|
|
|
|
("repeat_last_n", ctypes.c_int32),
|
|
|
|
("context_erase", ctypes.c_float)]
|
|
|
|
("context_erase", ctypes.c_float)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Define C function signatures using ctypes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
llmodel.llmodel_model_create2.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.POINTER(LLModelError)]
|
|
|
|
|
|
|
|
llmodel.llmodel_model_create2.restype = ctypes.c_void_p
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
|
|
|
|
|
|
|
llmodel.llmodel_loadModel.restype = ctypes.c_bool
|
|
|
|
|
|
|
|
llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p]
|
|
|
|
|
|
|
|
llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
|
|
|
|
|
|
|
|
|
|
|
|
PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32)
|
|
|
|
PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32)
|
|
|
|
ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p)
|
|
|
|
ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p)
|
|
|
|
RecalculateCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_bool)
|
|
|
|
RecalculateCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_bool)
|
|
|
@ -113,18 +113,17 @@ class LLModel:
|
|
|
|
----------
|
|
|
|
----------
|
|
|
|
model: llmodel_model
|
|
|
|
model: llmodel_model
|
|
|
|
Ctype pointer to underlying model
|
|
|
|
Ctype pointer to underlying model
|
|
|
|
model_type : str
|
|
|
|
model_name: str
|
|
|
|
Model architecture identifier
|
|
|
|
Model name.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
model_type: str = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
def __init__(self):
|
|
|
|
self.model = None
|
|
|
|
self.model = None
|
|
|
|
self.model_name = None
|
|
|
|
self.model_name = None
|
|
|
|
|
|
|
|
|
|
|
|
def __del__(self):
|
|
|
|
def __del__(self):
|
|
|
|
pass
|
|
|
|
if self.model is not None and llmodel is not None:
|
|
|
|
|
|
|
|
llmodel.llmodel_model_destroy(self.model)
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(self, model_path: str) -> bool:
|
|
|
|
def load_model(self, model_path: str) -> bool:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -139,7 +138,10 @@ class LLModel:
|
|
|
|
-------
|
|
|
|
-------
|
|
|
|
True if model loaded successfully, False otherwise
|
|
|
|
True if model loaded successfully, False otherwise
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
llmodel.llmodel_loadModel(self.model, model_path.encode('utf-8'))
|
|
|
|
model_path_enc = model_path.encode("utf-8")
|
|
|
|
|
|
|
|
build_var = "auto".encode("utf-8")
|
|
|
|
|
|
|
|
self.model = llmodel.llmodel_model_create2(model_path_enc, build_var, None)
|
|
|
|
|
|
|
|
llmodel.llmodel_loadModel(self.model, model_path_enc)
|
|
|
|
filename = os.path.basename(model_path)
|
|
|
|
filename = os.path.basename(model_path)
|
|
|
|
self.model_name = os.path.splitext(filename)[0]
|
|
|
|
self.model_name = os.path.splitext(filename)[0]
|
|
|
|
|
|
|
|
|
|
|
@ -148,7 +150,6 @@ class LLModel:
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
return False
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_thread_count(self, n_threads):
|
|
|
|
def set_thread_count(self, n_threads):
|
|
|
|
if not llmodel.llmodel_isModelLoaded(self.model):
|
|
|
|
if not llmodel.llmodel_isModelLoaded(self.model):
|
|
|
|
raise Exception("Model not loaded")
|
|
|
|
raise Exception("Model not loaded")
|
|
|
@ -159,7 +160,6 @@ class LLModel:
|
|
|
|
raise Exception("Model not loaded")
|
|
|
|
raise Exception("Model not loaded")
|
|
|
|
return llmodel.llmodel_threadCount(self.model)
|
|
|
|
return llmodel.llmodel_threadCount(self.model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate(self,
|
|
|
|
def generate(self,
|
|
|
|
prompt: str,
|
|
|
|
prompt: str,
|
|
|
|
logits_size: int = 0,
|
|
|
|
logits_size: int = 0,
|
|
|
@ -246,45 +246,3 @@ class LLModel:
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def _recalculate_callback(is_recalculating):
|
|
|
|
def _recalculate_callback(is_recalculating):
|
|
|
|
return is_recalculating
|
|
|
|
return is_recalculating
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GPTJModel(LLModel):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_type = "gptj"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.model = llmodel.llmodel_gptj_create()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __del__(self):
|
|
|
|
|
|
|
|
if self.model is not None and llmodel is not None:
|
|
|
|
|
|
|
|
llmodel.llmodel_gptj_destroy(self.model)
|
|
|
|
|
|
|
|
super().__del__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LlamaModel(LLModel):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_type = "llama"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.model = llmodel.llmodel_llama_create()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __del__(self):
|
|
|
|
|
|
|
|
if self.model is not None and llmodel is not None:
|
|
|
|
|
|
|
|
llmodel.llmodel_llama_destroy(self.model)
|
|
|
|
|
|
|
|
super().__del__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MPTModel(LLModel):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_type = "mpt"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.model = llmodel.llmodel_mpt_create()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __del__(self):
|
|
|
|
|
|
|
|
if self.model is not None and llmodel is not None:
|
|
|
|
|
|
|
|
llmodel.llmodel_mpt_destroy(self.model)
|
|
|
|
|
|
|
|
super().__del__()
|
|
|
|
|
|
|
|