diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index d8bdfe39..fb5f709e 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -14,6 +14,7 @@ from . import pyllmodel # TODO: move to config DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\") + class GPT4All(): """Python API for retrieving and interacting with GPT4All models. @@ -58,7 +59,8 @@ class GPT4All(): return requests.get("https://gpt4all.io/models/models.json").json() @staticmethod - def retrieve_model(model_name: str, model_path: str = None, allow_download: bool = True, verbose: bool = True) -> str: + def retrieve_model(model_name: str, model_path: str = None, allow_download: bool = True, + verbose: bool = True) -> str: """ Find model file, and if it doesn't exist, download the model. @@ -72,10 +74,8 @@ class GPT4All(): Returns: Model file destination. """ - - model_filename = model_name - if not model_filename.endswith(".bin"): - model_filename += ".bin" + + model_filename = append_bin_suffix_if_missing(model_name) # Validate download directory if model_path is None: @@ -123,7 +123,7 @@ class GPT4All(): def get_download_url(model_filename): return f"https://gpt4all.io/models/{model_filename}" - + # Download model download_path = os.path.join(model_path, model_filename).replace("\\", "\\\\") download_url = get_download_url(model_filename) @@ -171,11 +171,11 @@ class GPT4All(): Raw string of generated model response. """ return self.model.generate(prompt, streaming=streaming, **generate_kwargs) - - def chat_completion(self, - messages: List[Dict], - default_prompt_header: bool = True, - default_prompt_footer: bool = True, + + def chat_completion(self, + messages: List[Dict], + default_prompt_header: bool = True, + default_prompt_footer: bool = True, verbose: bool = True, streaming: bool = True, **generate_kwargs) -> dict: @@ -205,10 +205,10 @@ class GPT4All(): "choices": List of message dictionary where "content" is generated response and "role" is set as "assistant". Right now, only one choice is returned by model. """ - - full_prompt = self._build_prompt(messages, - default_prompt_header=default_prompt_header, - default_prompt_footer=default_prompt_footer) + + full_prompt = self._build_prompt(messages, + default_prompt_header=default_prompt_header, + default_prompt_footer=default_prompt_footer) if verbose: print(full_prompt) @@ -219,9 +219,9 @@ class GPT4All(): response_dict = { "model": self.model.model_name, - "usage": {"prompt_tokens": len(full_prompt), - "completion_tokens": len(response), - "total_tokens" : len(full_prompt) + len(response)}, + "usage": {"prompt_tokens": len(full_prompt), + "completion_tokens": len(response), + "total_tokens": len(full_prompt) + len(response)}, "choices": [ { "message": { @@ -233,10 +233,10 @@ class GPT4All(): } return response_dict - + @staticmethod - def _build_prompt(messages: List[Dict], - default_prompt_header=True, + def _build_prompt(messages: List[Dict], + default_prompt_header=True, default_prompt_footer=False) -> str: # Helper method to format messages into prompt. full_prompt = "" @@ -278,14 +278,13 @@ class GPT4All(): return pyllmodel.MPTModel() else: raise ValueError(f"No corresponding model for model_type: {model_type}") - + @staticmethod def get_model_from_name(model_name: str) -> pyllmodel.LLModel: # This needs to be updated for each new model # NOTE: We are doing this preprocessing a lot, maybe there's a better way to organize - if ".bin" not in model_name: - model_name += ".bin" + model_name = append_bin_suffix_if_missing(model_name) GPTJ_MODELS = [ "ggml-gpt4all-j-v1.3-groovy.bin", @@ -320,3 +319,9 @@ class GPT4All(): f"If this is a custom model, make sure to specify a valid model_type.\n") raise ValueError(err_msg) + + +def append_bin_suffix_if_missing(model_name): + if not model_name.endswith(".bin"): + model_name += ".bin" + return model_name