mirror of
https://github.com/nomic-ai/gpt4all
synced 2024-11-06 09:20:33 +00:00
mpt bindings
This commit is contained in:
parent
77fc90dd4c
commit
35e8b3984f
@ -29,7 +29,7 @@ class GPT4All():
|
||||
model_name: Name of GPT4All or custom model. Including ".bin" file extension is optional but encouraged.
|
||||
model_path: Path to directory containing model file or, if file does not exist, where to download model.
|
||||
Default is None, in which case models will be stored in `~/.cache/gpt4all/`.
|
||||
model_type: Model architecture to use - currently, only options are 'llama' or 'gptj'. Only required if model
|
||||
model_type: Model architecture to use - currently, options are 'llama', 'gptj', or 'mpt'. Only required if model
|
||||
is custom. Note that these models still must be built from llama.cpp or GPTJ ggml architecture.
|
||||
Default is None.
|
||||
allow_download: Allow API to download models from gpt4all.io. Default is True.
|
||||
@ -263,6 +263,8 @@ class GPT4All():
|
||||
return pyllmodel.GPTJModel()
|
||||
elif model_type == "llama":
|
||||
return pyllmodel.LlamaModel()
|
||||
elif model_type == "mpt":
|
||||
return pyllmodel.MPTModel()
|
||||
else:
|
||||
raise ValueError(f"No corresponding model for model_type: {model_type}")
|
||||
|
||||
@ -286,13 +288,20 @@ class GPT4All():
|
||||
"ggml-vicuna-7b-1.1-q4_2.bin",
|
||||
"ggml-vicuna-13b-1.1-q4_2.bin",
|
||||
"ggml-wizardLM-7B.q4_2.bin",
|
||||
"ggml-stable-vicuna-13B.q4_2.bin"
|
||||
"ggml-stable-vicuna-13B.q4_2.bin",
|
||||
"ggml-nous-gpt4-vicuna-13b.bin"
|
||||
]
|
||||
|
||||
MPT_MODELS = [
|
||||
"ggml-mpt-7b-base.bin"
|
||||
]
|
||||
|
||||
if model_name in GPTJ_MODELS:
|
||||
return pyllmodel.GPTJModel()
|
||||
elif model_name in LLAMA_MODELS:
|
||||
return pyllmodel.LlamaModel()
|
||||
elif model_name in MPT_MODELS:
|
||||
return pyllmodel.MPTModel()
|
||||
else:
|
||||
err_msg = f"""No corresponding model for provided filename {model_name}.
|
||||
If this is a custom model, make sure to specify a valid model_type.
|
||||
|
@ -46,6 +46,9 @@ llmodel.llmodel_gptj_create.restype = ctypes.c_void_p
|
||||
llmodel.llmodel_gptj_destroy.argtypes = [ctypes.c_void_p]
|
||||
llmodel.llmodel_llama_create.restype = ctypes.c_void_p
|
||||
llmodel.llmodel_llama_destroy.argtypes = [ctypes.c_void_p]
|
||||
llmodel.llmodel_mpt_create.restype = ctypes.c_void_p
|
||||
llmodel.llmodel_mpt_destroy.argtypes = [ctypes.c_void_p]
|
||||
|
||||
|
||||
llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
||||
llmodel.llmodel_loadModel.restype = ctypes.c_bool
|
||||
@ -236,3 +239,17 @@ class LlamaModel(LLModel):
|
||||
if self.model is not None:
|
||||
llmodel.llmodel_llama_destroy(self.model)
|
||||
super().__del__()
|
||||
|
||||
|
||||
class MPTModel(LLModel):
|
||||
|
||||
model_type = "mpt"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = llmodel.llmodel_mpt_create()
|
||||
|
||||
def __del__(self):
|
||||
if self.model is not None:
|
||||
llmodel.llmodel_mpt_destroy(self.model)
|
||||
super().__del__()
|
||||
|
0
gpt4all-bindings/python/tests/__init__.py
Normal file
0
gpt4all-bindings/python/tests/__init__.py
Normal file
@ -14,6 +14,24 @@ def test_create_llama():
|
||||
llama = pyllmodel.LlamaModel()
|
||||
assert llama.model_type == "llama"
|
||||
|
||||
def test_create_mpt():
|
||||
mpt = pyllmodel.MPTModel()
|
||||
assert mpt.model_type == "mpt"
|
||||
|
||||
def prompt_unloaded_mpt():
|
||||
mpt = pyllmodel.MPTModel()
|
||||
old_stdout = sys.stdout
|
||||
collect_response = StringIO()
|
||||
sys.stdout = collect_response
|
||||
|
||||
mpt.prompt("hello there")
|
||||
|
||||
response = collect_response.getvalue()
|
||||
sys.stdout = old_stdout
|
||||
|
||||
response = response.strip()
|
||||
assert response == "MPT ERROR: prompt won't work with an unloaded model!"
|
||||
|
||||
def prompt_unloaded_gptj():
|
||||
gptj = pyllmodel.GPTJModel()
|
||||
old_stdout = sys.stdout
|
||||
|
Loading…
Reference in New Issue
Block a user