@ -12,7 +12,7 @@ from langchain.llms.utils import enforce_stop_tokens
class GPT4All ( LLM ) :
r """ Wrapper around GPT4All language models.
To use , you should have the ` ` py gpt4all` ` python package installed , the
To use , you should have the ` ` gpt4all` ` python package installed , the
pre - trained model file , and the model ' s config information.
Example :
@ -28,7 +28,7 @@ class GPT4All(LLM):
model : str
""" Path to the pre-trained GPT4All model file. """
backend : str = Field ( " llama " , alias = " backend " )
backend : Optional [ str ] = Field ( None , alias = " backend " )
n_ctx : int = Field ( 512 , alias = " n_ctx " )
""" Token context window. """
@ -88,6 +88,10 @@ class GPT4All(LLM):
streaming : bool = False
""" Whether to stream the results or not. """
context_erase : float = 0.5
""" Leave (n_ctx * context_erase) tokens
starting from beginning if the context has run out . """
client : Any = None #: :meta private:
class Config :
@ -95,86 +99,55 @@ class GPT4All(LLM):
extra = Extra . forbid
def _llama_default_params ( self ) - > Dict [ str , Any ] :
""" Get the identifying parameters. """
@staticmethod
def _model_param_names ( ) - > Set [ str ] :
return {
" n_predict " : self . n_predict ,
" n_threads " : self . n_threads ,
" repeat_last_n " : self . repeat_last_n ,
" repeat_penalty " : self . repeat_penalty ,
" top_k " : self . top_k ,
" top_p " : self . top_p ,
" temp " : self . temp ,
" n_ctx " ,
" n_predict " ,
" top_k " ,
" top_p " ,
" temp " ,
" n_batch " ,
" repeat_penalty " ,
" repeat_last_n " ,
" context_erase " ,
}
def _gptj_default_params ( self ) - > Dict [ str , Any ] :
""" Get the identifying parameters. """
def _default_params ( self ) - > Dict [ str , Any ] :
return {
" n_ctx " : self . n_ctx ,
" n_predict " : self . n_predict ,
" n_threads " : self . n_threads ,
" top_k " : self . top_k ,
" top_p " : self . top_p ,
" temp " : self . temp ,
" n_batch " : self . n_batch ,
" repeat_penalty " : self . repeat_penalty ,
" repeat_last_n " : self . repeat_last_n ,
" context_erase " : self . context_erase ,
}
@staticmethod
def _llama_param_names ( ) - > Set [ str ] :
""" Get the identifying parameters. """
return {
" seed " ,
" n_ctx " ,
" n_parts " ,
" f16_kv " ,
" logits_all " ,
" vocab_only " ,
" use_mlock " ,
" embedding " ,
}
@staticmethod
def _gptj_param_names ( ) - > Set [ str ] :
""" Get the identifying parameters. """
return set ( )
@staticmethod
def _model_param_names ( backend : str ) - > Set [ str ] :
if backend == " llama " :
return GPT4All . _llama_param_names ( )
else :
return GPT4All . _gptj_param_names ( )
def _default_params ( self ) - > Dict [ str , Any ] :
if self . backend == " llama " :
return self . _llama_default_params ( )
else :
return self . _gptj_default_params ( )
@root_validator ( )
def validate_environment ( cls , values : Dict ) - > Dict :
""" Validate that the python package exists in the environment. """
try :
backend = values [ " backend " ]
if backend == " llama " :
from pygpt4all import GPT4All as GPT4AllModel
elif backend == " gptj " :
from pygpt4all import GPT4All_J as GPT4AllModel
else :
raise ValueError ( f " Incorrect gpt4all backend { cls . backend } " )
model_kwargs = {
k : v
for k , v in values . items ( )
if k in GPT4All . _model_param_names ( backend )
}
from gpt4all import GPT4All as GPT4AllModel
full_path = values [ " model " ]
model_path , delimiter , model_name = full_path . rpartition ( " / " )
model_path + = delimiter
values [ " client " ] = GPT4AllModel (
model_path = values [ " model " ] ,
* * model_kwargs ,
model_name = model_name ,
model_path = model_path or None ,
model_type = values [ " backend " ] ,
allow_download = False ,
)
values [ " backend " ] = values [ " client " ] . model . model_type
except ImportError :
raise ValueError (
" Could not import py gpt4all python package. "
" Please install it with `pip install py gpt4all`."
" Could not import gpt4all python package. "
" Please install it with `pip install gpt4all`. "
)
return values
@ -185,9 +158,7 @@ class GPT4All(LLM):
" model " : self . model ,
* * self . _default_params ( ) ,
* * {
k : v
for k , v in self . __dict__ . items ( )
if k in self . _model_param_names ( self . backend )
k : v for k , v in self . __dict__ . items ( ) if k in self . _model_param_names ( )
} ,
}