mpt bindings

httpserver
Richard Guo 1 year ago committed by Richard Guo
parent d56aada08c
commit 36a6e824f0

@ -29,7 +29,7 @@ class GPT4All():
model_name: Name of GPT4All or custom model. Including ".bin" file extension is optional but encouraged. model_name: Name of GPT4All or custom model. Including ".bin" file extension is optional but encouraged.
model_path: Path to directory containing model file or, if file does not exist, where to download model. model_path: Path to directory containing model file or, if file does not exist, where to download model.
Default is None, in which case models will be stored in `~/.cache/gpt4all/`. Default is None, in which case models will be stored in `~/.cache/gpt4all/`.
model_type: Model architecture to use - currently, only options are 'llama' or 'gptj'. Only required if model model_type: Model architecture to use - currently, options are 'llama', 'gptj', or 'mpt'. Only required if model
is custom. Note that these models still must be built from llama.cpp or GPTJ ggml architecture. is custom. Note that these models still must be built from llama.cpp or GPTJ ggml architecture.
Default is None. Default is None.
allow_download: Allow API to download models from gpt4all.io. Default is True. allow_download: Allow API to download models from gpt4all.io. Default is True.
@ -263,6 +263,8 @@ class GPT4All():
return pyllmodel.GPTJModel() return pyllmodel.GPTJModel()
elif model_type == "llama": elif model_type == "llama":
return pyllmodel.LlamaModel() return pyllmodel.LlamaModel()
elif model_type == "mpt":
return pyllmodel.MPTModel()
else: else:
raise ValueError(f"No corresponding model for model_type: {model_type}") raise ValueError(f"No corresponding model for model_type: {model_type}")
@ -286,13 +288,20 @@ class GPT4All():
"ggml-vicuna-7b-1.1-q4_2.bin", "ggml-vicuna-7b-1.1-q4_2.bin",
"ggml-vicuna-13b-1.1-q4_2.bin", "ggml-vicuna-13b-1.1-q4_2.bin",
"ggml-wizardLM-7B.q4_2.bin", "ggml-wizardLM-7B.q4_2.bin",
"ggml-stable-vicuna-13B.q4_2.bin" "ggml-stable-vicuna-13B.q4_2.bin",
"ggml-nous-gpt4-vicuna-13b.bin"
]
MPT_MODELS = [
"ggml-mpt-7b-base.bin"
] ]
if model_name in GPTJ_MODELS: if model_name in GPTJ_MODELS:
return pyllmodel.GPTJModel() return pyllmodel.GPTJModel()
elif model_name in LLAMA_MODELS: elif model_name in LLAMA_MODELS:
return pyllmodel.LlamaModel() return pyllmodel.LlamaModel()
elif model_name in MPT_MODELS:
return pyllmodel.MPTModel()
else: else:
err_msg = f"""No corresponding model for provided filename {model_name}. err_msg = f"""No corresponding model for provided filename {model_name}.
If this is a custom model, make sure to specify a valid model_type. If this is a custom model, make sure to specify a valid model_type.

@ -46,6 +46,9 @@ llmodel.llmodel_gptj_create.restype = ctypes.c_void_p
llmodel.llmodel_gptj_destroy.argtypes = [ctypes.c_void_p] llmodel.llmodel_gptj_destroy.argtypes = [ctypes.c_void_p]
llmodel.llmodel_llama_create.restype = ctypes.c_void_p llmodel.llmodel_llama_create.restype = ctypes.c_void_p
llmodel.llmodel_llama_destroy.argtypes = [ctypes.c_void_p] llmodel.llmodel_llama_destroy.argtypes = [ctypes.c_void_p]
llmodel.llmodel_mpt_create.restype = ctypes.c_void_p
llmodel.llmodel_mpt_destroy.argtypes = [ctypes.c_void_p]
llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p] llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
llmodel.llmodel_loadModel.restype = ctypes.c_bool llmodel.llmodel_loadModel.restype = ctypes.c_bool
@ -236,3 +239,17 @@ class LlamaModel(LLModel):
if self.model is not None: if self.model is not None:
llmodel.llmodel_llama_destroy(self.model) llmodel.llmodel_llama_destroy(self.model)
super().__del__() super().__del__()
class MPTModel(LLModel):
model_type = "mpt"
def __init__(self):
super().__init__()
self.model = llmodel.llmodel_mpt_create()
def __del__(self):
if self.model is not None:
llmodel.llmodel_mpt_destroy(self.model)
super().__del__()

@ -14,6 +14,24 @@ def test_create_llama():
llama = pyllmodel.LlamaModel() llama = pyllmodel.LlamaModel()
assert llama.model_type == "llama" assert llama.model_type == "llama"
def test_create_mpt():
mpt = pyllmodel.MPTModel()
assert mpt.model_type == "mpt"
def prompt_unloaded_mpt():
mpt = pyllmodel.MPTModel()
old_stdout = sys.stdout
collect_response = StringIO()
sys.stdout = collect_response
mpt.prompt("hello there")
response = collect_response.getvalue()
sys.stdout = old_stdout
response = response.strip()
assert response == "MPT ERROR: prompt won't work with an unloaded model!"
def prompt_unloaded_gptj(): def prompt_unloaded_gptj():
gptj = pyllmodel.GPTJModel() gptj = pyllmodel.GPTJModel()
old_stdout = sys.stdout old_stdout = sys.stdout

Loading…
Cancel
Save