one funcion to append .bin suffix

This commit is contained in:
Konstantin Gukov 2023-05-26 09:58:00 +02:00 committed by Richard Guo
parent 659244f0a2
commit a6f3e94458

View File

@ -14,6 +14,7 @@ from . import pyllmodel
# TODO: move to config # TODO: move to config
DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\") DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\")
class GPT4All(): class GPT4All():
"""Python API for retrieving and interacting with GPT4All models. """Python API for retrieving and interacting with GPT4All models.
@ -58,7 +59,8 @@ class GPT4All():
return requests.get("https://gpt4all.io/models/models.json").json() return requests.get("https://gpt4all.io/models/models.json").json()
@staticmethod @staticmethod
def retrieve_model(model_name: str, model_path: str = None, allow_download: bool = True, verbose: bool = True) -> str: def retrieve_model(model_name: str, model_path: str = None, allow_download: bool = True,
verbose: bool = True) -> str:
""" """
Find model file, and if it doesn't exist, download the model. Find model file, and if it doesn't exist, download the model.
@ -72,10 +74,8 @@ class GPT4All():
Returns: Returns:
Model file destination. Model file destination.
""" """
model_filename = model_name model_filename = append_bin_suffix_if_missing(model_name)
if not model_filename.endswith(".bin"):
model_filename += ".bin"
# Validate download directory # Validate download directory
if model_path is None: if model_path is None:
@ -123,7 +123,7 @@ class GPT4All():
def get_download_url(model_filename): def get_download_url(model_filename):
return f"https://gpt4all.io/models/{model_filename}" return f"https://gpt4all.io/models/{model_filename}"
# Download model # Download model
download_path = os.path.join(model_path, model_filename).replace("\\", "\\\\") download_path = os.path.join(model_path, model_filename).replace("\\", "\\\\")
download_url = get_download_url(model_filename) download_url = get_download_url(model_filename)
@ -171,11 +171,11 @@ class GPT4All():
Raw string of generated model response. Raw string of generated model response.
""" """
return self.model.generate(prompt, streaming=streaming, **generate_kwargs) return self.model.generate(prompt, streaming=streaming, **generate_kwargs)
def chat_completion(self, def chat_completion(self,
messages: List[Dict], messages: List[Dict],
default_prompt_header: bool = True, default_prompt_header: bool = True,
default_prompt_footer: bool = True, default_prompt_footer: bool = True,
verbose: bool = True, verbose: bool = True,
streaming: bool = True, streaming: bool = True,
**generate_kwargs) -> dict: **generate_kwargs) -> dict:
@ -205,10 +205,10 @@ class GPT4All():
"choices": List of message dictionary where "content" is generated response and "role" is set "choices": List of message dictionary where "content" is generated response and "role" is set
as "assistant". Right now, only one choice is returned by model. as "assistant". Right now, only one choice is returned by model.
""" """
full_prompt = self._build_prompt(messages, full_prompt = self._build_prompt(messages,
default_prompt_header=default_prompt_header, default_prompt_header=default_prompt_header,
default_prompt_footer=default_prompt_footer) default_prompt_footer=default_prompt_footer)
if verbose: if verbose:
print(full_prompt) print(full_prompt)
@ -219,9 +219,9 @@ class GPT4All():
response_dict = { response_dict = {
"model": self.model.model_name, "model": self.model.model_name,
"usage": {"prompt_tokens": len(full_prompt), "usage": {"prompt_tokens": len(full_prompt),
"completion_tokens": len(response), "completion_tokens": len(response),
"total_tokens" : len(full_prompt) + len(response)}, "total_tokens": len(full_prompt) + len(response)},
"choices": [ "choices": [
{ {
"message": { "message": {
@ -233,10 +233,10 @@ class GPT4All():
} }
return response_dict return response_dict
@staticmethod @staticmethod
def _build_prompt(messages: List[Dict], def _build_prompt(messages: List[Dict],
default_prompt_header=True, default_prompt_header=True,
default_prompt_footer=False) -> str: default_prompt_footer=False) -> str:
# Helper method to format messages into prompt. # Helper method to format messages into prompt.
full_prompt = "" full_prompt = ""
@ -278,14 +278,13 @@ class GPT4All():
return pyllmodel.MPTModel() return pyllmodel.MPTModel()
else: else:
raise ValueError(f"No corresponding model for model_type: {model_type}") raise ValueError(f"No corresponding model for model_type: {model_type}")
@staticmethod @staticmethod
def get_model_from_name(model_name: str) -> pyllmodel.LLModel: def get_model_from_name(model_name: str) -> pyllmodel.LLModel:
# This needs to be updated for each new model # This needs to be updated for each new model
# NOTE: We are doing this preprocessing a lot, maybe there's a better way to organize # NOTE: We are doing this preprocessing a lot, maybe there's a better way to organize
if ".bin" not in model_name: model_name = append_bin_suffix_if_missing(model_name)
model_name += ".bin"
GPTJ_MODELS = [ GPTJ_MODELS = [
"ggml-gpt4all-j-v1.3-groovy.bin", "ggml-gpt4all-j-v1.3-groovy.bin",
@ -320,3 +319,9 @@ class GPT4All():
f"If this is a custom model, make sure to specify a valid model_type.\n") f"If this is a custom model, make sure to specify a valid model_type.\n")
raise ValueError(err_msg) raise ValueError(err_msg)
def append_bin_suffix_if_missing(model_name):
if not model_name.endswith(".bin"):
model_name += ".bin"
return model_name