From cd83723ed7b8aef9a74a6910ba8e00c37f25132c Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Fri, 5 May 2023 10:00:05 -0400 Subject: [PATCH] Persistent state for gpt-j models too. --- llmodel/gptj.cpp | 268 +++++++++++++++++++++++++++++++++++++++++------ llmodel/gptj.h | 3 + 2 files changed, 239 insertions(+), 32 deletions(-) diff --git a/llmodel/gptj.cpp b/llmodel/gptj.cpp index 0d65c5cb..3ac0bf17 100644 --- a/llmodel/gptj.cpp +++ b/llmodel/gptj.cpp @@ -13,8 +13,11 @@ #include #include #include +#include // default hparams (GPT-J 6B) +static const size_t MB = 1024*1024; + struct gptj_hparams { int32_t n_vocab = 50400; int32_t n_ctx = 2048; @@ -45,6 +48,40 @@ struct gptj_layer { struct ggml_tensor * c_mlp_proj_b; }; +struct gptj_buffer { + uint8_t * addr = NULL; + size_t size = 0; + + void resize(size_t size) { + delete[] addr; + addr = new uint8_t[size]; + this->size = size; + } + + ~gptj_buffer() { + std::cout << "yes we are cleaning up" << std::endl; + fflush(stdout); + delete[] addr; + } +}; + +struct gptj_kv_cache { + struct ggml_tensor * k; + struct ggml_tensor * v; + + struct ggml_context * ctx = NULL; + + gptj_buffer buf; + + int n; // number of tokens currently in the cache + + ~gptj_kv_cache() { + if (ctx) { + ggml_free(ctx); + } + } +}; + struct gptj_model { gptj_hparams hparams; @@ -60,14 +97,52 @@ struct gptj_model { std::vector layers; // key + value memory - struct ggml_tensor * memory_k; - struct ggml_tensor * memory_v; + struct gptj_kv_cache kv_self; // struct ggml_context * ctx; std::map tensors; + + gptj_buffer buf; + + ~gptj_model() { + if (ctx) { + ggml_free(ctx); + } + } }; +static bool kv_cache_init( + const struct gptj_hparams & hparams, + struct gptj_kv_cache & cache, + ggml_type wtype, + int n_ctx) { + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + + const int64_t n_mem = (int64_t)n_layer*n_ctx; + const int64_t n_elements = n_embd*n_mem; + + cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB); + + struct ggml_init_params params; + params.mem_size = cache.buf.size; + params.mem_buffer = cache.buf.addr; + params.no_alloc = false; + + cache.ctx = ggml_init(params); + + if (!cache.ctx) { + fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); + return false; + } + + cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + + return true; +} + // load the model's weights from a stream bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & model, gpt_vocab & vocab) { printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); @@ -277,12 +352,14 @@ bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & m const int n_mem = n_layer*n_ctx; const int n_elements = n_embd*n_mem; - model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements); - model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements); + if (!kv_cache_init(hparams, model.kv_self, GGML_TYPE_F32, model.hparams.n_ctx)) { + fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); + ggml_free(ctx); + return false; + } - const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v); - - printf("%s: memory_size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem); + const size_t memory_size = ggml_nbytes(model.kv_self.k) + ggml_nbytes(model.kv_self.v); + printf("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); } // load weights @@ -400,7 +477,7 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab & // The GPT-J model requires about 16MB of memory per input token. // bool gptj_eval( - const gptj_model & model, + gptj_model & model, const int n_threads, const int n_past, const std::vector & embd_inp, @@ -419,25 +496,25 @@ bool gptj_eval( const int d_key = n_embd/n_head; - static size_t buf_size = 1024u*1024*1024; - static void * buf = malloc(buf_size); + static size_t buf_size = 1024u*MB; + if (!model.buf.addr || model.buf.size < buf_size) + model.buf.resize(buf_size); - if (mem_per_token > 0 && mem_per_token*N > buf_size) { + if (mem_per_token > 0 && mem_per_token*N > model.buf.size) { const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead - printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new); + printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, model.buf.size, buf_size_new); // reallocate - buf_size = buf_size_new; - buf = realloc(buf, buf_size); - if (buf == nullptr) { - fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size); + model.buf.resize(buf_size_new); + if (model.buf.addr == nullptr) { + fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, model.buf.size); return false; } } struct ggml_init_params params = { - .mem_size = buf_size, - .mem_buffer = buf, + .mem_size = model.buf.size, + .mem_buffer = model.buf.addr, }; struct ggml_context * ctx0 = ggml_init(params); @@ -474,8 +551,8 @@ bool gptj_eval( // store key and value to memory if (N >= 1) { - struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past)); - struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past)); + struct ggml_tensor * k = ggml_view_1d(ctx0, model.kv_self.k, N*n_embd, (ggml_element_size(model.kv_self.k)*n_embd)*(il*n_ctx + n_past)); + struct ggml_tensor * v = ggml_view_1d(ctx0, model.kv_self.v, N*n_embd, (ggml_element_size(model.kv_self.v)*n_embd)*(il*n_ctx + n_past)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); @@ -496,7 +573,7 @@ bool gptj_eval( ggml_permute(ctx0, ggml_rope(ctx0, ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd), + ggml_view_1d(ctx0, model.kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.kv_self.k)*n_embd), n_embd/n_head, n_head, n_past + N), n_past, n_rot, 1), 0, 2, 1, 3); @@ -522,10 +599,10 @@ bool gptj_eval( ggml_cpy(ctx0, ggml_permute(ctx0, ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd), + ggml_view_1d(ctx0, model.kv_self.v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.kv_self.v)*n_embd), n_embd/n_head, n_head, n_past + N), 1, 2, 0, 3), - ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head)); + ggml_new_tensor_3d(ctx0, model.kv_self.v->type, n_past + N, n_embd/n_head, n_head)); // KQV = transpose(V) * KQ_soft_max struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); @@ -629,11 +706,122 @@ bool gptj_eval( return true; } +#define GPTJ_MAX_RNG_STATE 64*1024 + +size_t gptj_get_state_size(const gptj_model &model) +{ + // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. + // for reference, std::mt19937(1337) serializes to 6701 bytes. + const size_t s_rng_size = sizeof(size_t); + const size_t s_rng = GPTJ_MAX_RNG_STATE; + const size_t s_kv_size = sizeof(size_t); + const size_t s_kv_ntok = sizeof(int); + const size_t s_kv = model.kv_self.buf.size; + const size_t s_total = ( + + s_rng_size + + s_rng + + s_kv_size + + s_kv_ntok + + s_kv + ); + fflush(stdout); + return s_total; +} + +size_t gptj_copy_state_data(const gptj_model &model, const std::mt19937 &rng, uint8_t *dest) +{ + uint8_t * out = dest; + fflush(stdout); + // copy rng + { + std::stringstream rng_ss; + rng_ss << rng; + + const size_t rng_size = rng_ss.str().size(); + char rng_buf[GPTJ_MAX_RNG_STATE]; + + memset(&rng_buf[0], 0, GPTJ_MAX_RNG_STATE); + memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); + + memcpy(out, &rng_size, sizeof(rng_size)); out += sizeof(rng_size); + memcpy(out, &rng_buf[0], GPTJ_MAX_RNG_STATE); out += GPTJ_MAX_RNG_STATE; + } + + // copy kv cache + { + const size_t kv_size = model.kv_self.buf.size; + const int kv_ntok = model.kv_self.n; + + memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size); + memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok); + + if (kv_size) { + memcpy(out, model.kv_self.buf.addr, kv_size); out += kv_size; + } + } + + const size_t written = out - dest; + const size_t expected = gptj_get_state_size(model); + assert(written == expected); + fflush(stdout); + return written; +} + +size_t gptj_set_state_data(gptj_model *model, std::mt19937 *rng, const uint8_t *src) +{ + const uint8_t * in = src; + + // set rng + { + size_t rng_size; + char rng_buf[GPTJ_MAX_RNG_STATE]; + + memcpy(&rng_size, in, sizeof(rng_size)); in += sizeof(rng_size); + memcpy(&rng_buf[0], in, GPTJ_MAX_RNG_STATE); in += GPTJ_MAX_RNG_STATE; + + std::stringstream rng_ss; + rng_ss.str(std::string(&rng_buf[0], rng_size)); + rng_ss >> *rng; + + assert(rng_ss.fail() == false); + } + + // set kv cache + { + size_t kv_size; + int kv_ntok; + + memcpy(&kv_size, in, sizeof(kv_size)); in += sizeof(kv_size); + memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok); + + if (kv_size) { + assert(model->kv_self.buf.size == kv_size); + + void * k_data = model->kv_self.k->data; // remember data pointers + void * v_data = model->kv_self.v->data; // because their value is stored in buf and overwritten by memcpy + + memcpy(model->kv_self.buf.addr, in, kv_size); in += kv_size; + + model->kv_self.k->data = k_data; // restore correct data pointers + model->kv_self.v->data = v_data; + + } + + model->kv_self.n = kv_ntok; + } + + const size_t nread = in - src; + const size_t expected = gptj_get_state_size(*model); + assert(nread == expected); + fflush(stdout); + return nread; +} + struct GPTJPrivate { const std::string modelPath; bool modelLoaded; gpt_vocab vocab; - gptj_model model; + gptj_model *model = nullptr; int64_t n_threads = 0; size_t mem_per_token = 0; std::mt19937 rng; @@ -642,6 +830,7 @@ struct GPTJPrivate { GPTJ::GPTJ() : d_ptr(new GPTJPrivate) { + d_ptr->model = new gptj_model; d_ptr->modelLoaded = false; } @@ -652,7 +841,7 @@ bool GPTJ::loadModel(const std::string &modelPath) { auto fin = std::ifstream(modelPath, std::ios::binary); // load the model - if (!gptj_model_load(modelPath, fin, d_ptr->model, d_ptr->vocab)) { + if (!gptj_model_load(modelPath, fin, *d_ptr->model, d_ptr->vocab)) { std::cerr << "GPT-J ERROR: failed to load model from " << modelPath; return false; } @@ -673,7 +862,7 @@ int32_t GPTJ::threadCount() { GPTJ::~GPTJ() { - ggml_free(d_ptr->model.ctx); + delete d_ptr->model; } bool GPTJ::isModelLoaded() const @@ -681,6 +870,21 @@ bool GPTJ::isModelLoaded() const return d_ptr->modelLoaded; } +size_t GPTJ::stateSize() const +{ + return gptj_get_state_size(*d_ptr->model); +} + +size_t GPTJ::saveState(uint8_t *dest) const +{ + return gptj_copy_state_data(*d_ptr->model, d_ptr->rng, dest); +} + +size_t GPTJ::restoreState(const uint8_t *src) +{ + return gptj_set_state_data(d_ptr->model, &d_ptr->rng, src); +} + void GPTJ::prompt(const std::string &prompt, std::function promptCallback, std::function responseCallback, @@ -702,7 +906,7 @@ void GPTJ::prompt(const std::string &prompt, std::vector embd_inp = ::gpt_tokenize(d_ptr->vocab, prompt); // save the context size - promptCtx.n_ctx = d_ptr->model.hparams.n_ctx; + promptCtx.n_ctx = d_ptr->model->hparams.n_ctx; if ((int) embd_inp.size() > promptCtx.n_ctx - 4) { responseCallback(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed."); @@ -719,7 +923,7 @@ void GPTJ::prompt(const std::string &prompt, static std::vector p_instruct; static std::vector r_instruct; if (!initialized) { - gptj_eval(d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, promptCtx.logits, + gptj_eval(*d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, promptCtx.logits, d_ptr->mem_per_token); initialized = true; } @@ -742,7 +946,7 @@ void GPTJ::prompt(const std::string &prompt, assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx); } - if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits, + if (!gptj_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits, d_ptr->mem_per_token)) { std::cerr << "GPT-J ERROR: Failed to process prompt\n"; return; @@ -769,7 +973,7 @@ void GPTJ::prompt(const std::string &prompt, for (int i = 0; i < promptCtx.n_predict; i++) { // sample next token - const int n_vocab = d_ptr->model.hparams.n_vocab; + const int n_vocab = d_ptr->model->hparams.n_vocab; gpt_vocab::id id = 0; { const int64_t t_start_sample_us = ggml_time_us(); @@ -796,7 +1000,7 @@ void GPTJ::prompt(const std::string &prompt, } const int64_t t_start_predict_us = ggml_time_us(); - if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits, + if (!gptj_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits, d_ptr->mem_per_token)) { std::cerr << "GPT-J ERROR: Failed to predict next token\n"; return; @@ -846,7 +1050,7 @@ void GPTJ::recalculateContext(PromptContext &promptCtx, std::functionmodel, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits, + if (!gptj_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits, d_ptr->mem_per_token)) { std::cerr << "GPTJ ERROR: Failed to process prompt\n"; goto stop_generating; diff --git a/llmodel/gptj.h b/llmodel/gptj.h index 70a4655a..3109c1da 100644 --- a/llmodel/gptj.h +++ b/llmodel/gptj.h @@ -14,6 +14,9 @@ public: bool loadModel(const std::string &modelPath) override; bool isModelLoaded() const override; + size_t stateSize() const override; + size_t saveState(uint8_t *dest) const override; + size_t restoreState(const uint8_t *src) override; void prompt(const std::string &prompt, std::function promptCallback, std::function responseCallback,