Fix up mpt.

This commit is contained in:
Adam Treat 2023-05-08 12:01:40 -04:00
parent 61e2aabadb
commit b6886c0e31

View File

@ -427,7 +427,7 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
// load the model's weights from a file path // load the model's weights from a file path
bool gptj_model_load(const std::string & fname, mpt_model & model, mpt_vocab & vocab) { bool mpt_model_load(const std::string & fname, mpt_model & model, mpt_vocab & vocab) {
auto fin = std::ifstream(fname, std::ios::binary); auto fin = std::ifstream(fname, std::ios::binary);
if (!fin) { if (!fin) {
@ -771,6 +771,7 @@ size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937 &rng, uint
mpt_vocab::id mpt_sample_top_k_top_p( mpt_vocab::id mpt_sample_top_k_top_p(
const mpt_vocab & vocab, const mpt_vocab & vocab,
const size_t actualVocabSize,
const int32_t * last_n_tokens_data, const int32_t * last_n_tokens_data,
int last_n_tokens_size, int last_n_tokens_size,
const std::vector<float> logits, const std::vector<float> logits,
@ -779,7 +780,7 @@ mpt_vocab::id mpt_sample_top_k_top_p(
double temp, double temp,
float repeat_penalty, float repeat_penalty,
std::mt19937 & rng) { std::mt19937 & rng) {
int n_logits = vocab.id_to_token.size(); int n_logits = actualVocabSize;
const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size); const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size);
const auto * plogits = logits.data() + logits.size() - n_logits; const auto * plogits = logits.data() + logits.size() - n_logits;
@ -1038,7 +1039,7 @@ void MPT::prompt(const std::string &prompt,
if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) { if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) {
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase; const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
// Erase the first percentage of context from the tokens... // Erase the first percentage of context from the tokens...
std::cerr << "GPTJ: reached the end of the context window so resizing\n"; std::cerr << "MPT: reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
promptCtx.n_past = promptCtx.tokens.size(); promptCtx.n_past = promptCtx.tokens.size();
recalculateContext(promptCtx, recalculateCallback); recalculateContext(promptCtx, recalculateCallback);
@ -1081,7 +1082,7 @@ void MPT::prompt(const std::string &prompt,
int id = 0; int id = 0;
{ {
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
id = mpt_sample_top_k_top_p(d_ptr->vocab, id = mpt_sample_top_k_top_p(d_ptr->vocab, n_vocab,
promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.n_ctx, promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.n_ctx,
promptCtx.n_ctx, promptCtx.n_ctx,
promptCtx.logits, promptCtx.logits,
@ -1096,7 +1097,7 @@ void MPT::prompt(const std::string &prompt,
if (promptCtx.n_past + 1 > promptCtx.n_ctx) { if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase; const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
// Erase the first percentage of context from the tokens... // Erase the first percentage of context from the tokens...
std::cerr << "GPTJ: reached the end of the context window so resizing\n"; std::cerr << "MPT: reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
promptCtx.n_past = promptCtx.tokens.size(); promptCtx.n_past = promptCtx.tokens.size();
recalculateContext(promptCtx, recalculateCallback); recalculateContext(promptCtx, recalculateCallback);
@ -1185,7 +1186,7 @@ void MPT::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)>
if (!mpt_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits, if (!mpt_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits,
d_ptr->mem_per_token)) { d_ptr->mem_per_token)) {
std::cerr << "GPTJ ERROR: Failed to process prompt\n"; std::cerr << "MPT ERROR: Failed to process prompt\n";
goto stop_generating; goto stop_generating;
} }
promptCtx.n_past += batch.size(); promptCtx.n_past += batch.size();