From bb78ee00251028f013dd7df597d1f72aff4d905c Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Thu, 20 Apr 2023 17:13:00 -0400 Subject: [PATCH] Back out the prompt/response finding in gptj since it doesn't seem to help. Guard against reaching the end of the context window which we don't handle gracefully except for avoiding a crash. --- gptj.cpp | 78 ++++++++++++++++++++------------------------------ llamamodel.cpp | 35 +++++++++------------- 2 files changed, 45 insertions(+), 68 deletions(-) diff --git a/gptj.cpp b/gptj.cpp index cf55d5cc..ff921a70 100644 --- a/gptj.cpp +++ b/gptj.cpp @@ -684,7 +684,7 @@ bool GPTJ::isModelLoaded() const } void GPTJ::prompt(const std::string &prompt, std::function response, - PromptContext &ctx, int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch) { + PromptContext &promptCtx, int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch) { if (!isModelLoaded()) { std::cerr << "GPT-J ERROR: prompt won't work with an unloaded model!\n"; @@ -700,8 +700,10 @@ void GPTJ::prompt(const std::string &prompt, std::function embd_inp = ::gpt_tokenize(d_ptr->vocab, prompt); - n_predict = std::min(n_predict, d_ptr->model.hparams.n_ctx - (int) embd_inp.size()); - ctx.n_past = std::min(ctx.n_past, d_ptr->model.hparams.n_ctx); + const int n_ctx = d_ptr->model.hparams.n_ctx; + + n_predict = std::min(n_predict, n_ctx - (int) embd_inp.size()); + promptCtx.n_past = std::min(promptCtx.n_past, n_ctx); // determine the required inference memory per token: static bool initialized = false; @@ -709,9 +711,7 @@ void GPTJ::prompt(const std::string &prompt, std::function r_instruct; size_t mem_per_token = 0; if (!initialized) { - gptj_eval(d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, ctx.logits, mem_per_token); - p_instruct = ::gpt_tokenize(d_ptr->vocab, "### Prompt:"); - r_instruct = ::gpt_tokenize(d_ptr->vocab, "### Response:"); + gptj_eval(d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, promptCtx.logits, mem_per_token); initialized = true; } @@ -721,7 +721,15 @@ void GPTJ::prompt(const std::string &prompt, std::function batch(embd_inp.begin() + i, embd_inp.begin() + batch_end); - if (!gptj_eval(d_ptr->model, d_ptr->n_threads, ctx.n_past, batch, ctx.logits, mem_per_token)) { + + // Check if the context has run out... + if (promptCtx.n_past + batch.size() > n_ctx) { + // FIXME: will produce gibberish after this + promptCtx.n_past = std::min(promptCtx.n_past, int(n_ctx - batch.size())); + std::cerr << "GPT-J WARNING: reached the end of the context window!\n"; + } + + if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits, mem_per_token)) { std::cerr << "GPT-J ERROR: Failed to process prompt\n"; return; } @@ -730,7 +738,7 @@ void GPTJ::prompt(const std::string &prompt, std::function cachedTokens; - // predict next tokens int32_t totalPredictions = 0; for (int i = 0; i < n_predict; i++) { @@ -749,52 +755,30 @@ void GPTJ::prompt(const std::string &prompt, std::functionvocab, ctx.logits.data() + (ctx.logits.size() - n_vocab), + id = gpt_sample_top_k_top_p(d_ptr->vocab, promptCtx.logits.data() + (promptCtx.logits.size() - n_vocab), top_k, top_p, temp, d_ptr->rng); t_sample_us += ggml_time_us() - t_start_sample_us; } + // Check if the context has run out... + if (promptCtx.n_past + 1 > n_ctx) { + // FIXME: will produce gibberish after this + promptCtx.n_past = std::min(promptCtx.n_past, n_ctx - 1); + std::cerr << "GPT-J WARNING: reached the end of the context window!\n"; + } + const int64_t t_start_predict_us = ggml_time_us(); - if (!gptj_eval(d_ptr->model, d_ptr->n_threads, ctx.n_past, { id }, ctx.logits, mem_per_token)) { + if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits, mem_per_token)) { std::cerr << "GPT-J ERROR: Failed to predict next token\n"; return; } - - cachedTokens.emplace_back(id); - - // Check if this token is next token for p_instruct or r_instruct - if (p_instruct.at(p_instructFound) == id) { - ++p_instructFound; - if (p_instructFound == p_instruct.size()) { - fprintf(stderr, "Warning: Tried to generate \"### Prompt:\" stopping.\n"); - fflush(stderr); - goto stop_generating; - } - continue; - } else - p_instructFound = 0; - - if (r_instruct.at(r_instructFound) == id) { - ++r_instructFound; - if (r_instructFound == r_instruct.size()) { - fprintf(stderr, "Warning: Tried to generate \"### Response:\" stopping.\n"); - fflush(stderr); - goto stop_generating; - } - continue; - } else - r_instructFound = 0; - t_predict_us += ggml_time_us() - t_start_predict_us; - for (int j = 0; j < cachedTokens.size(); ++j) { - gpt_vocab::id cachedToken = cachedTokens.at(j); - ctx.n_past += 1; - // display text - ++totalPredictions; - if (id == 50256 /*end of text*/ || !response(d_ptr->vocab.id_to_token[cachedToken])) - goto stop_generating; - } - cachedTokens.clear(); + + promptCtx.n_past += 1; + // display text + ++totalPredictions; + if (id == 50256 /*end of text*/ || !response(d_ptr->vocab.id_to_token[id])) + goto stop_generating; } stop_generating: diff --git a/llamamodel.cpp b/llamamodel.cpp index 6f7d7cd3..561ec5c8 100644 --- a/llamamodel.cpp +++ b/llamamodel.cpp @@ -43,7 +43,7 @@ bool LLamaModel::loadModel(const std::string &modelPath) d_ptr->params = llama_context_default_params(); gpt_params params; - d_ptr->params.n_ctx = params.n_ctx; + d_ptr->params.n_ctx = 2048; d_ptr->params.n_parts = params.n_parts; d_ptr->params.seed = params.seed; d_ptr->params.f16_kv = params.memory_f16; @@ -114,16 +114,18 @@ void LLamaModel::prompt(const std::string &prompt, std::function batch(embd_inp.begin() + i, embd_inp.begin() + batch_end); + // Check if the context has run out... if (promptCtx.n_past + batch.size() > n_ctx) { - std::cerr << "eval n_ctx " << n_ctx << " n_past " << promptCtx.n_past << std::endl; + // FIXME: will produce gibberish after this promptCtx.n_past = std::min(promptCtx.n_past, int(n_ctx - batch.size())); - std::cerr << "after n_ctx " << n_ctx << " n_past " << promptCtx.n_past << std::endl; + std::cerr << "LLAMA WARNING: reached the end of the context window!\n"; } if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) { std::cerr << "LLAMA ERROR: Failed to process prompt\n"; return; } + // We pass a null string for each token to see if the user has asked us to stop... size_t tokens = batch_end - i; for (size_t t = 0; t < tokens; ++t) @@ -133,37 +135,28 @@ void LLamaModel::prompt(const std::string &prompt, std::function cachedTokens; - // predict next tokens int32_t totalPredictions = 0; for (int i = 0; i < n_predict; i++) { // sample next token llama_token id = llama_sample_top_p_top_k(d_ptr->ctx, {}, 0, top_k, top_p, temp, 1.0f); + // Check if the context has run out... if (promptCtx.n_past + 1 > n_ctx) { - std::cerr << "eval 2 n_ctx " << n_ctx << " n_past " << promptCtx.n_past << std::endl; + // FIXME: will produce gibberish after this promptCtx.n_past = std::min(promptCtx.n_past, n_ctx - 1); - std::cerr << "after 2 n_ctx " << n_ctx << " n_past " << promptCtx.n_past << std::endl; + std::cerr << "LLAMA WARNING: reached the end of the context window!\n"; } if (llama_eval(d_ptr->ctx, &id, 1, promptCtx.n_past, d_ptr->n_threads)) { std::cerr << "LLAMA ERROR: Failed to predict next token\n"; return; } - cachedTokens.emplace_back(id); - - for (int j = 0; j < cachedTokens.size(); ++j) { - llama_token cachedToken = cachedTokens.at(j); - promptCtx.n_past += 1; - // display text - ++totalPredictions; - if (id == llama_token_eos() || !response(llama_token_to_str(d_ptr->ctx, cachedToken))) - goto stop_generating; - } - cachedTokens.clear(); - } -stop_generating: - return; + promptCtx.n_past += 1; + // display text + ++totalPredictions; + if (id == llama_token_eos() || !response(llama_token_to_str(d_ptr->ctx, id))) + return; + } }