@ -2,11 +2,20 @@
# include <cassert>
# include <iostream>
# include <regex>
# include <unordered_set>
// TODO(cebtenzzre): replace this with llama_kv_cache_seq_shift for llamamodel (GPT-J needs this as-is)
void LLModel : : recalculateContext ( PromptContext & promptCtx , std : : function < bool ( bool ) > recalculate ) {
size_t i = 0 ;
promptCtx . n_past = 0 ;
int n_keep = shouldAddBOS ( ) ;
const int32_t n_discard = ( promptCtx . n_ctx - n_keep ) * promptCtx . contextErase ;
// Erase the first percentage of context from the tokens
std : : cerr < < implementation ( ) . modelType ( ) < < " : reached the end of the context window so resizing \n " ;
promptCtx . tokens . erase ( promptCtx . tokens . begin ( ) + n_keep , promptCtx . tokens . begin ( ) + n_keep + n_discard ) ;
size_t i = n_keep ;
promptCtx . n_past = n_keep ;
while ( i < promptCtx . tokens . size ( ) ) {
size_t batch_end = std : : min ( i + promptCtx . n_batch , promptCtx . tokens . size ( ) ) ;
std : : vector < int32_t > batch ( promptCtx . tokens . begin ( ) + i , promptCtx . tokens . begin ( ) + batch_end ) ;
@ -26,11 +35,36 @@ stop_generating:
recalculate ( false ) ;
}
static bool parsePromptTemplate ( const std : : string & tmpl , std : : vector < std : : smatch > & placeholders , std : : string & err ) {
static const std : : regex placeholderRegex ( R " (%[1-2](?![0-9])) " ) ;
auto it = std : : sregex_iterator ( tmpl . begin ( ) , tmpl . end ( ) , placeholderRegex ) ;
placeholders . clear ( ) ;
placeholders . insert ( placeholders . end ( ) , it , std : : sregex_iterator ( ) ) ;
if ( placeholders . size ( ) > 2 ) {
err = " ERROR: expected at most two placeholders, got " + std : : to_string ( placeholders . size ( ) ) ;
return false ;
}
if ( placeholders . size ( ) > = 1 & & placeholders [ 0 ] . str ( ) ! = " %1 " ) {
err = " ERROR: first placeholder must be %1, got " + placeholders [ 0 ] . str ( ) ;
return false ;
}
if ( placeholders . size ( ) > = 2 & & placeholders [ 1 ] . str ( ) ! = " %2 " ) {
err = " ERROR: second placeholder must be %2, got " + placeholders [ 1 ] . str ( ) ;
return false ;
}
return true ;
}
void LLModel : : prompt ( const std : : string & prompt ,
const std : : string & promptTemplate ,
std : : function < bool ( int32_t ) > promptCallback ,
std : : function < bool ( int32_t , const std : : string & ) > responseCallback ,
std : : function < bool ( bool ) > recalculateCallback ,
PromptContext & promptCtx )
PromptContext & promptCtx ,
bool special ,
std : : string * fakeReply )
{
if ( ! isModelLoaded ( ) ) {
std : : cerr < < implementation ( ) . modelType ( ) < < " ERROR: prompt won't work with an unloaded model! \n " ;
@ -38,15 +72,86 @@ void LLModel::prompt(const std::string &prompt,
}
if ( ! supportsCompletion ( ) ) {
std : : string errorMessage = " ERROR: this model does not support text completion or chat! \n " ;
std : : string errorMessage = " ERROR: this model does not support text completion or chat! " ;
responseCallback ( - 1 , errorMessage ) ;
std : : cerr < < implementation ( ) . modelType ( ) < < errorMessage ;
std : : cerr < < implementation ( ) . modelType ( ) < < " " < < errorMessage < < " \n " ;
return ;
}
// tokenize the prompt
std : : vector < Token > embd_inp = tokenize ( promptCtx , prompt ) ;
// parse the prompt template
std : : vector < std : : smatch > placeholders ;
{
std : : string err ;
if ( ! parsePromptTemplate ( promptTemplate , placeholders , err ) ) {
responseCallback ( - 1 , err ) ;
std : : cerr < < err < < " \n " ;
return ;
}
}
auto old_n_past = promptCtx . n_past ; // prepare to fake n_past for tokenize
// tokenize the user prompt
std : : vector < Token > embd_inp ;
if ( placeholders . empty ( ) ) {
// this is unusual, but well-defined
std : : cerr < < __func__ < < " : prompt template has no placeholder \n " ;
embd_inp = tokenize ( promptCtx , promptTemplate , true ) ;
} else {
// template: beginning of user prompt
const auto & phUser = placeholders [ 0 ] ;
std : : string userPrefix ( phUser . prefix ( ) ) ;
if ( ! userPrefix . empty ( ) ) {
embd_inp = tokenize ( promptCtx , userPrefix , true ) ;
promptCtx . n_past + = embd_inp . size ( ) ;
}
// user input (shouldn't have special token processing)
auto tokens = tokenize ( promptCtx , prompt , special ) ;
embd_inp . insert ( embd_inp . end ( ) , tokens . begin ( ) , tokens . end ( ) ) ;
promptCtx . n_past + = tokens . size ( ) ;
// template: end of user prompt + start of assistant prompt
size_t start = phUser . position ( ) + phUser . length ( ) ;
size_t end = placeholders . size ( ) > = 2 ? placeholders [ 1 ] . position ( ) : promptTemplate . length ( ) ;
auto userToAsst = promptTemplate . substr ( start , end - start ) ;
if ( ! userToAsst . empty ( ) ) {
tokens = tokenize ( promptCtx , userToAsst , true ) ;
embd_inp . insert ( embd_inp . end ( ) , tokens . begin ( ) , tokens . end ( ) ) ;
promptCtx . n_past + = tokens . size ( ) ;
}
}
promptCtx . n_past = old_n_past ; // restore n_past so decodePrompt can increment it
// decode the user prompt
decodePrompt ( promptCallback , responseCallback , recalculateCallback , promptCtx , embd_inp ) ;
// decode the assistant's reply, either generated or spoofed
if ( fakeReply = = nullptr ) {
generateResponse ( responseCallback , recalculateCallback , promptCtx ) ;
} else {
embd_inp = tokenize ( promptCtx , * fakeReply , false ) ;
decodePrompt ( promptCallback , responseCallback , recalculateCallback , promptCtx , embd_inp ) ;
}
// decode the rest of the prompt template
if ( placeholders . size ( ) > = 2 ) {
// template: end of assistant prompt
size_t start = placeholders [ 1 ] . position ( ) + placeholders [ 1 ] . length ( ) ;
auto asstSuffix = promptTemplate . substr ( start ) ;
if ( ! asstSuffix . empty ( ) ) {
embd_inp = tokenize ( promptCtx , asstSuffix , true ) ;
decodePrompt ( promptCallback , responseCallback , recalculateCallback , promptCtx , embd_inp ) ;
}
}
}
void LLModel : : decodePrompt ( std : : function < bool ( int32_t ) > promptCallback ,
std : : function < bool ( int32_t , const std : : string & ) > responseCallback ,
std : : function < bool ( bool ) > recalculateCallback ,
PromptContext & promptCtx ,
std : : vector < Token > embd_inp ) {
// save the context size
promptCtx . n_ctx = contextLength ( ) ;
@ -69,11 +174,6 @@ void LLModel::prompt(const std::string &prompt,
// Check if the context has run out...
if ( promptCtx . n_past + int32_t ( batch . size ( ) ) > promptCtx . n_ctx ) {
const int32_t erasePoint = promptCtx . n_ctx * promptCtx . contextErase ;
// Erase the first percentage of context from the tokens...
std : : cerr < < implementation ( ) . modelType ( ) < < " : reached the end of the context window so resizing \n " ;
promptCtx . tokens . erase ( promptCtx . tokens . begin ( ) , promptCtx . tokens . begin ( ) + erasePoint ) ;
promptCtx . n_past = promptCtx . tokens . size ( ) ;
recalculateContext ( promptCtx , recalculateCallback ) ;
assert ( promptCtx . n_past + int32_t ( batch . size ( ) ) < = promptCtx . n_ctx ) ;
}
@ -94,7 +194,11 @@ void LLModel::prompt(const std::string &prompt,
}
i = batch_end ;
}
}
void LLModel : : generateResponse ( std : : function < bool ( int32_t , const std : : string & ) > responseCallback ,
std : : function < bool ( bool ) > recalculateCallback ,
PromptContext & promptCtx ) {
std : : string cachedResponse ;
std : : vector < Token > cachedTokens ;
std : : unordered_set < std : : string > reversePrompts
@ -108,11 +212,6 @@ void LLModel::prompt(const std::string &prompt,
// Check if the context has run out...
if ( promptCtx . n_past + 1 > promptCtx . n_ctx ) {
const int32_t erasePoint = promptCtx . n_ctx * promptCtx . contextErase ;
// Erase the first percentage of context from the tokens...
std : : cerr < < implementation ( ) . modelType ( ) < < " : reached the end of the context window so resizing \n " ;
promptCtx . tokens . erase ( promptCtx . tokens . begin ( ) , promptCtx . tokens . begin ( ) + erasePoint ) ;
promptCtx . n_past = promptCtx . tokens . size ( ) ;
recalculateContext ( promptCtx , recalculateCallback ) ;
assert ( promptCtx . n_past + 1 < = promptCtx . n_ctx ) ;
}
@ -165,8 +264,9 @@ void LLModel::prompt(const std::string &prompt,
}
}
std : : vector < float > LLModel : : embedding ( const std : : string & /*text*/ )
std : : vector < float > LLModel : : embedding ( const std : : string & text )
{
( void ) text ;
if ( ! supportsCompletion ( ) ) {
std : : string errorMessage = " ERROR: this model does not support generating embeddings! \n " ;
std : : cerr < < implementation ( ) . modelType ( ) < < errorMessage ;