@ -5,7 +5,7 @@ import os
import time
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, Iterable , List , Union , Optional
from typing import Any, Dict, Iterable , List , Union , Optional
import requests
from tqdm import tqdm
@ -13,7 +13,17 @@ from tqdm import tqdm
from . import pyllmodel
# TODO: move to config
DEFAULT_MODEL_DIRECTORY = os . path . join ( str ( Path . home ( ) ) , " .cache " , " gpt4all " ) . replace ( " \\ " , " \\ \\ " )
DEFAULT_MODEL_DIRECTORY = os . path . join ( str ( Path . home ( ) ) , " .cache " , " gpt4all " ) . replace (
" \\ " , " \\ \\ "
)
DEFAULT_MODEL_CONFIG = {
" systemPrompt " : " " ,
" promptTemplate " : " ### Human: \n {0} \n ### Assistant: \n " ,
}
ConfigType = Dict [ str , str ]
MessageType = Dict [ str , str ]
class Embed4All :
"""
@ -34,7 +44,7 @@ class Embed4All:
def embed (
self ,
text : str
) - > l ist[ float ] :
) - > L ist[ float ] :
"""
Generate an embedding .
@ -74,17 +84,20 @@ class GPT4All:
self . model_type = model_type
self . model = pyllmodel . LLModel ( )
# Retrieve model and download if allowed
model_dest = self . retrieve_model ( model_name , model_path = model_path , allow_download = allow_download )
self . model . load_model ( model_dest )
self . config : ConfigType = self . retrieve_model (
model_name , model_path = model_path , allow_download = allow_download
)
self . model . load_model ( self . config [ " path " ] )
# Set n_threads
if n_threads is not None :
self . model . set_thread_count ( n_threads )
self . _is_chat_session_activated = False
self . current_chat_session = [ ]
self . _is_chat_session_activated : bool = False
self . current_chat_session : List [ MessageType ] = empty_chat_session ( )
self . _current_prompt_template : str = " {0} "
@staticmethod
def list_models ( ) - > Dict :
def list_models ( ) - > List[ ConfigType ] :
"""
Fetch model list from https : / / gpt4all . io / models / models . json .
@ -95,8 +108,11 @@ class GPT4All:
@staticmethod
def retrieve_model (
model_name : str , model_path : Optional [ str ] = None , allow_download : bool = True , verbose : bool = True
) - > str :
model_name : str ,
model_path : Optional [ str ] = None ,
allow_download : bool = True ,
verbose : bool = True ,
) - > ConfigType :
"""
Find model file , and if it doesn ' t exist, download the model.
@ -108,11 +124,25 @@ class GPT4All:
verbose : If True ( default ) , print debug messages .
Returns :
Model file destination .
Model config .
"""
model_filename = append_bin_suffix_if_missing ( model_name )
# get the config for the model
config : ConfigType = DEFAULT_MODEL_CONFIG
if allow_download :
available_models = GPT4All . list_models ( )
for m in available_models :
if model_filename == m [ " filename " ] :
config . update ( m )
config [ " systemPrompt " ] = config [ " systemPrompt " ] . strip ( )
config [ " promptTemplate " ] = config [ " promptTemplate " ] . replace (
" % 1 " , " {0} " , 1
) # change to Python-style formatting
break
# Validate download directory
if model_path is None :
try :
@ -131,31 +161,34 @@ class GPT4All:
model_dest = os . path . join ( model_path , model_filename ) . replace ( " \\ " , " \\ \\ " )
if os . path . exists ( model_dest ) :
config . pop ( " url " , None )
config [ " path " ] = model_dest
if verbose :
print ( " Found model file at " , model_dest )
return model_dest
# If model file does not exist, download
elif allow_download :
# Make sure valid model filename before attempting download
available_models = GPT4All . list_models ( )
selected_model = None
for m in available_models :
if model_filename == m [ ' filename ' ] :
selected_model = m
break
if selected_model is None :
if " url " not in config :
raise ValueError ( f " Model filename not in model list: { model_filename } " )
url = selected_model. pop ( ' url ' , None )
url = config . pop ( " url " , None )
return GPT4All . download_model ( model_filename , model_path , verbose = verbose , url = url )
config [ " path " ] = GPT4All . download_model (
model_filename , model_path , verbose = verbose , url = url
)
else :
raise ValueError ( " Failed to retrieve model " )
return config
@staticmethod
def download_model ( model_filename : str , model_path : str , verbose : bool = True , url : Optional [ str ] = None ) - > str :
def download_model (
model_filename : str ,
model_path : str ,
verbose : bool = True ,
url : Optional [ str ] = None ,
) - > str :
"""
Download model from https : / / gpt4all . io .
@ -191,7 +224,7 @@ class GPT4All:
except Exception :
if os . path . exists ( download_path ) :
if verbose :
print ( ' Cleaning up the interrupted download... ' )
print ( " Cleaning up the interrupted download... " )
os . remove ( download_path )
raise
@ -218,7 +251,8 @@ class GPT4All:
n_batch : int = 8 ,
n_predict : Optional [ int ] = None ,
streaming : bool = False ,
) - > Union [ str , Iterable ] :
callback : pyllmodel . ResponseCallbackType = pyllmodel . empty_response_callback ,
) - > Union [ str , Iterable [ str ] ] :
"""
Generate outputs from any GPT4All model .
@ -233,12 +267,14 @@ class GPT4All:
n_batch : Number of prompt tokens processed in parallel . Larger values decrease latency but increase resource requirements .
n_predict : Equivalent to max_tokens , exists for backwards compatibility .
streaming : If True , this method will instead return a generator that yields tokens as the model generates them .
callback : A function with arguments token_id : int and response : str , which receives the tokens from the model as they are generated and stops the generation by returning False .
Returns :
Either the entire completion or a generator that yields the completion token by token .
"""
generate_kwargs = dict (
prompt = prompt ,
# Preparing the model request
generate_kwargs : Dict [ str , Any ] = dict (
temp = temp ,
top_k = top_k ,
top_p = top_p ,
@ -249,42 +285,87 @@ class GPT4All:
)
if self . _is_chat_session_activated :
generate_kwargs [ " reset_context " ] = len ( self . current_chat_session ) == 1 # check if there is only one message, i.e. system prompt
self . current_chat_session . append ( { " role " : " user " , " content " : prompt } )
generate_kwargs [ ' prompt ' ] = self . _format_chat_prompt_template ( messages = self . current_chat_session [ - 1 : ] )
generate_kwargs [ ' reset_context ' ] = len ( self . current_chat_session ) == 1
else :
generate_kwargs [ ' reset_context ' ] = True
if streaming :
return self . model . prompt_model_streaming ( * * generate_kwargs )
prompt = self . _format_chat_prompt_template (
messages = self . current_chat_session [ - 1 : ] ,
default_prompt_header = self . current_chat_session [ 0 ] [ " content " ] if generate_kwargs [ " reset_context " ] else " " ,
)
else :
generate_kwargs [ " reset_context " ] = True
output = self . model . prompt_model ( * * generate_kwargs )
# Prepare the callback, process the model response
output_collector : List [ MessageType ]
output_collector = [ { " content " : " " } ] # placeholder for the self.current_chat_session if chat session is not activated
if self . _is_chat_session_activated :
self . current_chat_session . append ( { " role " : " assistant " , " content " : output } )
self . current_chat_session . append ( { " role " : " assistant " , " content " : " " } )
output_collector = self . current_chat_session
def _callback_wrapper (
callback : pyllmodel . ResponseCallbackType ,
output_collector : List [ MessageType ] ,
) - > pyllmodel . ResponseCallbackType :
return output
def _callback ( token_id : int , response : str ) - > bool :
nonlocal callback , output_collector
output_collector [ - 1 ] [ " content " ] + = response
return callback ( token_id , response )
return _callback
# Send the request to the model
if streaming :
return self . model . prompt_model_streaming (
prompt = prompt ,
callback = _callback_wrapper ( callback , output_collector ) ,
* * generate_kwargs ,
)
self . model . prompt_model (
prompt = prompt ,
callback = _callback_wrapper ( callback , output_collector ) ,
* * generate_kwargs ,
)
return output_collector [ - 1 ] [ " content " ]
@contextmanager
def chat_session ( self ) :
'''
def chat_session (
self ,
system_prompt : str = " " ,
prompt_template : str = " " ,
) :
"""
Context manager to hold an inference optimized chat session with a GPT4All model .
'''
Args :
system_prompt : An initial instruction for the model .
prompt_template : Template for the prompts with { 0 } being replaced by the user message .
"""
# Code to acquire resource, e.g.:
self . _is_chat_session_activated = True
self . current_chat_session = [ ]
self . current_chat_session = empty_chat_session ( system_prompt or self . config [ " systemPrompt " ] )
self . _current_prompt_template = prompt_template or self . config [ " promptTemplate " ]
try :
yield self
finally :
# Code to release resource, e.g.:
self . _is_chat_session_activated = False
self . current_chat_session = [ ]
self . current_chat_session = empty_chat_session ( )
self . _current_prompt_template = " {0} "
def _format_chat_prompt_template (
self , messages : List [ Dict ] , default_prompt_header = True , default_prompt_footer = True
self ,
messages : List [ MessageType ] ,
default_prompt_header : str = " " ,
default_prompt_footer : str = " " ,
) - > str :
"""
Helper method for building a prompt using template from list of messages .
Helper method for building a prompt from list of messages using the self . _current_prompt_template as a template for each message .
Args :
messages : List of dictionaries . Each dictionary should have a " role " key
@ -296,19 +377,44 @@ class GPT4All:
Returns :
Formatted prompt .
"""
full_prompt = " "
if isinstance ( default_prompt_header , bool ) :
import warnings
warnings . warn (
" Using True/False for the ' default_prompt_header ' is deprecated. Use a string instead. " ,
DeprecationWarning ,
)
default_prompt_header = " "
if isinstance ( default_prompt_footer , bool ) :
import warnings
warnings . warn (
" Using True/False for the ' default_prompt_footer ' is deprecated. Use a string instead. " ,
DeprecationWarning ,
)
default_prompt_footer = " "
full_prompt = default_prompt_header + " \n \n " if default_prompt_header != " " else " "
for message in messages :
if message [ " role " ] == " user " :
user_message = " ### Human: \n " + message [ " content " ] + " \n ### Assistant: \n "
user_message = self . _current_prompt_template . format ( message [ " content " ] )
full_prompt + = user_message
if message [ " role " ] == " assistant " :
assistant_message = message [ " content " ] + ' \n '
assistant_message = message [ " content " ] + " \n "
full_prompt + = assistant_message
full_prompt + = " \n \n " + default_prompt_footer if default_prompt_footer != " " else " "
return full_prompt
def empty_chat_session ( system_prompt : str = " " ) - > List [ MessageType ] :
return [ { " role " : " system " , " content " : system_prompt } ]
def append_bin_suffix_if_missing ( model_name ) :
if not model_name . endswith ( " .bin " ) :
model_name + = " .bin "