Fix for special im_end token in mpt-7b-chat model.

This commit is contained in:
Adam Treat 2023-05-08 18:55:33 -04:00
parent a4bec78ec6
commit 9c66308922

View File

@ -959,6 +959,7 @@ struct MPTPrivate {
int64_t n_threads = 0;
size_t mem_per_token = 0;
std::mt19937 rng;
bool has_im_end = false;
};
MPT::MPT()
@ -982,6 +983,7 @@ bool MPT::loadModel(const std::string &modelPath) {
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
d_ptr->modelLoaded = true;
d_ptr->has_im_end = d_ptr->vocab.token_to_id.find("<|im_end|>") != d_ptr->vocab.token_to_id.end();
fflush(stdout);
return true;
}
@ -1150,6 +1152,10 @@ void MPT::prompt(const std::string &prompt,
// display text
++totalPredictions;
// mpt-7b-chat has special token for end
if (d_ptr->has_im_end && id == d_ptr->vocab.token_to_id["<|im_end|>"])
goto stop_generating;
if (id == 0 /*end of text*/)
goto stop_generating;