diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index 1d58105d..d00c4fef 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -22,7 +22,7 @@ class GPT4All(): model: Pointer to underlying C model. """ - def __init__(self, model_name: str, model_path: str = None, model_type: str = None, allow_download = True): + def __init__(self, model_name: str, model_path: str = None, model_type: str = None, allow_download = True, n_threads = None): """ Constructor @@ -33,12 +33,16 @@ class GPT4All(): model_type: Model architecture. This argument currently does not have any functionality and is just used as descriptive identifier for user. Default is None. allow_download: Allow API to download models from gpt4all.io. Default is True. + n_threads: number of CPU threads used by GPT4All. Default is None, than the number of threads are determined automatically. """ self.model_type = model_type self.model = pyllmodel.LLModel() # Retrieve model and download if allowed model_dest = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download) self.model.load_model(model_dest) + # Set n_threads + if n_threads != None: + self.model.set_thread_count(n_threads) @staticmethod def list_models():