|
|
|
@ -58,7 +58,7 @@ class GPT4All():
|
|
|
|
|
return requests.get("https://gpt4all.io/models/models.json").json()
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def retrieve_model(model_name: str, model_path: str = None, allow_download: bool = True) -> str:
|
|
|
|
|
def retrieve_model(model_name: str, model_path: str = None, allow_download: bool = True, verbose: bool = True) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Find model file, and if it doesn't exist, download the model.
|
|
|
|
|
|
|
|
|
@ -67,6 +67,7 @@ class GPT4All():
|
|
|
|
|
model_path: Path to find model. Default is None in which case path is set to
|
|
|
|
|
~/.cache/gpt4all/.
|
|
|
|
|
allow_download: Allow API to download model from gpt4all.io. Default is True.
|
|
|
|
|
verbose: If True (default), print debug messages.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Model file destination.
|
|
|
|
@ -92,7 +93,8 @@ class GPT4All():
|
|
|
|
|
|
|
|
|
|
model_dest = os.path.join(model_path, model_filename).replace("\\", "\\\\")
|
|
|
|
|
if os.path.exists(model_dest):
|
|
|
|
|
print("Found model file at ", model_dest)
|
|
|
|
|
if verbose:
|
|
|
|
|
print("Found model file at ", model_dest)
|
|
|
|
|
return model_dest
|
|
|
|
|
|
|
|
|
|
# If model file does not exist, download
|
|
|
|
@ -106,13 +108,14 @@ class GPT4All():
|
|
|
|
|
raise ValueError("Failed to retrieve model")
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def download_model(model_filename: str, model_path: str) -> str:
|
|
|
|
|
def download_model(model_filename: str, model_path: str, verbose: bool) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Download model from https://gpt4all.io.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
model_filename: Filename of model (with .bin extension).
|
|
|
|
|
model_path: Path to download model to.
|
|
|
|
|
verbose: If True (default), print debug messages.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Model file destination.
|
|
|
|
@ -137,7 +140,8 @@ class GPT4All():
|
|
|
|
|
file.write(data)
|
|
|
|
|
except Exception:
|
|
|
|
|
if os.path.exists(download_path):
|
|
|
|
|
print('Cleaning up the interrupted download...')
|
|
|
|
|
if verbose:
|
|
|
|
|
print('Cleaning up the interrupted download...')
|
|
|
|
|
os.remove(download_path)
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
@ -150,7 +154,8 @@ class GPT4All():
|
|
|
|
|
# Sleep for a little bit so OS can remove file lock
|
|
|
|
|
time.sleep(2)
|
|
|
|
|
|
|
|
|
|
print("Model downloaded at: " + download_path)
|
|
|
|
|
if verbose:
|
|
|
|
|
print("Model downloaded at: ", download_path)
|
|
|
|
|
return download_path
|
|
|
|
|
|
|
|
|
|
def generate(self, prompt: str, streaming: bool = True, **generate_kwargs) -> str:
|
|
|
|
|