mpt bindings

This commit is contained in:
Richard Guo 2023-05-11 15:26:20 -04:00 committed by Richard Guo
parent 77fc90dd4c
commit 35e8b3984f
4 changed files with 46 additions and 2 deletions

View File

@ -29,7 +29,7 @@ class GPT4All():
model_name: Name of GPT4All or custom model. Including ".bin" file extension is optional but encouraged.
model_path: Path to directory containing model file or, if file does not exist, where to download model.
Default is None, in which case models will be stored in `~/.cache/gpt4all/`.
model_type: Model architecture to use - currently, only options are 'llama' or 'gptj'. Only required if model
model_type: Model architecture to use - currently, options are 'llama', 'gptj', or 'mpt'. Only required if model
is custom. Note that these models still must be built from llama.cpp or GPTJ ggml architecture.
Default is None.
allow_download: Allow API to download models from gpt4all.io. Default is True.
@ -263,6 +263,8 @@ class GPT4All():
return pyllmodel.GPTJModel()
elif model_type == "llama":
return pyllmodel.LlamaModel()
elif model_type == "mpt":
return pyllmodel.MPTModel()
else:
raise ValueError(f"No corresponding model for model_type: {model_type}")
@ -286,13 +288,20 @@ class GPT4All():
"ggml-vicuna-7b-1.1-q4_2.bin",
"ggml-vicuna-13b-1.1-q4_2.bin",
"ggml-wizardLM-7B.q4_2.bin",
"ggml-stable-vicuna-13B.q4_2.bin"
"ggml-stable-vicuna-13B.q4_2.bin",
"ggml-nous-gpt4-vicuna-13b.bin"
]
MPT_MODELS = [
"ggml-mpt-7b-base.bin"
]
if model_name in GPTJ_MODELS:
return pyllmodel.GPTJModel()
elif model_name in LLAMA_MODELS:
return pyllmodel.LlamaModel()
elif model_name in MPT_MODELS:
return pyllmodel.MPTModel()
else:
err_msg = f"""No corresponding model for provided filename {model_name}.
If this is a custom model, make sure to specify a valid model_type.

View File

@ -46,6 +46,9 @@ llmodel.llmodel_gptj_create.restype = ctypes.c_void_p
llmodel.llmodel_gptj_destroy.argtypes = [ctypes.c_void_p]
llmodel.llmodel_llama_create.restype = ctypes.c_void_p
llmodel.llmodel_llama_destroy.argtypes = [ctypes.c_void_p]
llmodel.llmodel_mpt_create.restype = ctypes.c_void_p
llmodel.llmodel_mpt_destroy.argtypes = [ctypes.c_void_p]
llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
llmodel.llmodel_loadModel.restype = ctypes.c_bool
@ -236,3 +239,17 @@ class LlamaModel(LLModel):
if self.model is not None:
llmodel.llmodel_llama_destroy(self.model)
super().__del__()
class MPTModel(LLModel):
model_type = "mpt"
def __init__(self):
super().__init__()
self.model = llmodel.llmodel_mpt_create()
def __del__(self):
if self.model is not None:
llmodel.llmodel_mpt_destroy(self.model)
super().__del__()

View File

@ -14,6 +14,24 @@ def test_create_llama():
llama = pyllmodel.LlamaModel()
assert llama.model_type == "llama"
def test_create_mpt():
mpt = pyllmodel.MPTModel()
assert mpt.model_type == "mpt"
def prompt_unloaded_mpt():
mpt = pyllmodel.MPTModel()
old_stdout = sys.stdout
collect_response = StringIO()
sys.stdout = collect_response
mpt.prompt("hello there")
response = collect_response.getvalue()
sys.stdout = old_stdout
response = response.strip()
assert response == "MPT ERROR: prompt won't work with an unloaded model!"
def prompt_unloaded_gptj():
gptj = pyllmodel.GPTJModel()
old_stdout = sys.stdout