set n_threads in GPT4All python bindings (#1042)

* set n_threads in GPT4All

* changed default n_threads to None
This commit is contained in:
EKal-aa 2023-06-23 10:16:35 +02:00 committed by GitHub
parent ae3d91476c
commit aed7b43143
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -22,7 +22,7 @@ class GPT4All():
model: Pointer to underlying C model.
"""
def __init__(self, model_name: str, model_path: str = None, model_type: str = None, allow_download = True):
def __init__(self, model_name: str, model_path: str = None, model_type: str = None, allow_download = True, n_threads = None):
"""
Constructor
@ -33,12 +33,16 @@ class GPT4All():
model_type: Model architecture. This argument currently does not have any functionality and is just used as
descriptive identifier for user. Default is None.
allow_download: Allow API to download models from gpt4all.io. Default is True.
n_threads: number of CPU threads used by GPT4All. Default is None, than the number of threads are determined automatically.
"""
self.model_type = model_type
self.model = pyllmodel.LLModel()
# Retrieve model and download if allowed
model_dest = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download)
self.model.load_model(model_dest)
# Set n_threads
if n_threads != None:
self.model.set_thread_count(n_threads)
@staticmethod
def list_models():