diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index 2b901251..5787d293 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -29,7 +29,7 @@ class GPT4All(): model_name: Name of GPT4All or custom model. Including ".bin" file extension is optional but encouraged. model_path: Path to directory containing model file or, if file does not exist, where to download model. Default is None, in which case models will be stored in `~/.cache/gpt4all/`. - model_type: Model architecture to use - currently, only options are 'llama' or 'gptj'. Only required if model + model_type: Model architecture to use - currently, options are 'llama', 'gptj', or 'mpt'. Only required if model is custom. Note that these models still must be built from llama.cpp or GPTJ ggml architecture. Default is None. allow_download: Allow API to download models from gpt4all.io. Default is True. @@ -263,6 +263,8 @@ class GPT4All(): return pyllmodel.GPTJModel() elif model_type == "llama": return pyllmodel.LlamaModel() + elif model_type == "mpt": + return pyllmodel.MPTModel() else: raise ValueError(f"No corresponding model for model_type: {model_type}") @@ -286,13 +288,20 @@ class GPT4All(): "ggml-vicuna-7b-1.1-q4_2.bin", "ggml-vicuna-13b-1.1-q4_2.bin", "ggml-wizardLM-7B.q4_2.bin", - "ggml-stable-vicuna-13B.q4_2.bin" + "ggml-stable-vicuna-13B.q4_2.bin", + "ggml-nous-gpt4-vicuna-13b.bin" + ] + + MPT_MODELS = [ + "ggml-mpt-7b-base.bin" ] if model_name in GPTJ_MODELS: return pyllmodel.GPTJModel() elif model_name in LLAMA_MODELS: return pyllmodel.LlamaModel() + elif model_name in MPT_MODELS: + return pyllmodel.MPTModel() else: err_msg = f"""No corresponding model for provided filename {model_name}. If this is a custom model, make sure to specify a valid model_type. diff --git a/gpt4all-bindings/python/gpt4all/pyllmodel.py b/gpt4all-bindings/python/gpt4all/pyllmodel.py index 49ba184e..d194c234 100644 --- a/gpt4all-bindings/python/gpt4all/pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/pyllmodel.py @@ -46,6 +46,9 @@ llmodel.llmodel_gptj_create.restype = ctypes.c_void_p llmodel.llmodel_gptj_destroy.argtypes = [ctypes.c_void_p] llmodel.llmodel_llama_create.restype = ctypes.c_void_p llmodel.llmodel_llama_destroy.argtypes = [ctypes.c_void_p] +llmodel.llmodel_mpt_create.restype = ctypes.c_void_p +llmodel.llmodel_mpt_destroy.argtypes = [ctypes.c_void_p] + llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p] llmodel.llmodel_loadModel.restype = ctypes.c_bool @@ -236,3 +239,17 @@ class LlamaModel(LLModel): if self.model is not None: llmodel.llmodel_llama_destroy(self.model) super().__del__() + + +class MPTModel(LLModel): + + model_type = "mpt" + + def __init__(self): + super().__init__() + self.model = llmodel.llmodel_mpt_create() + + def __del__(self): + if self.model is not None: + llmodel.llmodel_mpt_destroy(self.model) + super().__del__() diff --git a/gpt4all-bindings/python/tests/__init__.py b/gpt4all-bindings/python/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/gpt4all-bindings/python/tests/test_pyllmodel.py b/gpt4all-bindings/python/tests/test_pyllmodel.py index 5535cb7b..e446b0b5 100644 --- a/gpt4all-bindings/python/tests/test_pyllmodel.py +++ b/gpt4all-bindings/python/tests/test_pyllmodel.py @@ -14,6 +14,24 @@ def test_create_llama(): llama = pyllmodel.LlamaModel() assert llama.model_type == "llama" +def test_create_mpt(): + mpt = pyllmodel.MPTModel() + assert mpt.model_type == "mpt" + +def prompt_unloaded_mpt(): + mpt = pyllmodel.MPTModel() + old_stdout = sys.stdout + collect_response = StringIO() + sys.stdout = collect_response + + mpt.prompt("hello there") + + response = collect_response.getvalue() + sys.stdout = old_stdout + + response = response.strip() + assert response == "MPT ERROR: prompt won't work with an unloaded model!" + def prompt_unloaded_gptj(): gptj = pyllmodel.GPTJModel() old_stdout = sys.stdout