@ -223,7 +223,7 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
// TODO: this only kind-of works, the gpt_tokenize can still incorrectly
// tokenize special tokens
if ( special ) {
vocab . add_special_token ( regex_escape( word) ) ;
vocab . add_special_token ( word) ;
}
}
}
@ -648,7 +648,7 @@ bool mpt_eval(
return true ;
}
std : : vector < int > mpt_tokenize ( const mpt_vocab & vocab , const std : : string & text ) {
std : : vector < int > mpt_tokenize _inner ( const mpt_vocab & vocab , const std : : string & text ) {
// taken from stablelm example in ggml
// they both use the gpt-neox tokenizer
// not sure if this entirely right?
@ -659,21 +659,6 @@ std::vector<int> mpt_tokenize(const mpt_vocab & vocab, const std::string & text)
{
std : : string str = text ;
std : : string pat = R " ('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^ \ s[:alpha:][:digit:]]+| \ s+(?! \ S)| \ s+) " ;
// Generate the subpattern from the special_tokens vector if it's not empty
if ( ! vocab . special_tokens . empty ( ) ) {
std : : string special_tokens_subpattern ;
for ( const auto & token : vocab . special_tokens ) {
if ( ! special_tokens_subpattern . empty ( ) ) {
special_tokens_subpattern + = " | " ;
}
special_tokens_subpattern + = token ;
}
// Modify the regex pattern with the generated special tokens subpattern
pat = special_tokens_subpattern + " | " + pat ;
}
std : : regex re ( pat ) ;
std : : smatch m ;
@ -721,6 +706,41 @@ std::vector<int> mpt_tokenize(const mpt_vocab & vocab, const std::string & text)
return tokens ;
}
std : : vector < mpt_vocab : : id > mpt_tokenize ( const mpt_vocab & vocab , const std : : string & text ) {
// Generate the subpattern from the special_tokens vector if it's not empty
if ( ! vocab . special_tokens . empty ( ) ) {
std : : vector < mpt_vocab : : id > out ;
std : : vector < std : : string > chunks ;
std : : string str = text ;
std : : string special_tokens_subpattern ;
for ( const auto & token : vocab . special_tokens ) {
if ( ! special_tokens_subpattern . empty ( ) ) {
special_tokens_subpattern + = " | " ;
}
special_tokens_subpattern + = regex_escape ( token ) ;
}
std : : regex re ( special_tokens_subpattern ) ;
std : : smatch m ;
while ( std : : regex_search ( str , m , re ) ) {
auto tok = vocab . token_to_id . find ( m . str ( ) ) ;
if ( tok ! = vocab . token_to_id . end ( ) ) {
auto tokid = tok - > second ;
auto pfxtoks = mpt_tokenize_inner ( vocab , m . prefix ( ) ) ;
out . insert ( out . end ( ) , pfxtoks . begin ( ) , pfxtoks . end ( ) ) ;
out . push_back ( tokid ) ;
str = m . suffix ( ) ;
}
}
if ( ! str . empty ( ) ) {
auto tokrest = mpt_tokenize_inner ( vocab , str ) ;
out . insert ( out . end ( ) , tokrest . begin ( ) , tokrest . end ( ) ) ;
}
return out ;
} else {
return mpt_tokenize_inner ( vocab , text ) ;
}
}
# define MPT_MAX_RNG_STATE 64*1024
size_t mpt_get_state_size ( const mpt_model & model )