one funcion to append .bin suffix

This commit is contained in:
Konstantin Gukov 2023-05-26 09:58:00 +02:00 committed by Richard Guo
parent 659244f0a2
commit a6f3e94458

View File

@ -14,6 +14,7 @@ from . import pyllmodel
# TODO: move to config # TODO: move to config
DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\") DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\")
class GPT4All(): class GPT4All():
"""Python API for retrieving and interacting with GPT4All models. """Python API for retrieving and interacting with GPT4All models.
@ -58,7 +59,8 @@ class GPT4All():
return requests.get("https://gpt4all.io/models/models.json").json() return requests.get("https://gpt4all.io/models/models.json").json()
@staticmethod @staticmethod
def retrieve_model(model_name: str, model_path: str = None, allow_download: bool = True, verbose: bool = True) -> str: def retrieve_model(model_name: str, model_path: str = None, allow_download: bool = True,
verbose: bool = True) -> str:
""" """
Find model file, and if it doesn't exist, download the model. Find model file, and if it doesn't exist, download the model.
@ -73,9 +75,7 @@ class GPT4All():
Model file destination. Model file destination.
""" """
model_filename = model_name model_filename = append_bin_suffix_if_missing(model_name)
if not model_filename.endswith(".bin"):
model_filename += ".bin"
# Validate download directory # Validate download directory
if model_path is None: if model_path is None:
@ -207,8 +207,8 @@ class GPT4All():
""" """
full_prompt = self._build_prompt(messages, full_prompt = self._build_prompt(messages,
default_prompt_header=default_prompt_header, default_prompt_header=default_prompt_header,
default_prompt_footer=default_prompt_footer) default_prompt_footer=default_prompt_footer)
if verbose: if verbose:
print(full_prompt) print(full_prompt)
@ -221,7 +221,7 @@ class GPT4All():
"model": self.model.model_name, "model": self.model.model_name,
"usage": {"prompt_tokens": len(full_prompt), "usage": {"prompt_tokens": len(full_prompt),
"completion_tokens": len(response), "completion_tokens": len(response),
"total_tokens" : len(full_prompt) + len(response)}, "total_tokens": len(full_prompt) + len(response)},
"choices": [ "choices": [
{ {
"message": { "message": {
@ -284,8 +284,7 @@ class GPT4All():
# This needs to be updated for each new model # This needs to be updated for each new model
# NOTE: We are doing this preprocessing a lot, maybe there's a better way to organize # NOTE: We are doing this preprocessing a lot, maybe there's a better way to organize
if ".bin" not in model_name: model_name = append_bin_suffix_if_missing(model_name)
model_name += ".bin"
GPTJ_MODELS = [ GPTJ_MODELS = [
"ggml-gpt4all-j-v1.3-groovy.bin", "ggml-gpt4all-j-v1.3-groovy.bin",
@ -320,3 +319,9 @@ class GPT4All():
f"If this is a custom model, make sure to specify a valid model_type.\n") f"If this is a custom model, make sure to specify a valid model_type.\n")
raise ValueError(err_msg) raise ValueError(err_msg)
def append_bin_suffix_if_missing(model_name):
if not model_name.endswith(".bin"):
model_name += ".bin"
return model_name