From 185dc2460edf4fd052ea993eaeaa4e2f8c5d049d Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Sun, 16 Apr 2023 11:12:22 -0400 Subject: [PATCH] Check for ###Prompt: or ###Response and stop generating and modify the default template a little bit. --- gptj.cpp | 52 ++++++++++++++++++++++++++++++++++++++++++++++------ main.qml | 16 ++++++++-------- 2 files changed, 54 insertions(+), 14 deletions(-) diff --git a/gptj.cpp b/gptj.cpp index 34aa16f9..db3da3b7 100644 --- a/gptj.cpp +++ b/gptj.cpp @@ -691,9 +691,13 @@ void GPTJ::prompt(const std::string &prompt, std::function p_instruct; + static std::vector r_instruct; size_t mem_per_token = 0; if (!initialized) { gptj_eval(d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, ctx.logits, mem_per_token); + p_instruct = ::gpt_tokenize(d_ptr->vocab, "### Prompt:"); + r_instruct = ::gpt_tokenize(d_ptr->vocab, "### Response:"); initialized = true; } @@ -717,6 +721,11 @@ void GPTJ::prompt(const std::string &prompt, std::function cachedTokens; + // predict next tokens int32_t totalPredictions = 0; for (int i = 0; i < n_predict; i++) { @@ -736,15 +745,46 @@ void GPTJ::prompt(const std::string &prompt, std::functionvocab.id_to_token[id])) - break; + cachedTokens.emplace_back(id); + + // Check if this token is next token for p_instruct or r_instruct + if (p_instruct.at(p_instructFound) == id) { + ++p_instructFound; + if (p_instructFound == p_instruct.size()) { + fprintf(stderr, "Warning: Tried to generate \"### Prompt:\" stopping.\n"); + fflush(stderr); + goto stop_generating; + } + continue; + } else + p_instructFound = 0; + + if (r_instruct.at(r_instructFound) == id) { + ++r_instructFound; + if (r_instructFound == r_instruct.size()) { + fprintf(stderr, "Warning: Tried to generate \"### Response:\" stopping.\n"); + fflush(stderr); + goto stop_generating; + } + continue; + } else + r_instructFound = 0; + + t_predict_us += ggml_time_us() - t_start_predict_us; + for (int j = 0; j < cachedTokens.size(); ++j) { + gpt_vocab::id cachedToken = cachedTokens.at(j); + ctx.n_past += 1; + // display text + ++totalPredictions; + if (id == 50256 /*end of text*/ || !response(d_ptr->vocab.id_to_token[cachedToken])) + goto stop_generating; + } + cachedTokens.clear(); } +stop_generating: + #if 0 // report timing { diff --git a/main.qml b/main.qml index 63b23875..0f3f97de 100644 --- a/main.qml +++ b/main.qml @@ -59,12 +59,13 @@ Window { property int topK: 40 property int maxLength: 4096 property int promptBatchSize: 9 - property string promptTemplate: "Below is a prompt for either a task to complete or a piece of conversation. Decide which and write an appropriate response to the prompt. + property string defaultPromptTemplate: "The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response. ### Prompt: %1 -### Response: -" +### Response:\n" + + property string promptTemplate: "" function restoreDefaults() { temperature = 0.9; @@ -72,12 +73,11 @@ Window { topK = 40; maxLength = 4096; promptBatchSize = 9; - promptTemplate = "Below is a prompt for either a task to complete or a piece of conversation. Decide which and write an appropriate response to the prompt. + promptTemplate = defaultPromptTemplate; + } -### Prompt: -%1 -### Response: -"; + Component.onCompleted: { + promptTemplate = defaultPromptTemplate; } GridLayout {