diff --git a/gptj.cpp b/gptj.cpp index 7319b62e..832d2652 100644 --- a/gptj.cpp +++ b/gptj.cpp @@ -419,7 +419,7 @@ bool gptj_eval( const int d_key = n_embd/n_head; - static size_t buf_size = 256u*1024*1024; + static size_t buf_size = 1024u*1024*1024; static void * buf = malloc(buf_size); if (mem_per_token > 0 && mem_per_token*N > buf_size) { @@ -670,8 +670,7 @@ bool GPTJ::isModelLoaded() const } void GPTJ::prompt(const std::string &prompt, std::function response, - int32_t n_predict, int32_t top_k, float top_p, float temp, - int32_t n_batch) { + PromptContext &ctx, int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch) { if (!isModelLoaded()) { std::cerr << "GPT-J ERROR: prompt won't work with an unloaded model!\n"; @@ -679,32 +678,38 @@ void GPTJ::prompt(const std::string &prompt, std::function logits; - // tokenize the prompt std::vector embd_inp = ::gpt_tokenize(d_ptr->vocab, prompt); n_predict = std::min(n_predict, d_ptr->model.hparams.n_ctx - (int) embd_inp.size()); + ctx.n_past = std::min(ctx.n_past, 1024); +// n_batch = embd_inp.size(); + + std::cout << "The past was: " << ctx.n_past; + fflush(stdout); std::vector embd; std::vector resp; // determine the required inference memory per token: + static bool initialized = false; size_t mem_per_token = 0; - gptj_eval(d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); + if (!initialized) { + gptj_eval(d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, ctx.logits, mem_per_token); + initialized = true; + } for (int i = embd.size(); i < embd_inp.size() + n_predict; i++) { // predict if (embd.size() > 0) { const int64_t t_start_us = ggml_time_us(); - if (!gptj_eval(d_ptr->model, d_ptr->n_threads, n_past, embd, logits, mem_per_token)) { + if (!gptj_eval(d_ptr->model, d_ptr->n_threads, ctx.n_past, embd, ctx.logits, mem_per_token)) { std::cerr << "GPT-J ERROR: Failed to predict\n"; return; } @@ -712,7 +717,7 @@ void GPTJ::prompt(const std::string &prompt, std::functionvocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, d_ptr->rng); + id = gpt_sample_top_k_top_p(d_ptr->vocab, ctx.logits.data() + (ctx.logits.size() - n_vocab), top_k, top_p, temp, d_ptr->rng); t_sample_us += ggml_time_us() - t_start_sample_us; } diff --git a/gptj.h b/gptj.h index 3a698180..884c9c4f 100644 --- a/gptj.h +++ b/gptj.h @@ -13,9 +13,13 @@ public: bool loadModel(const std::string &modelPath, std::istream &fin); bool isModelLoaded() const; + struct PromptContext { + std::vector logits; + int32_t n_past = 0; // number of tokens in past conversation + }; void prompt(const std::string &prompt, std::function response, - int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f, float temp = 0.9f, - int32_t n_batch = 9); + PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f, + float temp = 0.9f, int32_t n_batch = 9); private: GPTJPrivate *d_ptr; diff --git a/llm.cpp b/llm.cpp index 6e2ca906..1bf4e287 100644 --- a/llm.cpp +++ b/llm.cpp @@ -75,7 +75,8 @@ bool GPTJObject::prompt(const QString &prompt) m_stopGenerating = false; auto func = std::bind(&GPTJObject::handleResponse, this, std::placeholders::_1); emit responseStarted(); - m_gptj->prompt(prompt.toStdString(), func); + static GPTJ::PromptContext ctx; + m_gptj->prompt(prompt.toStdString(), func, ctx, 4096 /*number of chars to predict*/); emit responseStopped(); return true; }