From b6886c0e31387d233235ffd24477116e04d5a948 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Mon, 8 May 2023 12:01:40 -0400 Subject: [PATCH] Fix up mpt. --- llmodel/mpt.cpp | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/llmodel/mpt.cpp b/llmodel/mpt.cpp index 229a595d..fd905be3 100644 --- a/llmodel/mpt.cpp +++ b/llmodel/mpt.cpp @@ -40,7 +40,7 @@ struct mpt_layer { // attention struct ggml_tensor * attn_Wqkv_w; struct ggml_tensor * attn_out_proj_w; - + // ff struct ggml_tensor * ffn_up_proj_w; struct ggml_tensor * ffn_down_proj_w; @@ -87,7 +87,7 @@ struct mpt_model { struct ggml_tensor * norm_f_w; struct ggml_tensor * wte; // position embedding - + // mpt does weight tying std::vector layers; @@ -260,7 +260,7 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod ctx_size += n_layer*(expand*n_embd*n_embd*ggml_type_sizef(wtype)); // ffn_up_proj_w ctx_size += n_layer*(expand*n_embd*n_embd*ggml_type_sizef(wtype)); // ffn_down_proj_w - + ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F16); // memory_k ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F16); // memory_v @@ -427,7 +427,7 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod // load the model's weights from a file path -bool gptj_model_load(const std::string & fname, mpt_model & model, mpt_vocab & vocab) { +bool mpt_model_load(const std::string & fname, mpt_model & model, mpt_vocab & vocab) { auto fin = std::ifstream(fname, std::ios::binary); if (!fin) { @@ -528,7 +528,7 @@ bool mpt_eval( 0, 2, 1, 3); struct ggml_tensor * K = - ggml_permute(ctx0, + ggml_permute(ctx0, ggml_reshape_3d(ctx0, ggml_view_1d(ctx0, model.kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.kv_self.k)*n_embd), n_embd/n_head, n_head, n_past + N), @@ -641,7 +641,7 @@ std::vector mpt_tokenize(const mpt_vocab & vocab, const std::string & text) // not sure if this entirely right? std::vector words; - + // first split the text into words { std::string str = text; @@ -771,6 +771,7 @@ size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937 &rng, uint mpt_vocab::id mpt_sample_top_k_top_p( const mpt_vocab & vocab, + const size_t actualVocabSize, const int32_t * last_n_tokens_data, int last_n_tokens_size, const std::vector logits, @@ -779,7 +780,7 @@ mpt_vocab::id mpt_sample_top_k_top_p( double temp, float repeat_penalty, std::mt19937 & rng) { - int n_logits = vocab.id_to_token.size(); + int n_logits = actualVocabSize; const auto last_n_tokens = std::vector(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size); const auto * plogits = logits.data() + logits.size() - n_logits; @@ -1038,7 +1039,7 @@ void MPT::prompt(const std::string &prompt, if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) { const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase; // Erase the first percentage of context from the tokens... - std::cerr << "GPTJ: reached the end of the context window so resizing\n"; + std::cerr << "MPT: reached the end of the context window so resizing\n"; promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); promptCtx.n_past = promptCtx.tokens.size(); recalculateContext(promptCtx, recalculateCallback); @@ -1081,7 +1082,7 @@ void MPT::prompt(const std::string &prompt, int id = 0; { const int64_t t_start_sample_us = ggml_time_us(); - id = mpt_sample_top_k_top_p(d_ptr->vocab, + id = mpt_sample_top_k_top_p(d_ptr->vocab, n_vocab, promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.n_ctx, promptCtx.n_ctx, promptCtx.logits, @@ -1096,7 +1097,7 @@ void MPT::prompt(const std::string &prompt, if (promptCtx.n_past + 1 > promptCtx.n_ctx) { const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase; // Erase the first percentage of context from the tokens... - std::cerr << "GPTJ: reached the end of the context window so resizing\n"; + std::cerr << "MPT: reached the end of the context window so resizing\n"; promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); promptCtx.n_past = promptCtx.tokens.size(); recalculateContext(promptCtx, recalculateCallback); @@ -1185,7 +1186,7 @@ void MPT::recalculateContext(PromptContext &promptCtx, std::function if (!mpt_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits, d_ptr->mem_per_token)) { - std::cerr << "GPTJ ERROR: Failed to process prompt\n"; + std::cerr << "MPT ERROR: Failed to process prompt\n"; goto stop_generating; } promptCtx.n_past += batch.size();