From 49fc7b315ae619d9a3b638f6c50b8dc9676bd73f Mon Sep 17 00:00:00 2001 From: Aaron Miller Date: Mon, 8 May 2023 14:42:20 -0700 Subject: [PATCH] mpt tokenizer: better special token handling closer to the behavior of huggingface `tokenizers`, do not attempt to handle additional tokens as if they were part of the original vocabulary as this cannot prevent them from being split into smaller chunks - handle added tokens *before* the regular tokenizing pass note this is still necessary even with a "proper" tokenizer implementation --- llmodel/mpt.cpp | 54 +++++++++++++++++++++++++++++++++---------------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/llmodel/mpt.cpp b/llmodel/mpt.cpp index a336b921..c8f3c230 100644 --- a/llmodel/mpt.cpp +++ b/llmodel/mpt.cpp @@ -223,7 +223,7 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod // TODO: this only kind-of works, the gpt_tokenize can still incorrectly // tokenize special tokens if(special) { - vocab.add_special_token(regex_escape(word)); + vocab.add_special_token(word); } } } @@ -648,7 +648,7 @@ bool mpt_eval( return true; } -std::vector mpt_tokenize(const mpt_vocab & vocab, const std::string & text) { +std::vector mpt_tokenize_inner(const mpt_vocab & vocab, const std::string & text) { // taken from stablelm example in ggml // they both use the gpt-neox tokenizer // not sure if this entirely right? @@ -659,21 +659,6 @@ std::vector mpt_tokenize(const mpt_vocab & vocab, const std::string & text) { std::string str = text; std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; - - // Generate the subpattern from the special_tokens vector if it's not empty - if (!vocab.special_tokens.empty()) { - std::string special_tokens_subpattern; - for (const auto &token : vocab.special_tokens) { - if (!special_tokens_subpattern.empty()) { - special_tokens_subpattern += "|"; - } - special_tokens_subpattern += token; - } - - // Modify the regex pattern with the generated special tokens subpattern - pat = special_tokens_subpattern + "|" + pat; - } - std::regex re(pat); std::smatch m; @@ -721,6 +706,41 @@ std::vector mpt_tokenize(const mpt_vocab & vocab, const std::string & text) return tokens; } +std::vector mpt_tokenize(const mpt_vocab & vocab, const std::string & text) { + // Generate the subpattern from the special_tokens vector if it's not empty + if (!vocab.special_tokens.empty()) { + std::vector out; + std::vector chunks; + std::string str = text; + std::string special_tokens_subpattern; + for (const auto &token : vocab.special_tokens) { + if (!special_tokens_subpattern.empty()) { + special_tokens_subpattern += "|"; + } + special_tokens_subpattern += regex_escape(token); + } + std::regex re(special_tokens_subpattern); + std::smatch m; + while (std::regex_search(str, m, re)) { + auto tok = vocab.token_to_id.find(m.str()); + if (tok != vocab.token_to_id.end()) { + auto tokid = tok->second; + auto pfxtoks = mpt_tokenize_inner(vocab, m.prefix()); + out.insert(out.end(), pfxtoks.begin(), pfxtoks.end()); + out.push_back(tokid); + str = m.suffix(); + } + } + if (!str.empty()) { + auto tokrest = mpt_tokenize_inner(vocab, str); + out.insert(out.end(), tokrest.begin(), tokrest.end()); + } + return out; + } else { + return mpt_tokenize_inner(vocab, text); + } +} + #define MPT_MAX_RNG_STATE 64*1024 size_t mpt_get_state_size(const mpt_model &model)