From bbe195ee0207549c4a763ee67f464952af2fac02 Mon Sep 17 00:00:00 2001 From: AT Date: Sun, 4 Jun 2023 08:59:24 -0400 Subject: [PATCH] Backend prompt dedup (#822) * Deduplicated prompt() function code --- gpt4all-backend/CMakeLists.txt | 2 +- gpt4all-backend/gptj.cpp | 169 +++++----------------------- gpt4all-backend/gptj_impl.h | 14 ++- gpt4all-backend/llamamodel.cpp | 172 +++++----------------------- gpt4all-backend/llamamodel_impl.h | 14 ++- gpt4all-backend/llmodel.h | 29 +++-- gpt4all-backend/llmodel_shared.cpp | 133 ++++++++++++++++++++++ gpt4all-backend/mpt.cpp | 173 +++++------------------------ gpt4all-backend/mpt_impl.h | 14 ++- gpt4all-chat/chatgpt.h | 13 ++- 10 files changed, 281 insertions(+), 452 deletions(-) diff --git a/gpt4all-backend/CMakeLists.txt b/gpt4all-backend/CMakeLists.txt index defcc127..c2190674 100644 --- a/gpt4all-backend/CMakeLists.txt +++ b/gpt4all-backend/CMakeLists.txt @@ -103,7 +103,7 @@ foreach(BUILD_VARIANT IN LISTS BUILD_VARIANTS) endforeach() add_library(llmodel - llmodel.h llmodel.cpp + llmodel.h llmodel.cpp llmodel_shared.cpp llmodel_c.h llmodel_c.cpp dlhandle.h ) diff --git a/gpt4all-backend/gptj.cpp b/gpt4all-backend/gptj.cpp index f0ce58a2..16a8e88f 100644 --- a/gpt4all-backend/gptj.cpp +++ b/gpt4all-backend/gptj.cpp @@ -890,159 +890,50 @@ size_t GPTJ::restoreState(const uint8_t *src) return gptj_set_state_data(d_ptr->model, &d_ptr->rng, src); } -void GPTJ::prompt(const std::string &prompt, - std::function promptCallback, - std::function responseCallback, - std::function recalculateCallback, - PromptContext &promptCtx) { - - if (!isModelLoaded()) { - std::cerr << "GPT-J ERROR: prompt won't work with an unloaded model!\n"; - return; - } - - // tokenize the prompt - std::vector embd_inp = ::gpt_tokenize(d_ptr->vocab, prompt); - - // save the context size - promptCtx.n_ctx = d_ptr->model->hparams.n_ctx; +std::vector GPTJ::tokenize(const std::string &str) const +{ + return ::gpt_tokenize(d_ptr->vocab, str); +} - if ((int) embd_inp.size() > promptCtx.n_ctx - 4) { - responseCallback(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed."); - std::cerr << "GPT-J ERROR: The prompt is" << embd_inp.size() << - "tokens and the context window is" << promptCtx.n_ctx << "!\n"; - return; - } +LLModel::Token GPTJ::sampleToken(PromptContext &promptCtx) const +{ + const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size()); + return gpt_sample_top_k_top_p(d_ptr->model->hparams.n_vocab, + promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks, + n_prev_toks, + promptCtx.logits, + promptCtx.top_k, promptCtx.top_p, promptCtx.temp, + promptCtx.repeat_penalty, + d_ptr->rng); +} - promptCtx.n_predict = std::min(promptCtx.n_predict, promptCtx.n_ctx - (int) embd_inp.size()); - promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx); +std::string_view GPTJ::tokenToString(Token id) const +{ + return d_ptr->vocab.id_to_token[id]; +} +bool GPTJ::evalTokens(PromptContext &ctx, const std::vector &tokens) const +{ // determine the required inference memory per token: static bool initialized = false; - static std::vector p_instruct; - static std::vector r_instruct; if (!initialized) { - gptj_eval(*d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, promptCtx.logits, + gptj_eval(*d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, ctx.logits, d_ptr->mem_per_token); initialized = true; } - // process the prompt in batches - size_t i = 0; - while (i < embd_inp.size()) { - size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size()); - std::vector batch(embd_inp.begin() + i, embd_inp.begin() + batch_end); - - // Check if the context has run out... - if (promptCtx.n_past + int32_t(batch.size()) > promptCtx.n_ctx) { - const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase; - // Erase the first percentage of context from the tokens... - std::cerr << "GPTJ: reached the end of the context window so resizing\n"; - promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); - promptCtx.n_past = promptCtx.tokens.size(); - recalculateContext(promptCtx, recalculateCallback); - assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx); - } - - if (!evalTokens(promptCtx, batch)) { - std::cerr << "GPT-J ERROR: Failed to process prompt\n"; - return; - } - - size_t tokens = batch_end - i; - for (size_t t = 0; t < tokens; ++t) { - if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx) - promptCtx.tokens.erase(promptCtx.tokens.begin()); - promptCtx.tokens.push_back(batch.at(t)); - if (!promptCallback(batch.at(t))) - return; - } - promptCtx.n_past += batch.size(); - i = batch_end; - } - - std::string cachedResponse; - std::vector cachedTokens; - std::unordered_set reversePrompts - = { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context" }; - - // predict next tokens - for (int i = 0; i < promptCtx.n_predict; i++) { - - // sample next token - const int n_vocab = d_ptr->model->hparams.n_vocab; - gpt_vocab::id id = 0; - { - const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size()); - id = gpt_sample_top_k_top_p(n_vocab, - promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks, - n_prev_toks, - promptCtx.logits, - promptCtx.top_k, promptCtx.top_p, promptCtx.temp, - promptCtx.repeat_penalty, - d_ptr->rng); - } - - // Check if the context has run out... - if (promptCtx.n_past + 1 > promptCtx.n_ctx) { - const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase; - // Erase the first percentage of context from the tokens... - std::cerr << "GPTJ: reached the end of the context window so resizing\n"; - promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); - promptCtx.n_past = promptCtx.tokens.size(); - recalculateContext(promptCtx, recalculateCallback); - assert(promptCtx.n_past + 1 <= promptCtx.n_ctx); - } - - if (!evalTokens(promptCtx, { id })) { - std::cerr << "GPT-J ERROR: Failed to predict next token\n"; - return; - } - - promptCtx.n_past += 1; - // display text - if (id == 50256 /*end of text*/) - return; - - const std::string str = d_ptr->vocab.id_to_token[id]; - - // Check if the provided str is part of our reverse prompts - bool foundPartialReversePrompt = false; - const std::string completed = cachedResponse + str; - if (reversePrompts.find(completed) != reversePrompts.end()) - return; - - // Check if it partially matches our reverse prompts and if so, cache - for (const auto &s : reversePrompts) { - if (s.compare(0, completed.size(), completed) == 0) { - foundPartialReversePrompt = true; - cachedResponse = completed; - break; - } - } - - // Regardless the token gets added to our cache - cachedTokens.push_back(id); - - // Continue if we have found a partial match - if (foundPartialReversePrompt) - continue; + return gptj_eval(*d_ptr->model, d_ptr->n_threads, ctx.n_past, tokens, ctx.logits, d_ptr->mem_per_token); +} - // Empty the cache - for (auto t : cachedTokens) { - if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx) - promptCtx.tokens.erase(promptCtx.tokens.begin()); - promptCtx.tokens.push_back(t); - if (!responseCallback(t, d_ptr->vocab.id_to_token[t])) - return; - } - cachedTokens.clear(); - } +int32_t GPTJ::contextLength() const +{ + return d_ptr->model->hparams.n_ctx; } -bool GPTJ::evalTokens(PromptContext &ctx, const std::vector &tokens) +const std::vector &GPTJ::endTokens() const { - return gptj_eval(*d_ptr->model, d_ptr->n_threads, ctx.n_past, tokens, ctx.logits, d_ptr->mem_per_token); + static const std::vector fres = {50256}; + return fres; } #if defined(_WIN32) diff --git a/gpt4all-backend/gptj_impl.h b/gpt4all-backend/gptj_impl.h index 4b209bd2..270d65bb 100644 --- a/gpt4all-backend/gptj_impl.h +++ b/gpt4all-backend/gptj_impl.h @@ -20,17 +20,19 @@ public: size_t stateSize() const override; size_t saveState(uint8_t *dest) const override; size_t restoreState(const uint8_t *src) override; - void prompt(const std::string &prompt, - std::function promptCallback, - std::function responseCallback, - std::function recalculateCallback, - PromptContext &ctx) override; - bool evalTokens(PromptContext &ctx, const std::vector &tokens) override; void setThreadCount(int32_t n_threads) override; int32_t threadCount() const override; private: GPTJPrivate *d_ptr; + +protected: + std::vector tokenize(const std::string&) const override; + Token sampleToken(PromptContext &ctx) const override; + std::string_view tokenToString(Token) const override; + bool evalTokens(PromptContext &ctx, const std::vector &tokens) const override; + int32_t contextLength() const override; + const std::vector& endTokens() const override; }; #endif // GPTJ_H diff --git a/gpt4all-backend/llamamodel.cpp b/gpt4all-backend/llamamodel.cpp index 6c94b1d8..e9822d33 100644 --- a/gpt4all-backend/llamamodel.cpp +++ b/gpt4all-backend/llamamodel.cpp @@ -112,7 +112,7 @@ bool LLamaModel::loadModel(const std::string &modelPath) d_ptr->params.use_mlock = true; #else d_ptr->params.use_mlock = params.use_mlock; -#endif +#endif #if LLAMA_DATE <= 230511 d_ptr->params.n_parts = params.n_parts; #endif @@ -163,155 +163,43 @@ size_t LLamaModel::restoreState(const uint8_t *src) return llama_set_state_data(d_ptr->ctx, const_cast(src)); } -void LLamaModel::prompt(const std::string &prompt, - std::function promptCallback, - std::function responseCallback, - std::function recalculateCallback, - PromptContext &promptCtx) { - - if (!isModelLoaded()) { - std::cerr << "LLAMA ERROR: prompt won't work with an unloaded model!\n"; - return; - } +std::vector LLamaModel::tokenize(const std::string &str) const +{ + std::vector fres(str.size()+4); + auto fres_len = llama_tokenize(d_ptr->ctx, str.c_str(), fres.data(), fres.size(), d_ptr->empty); + fres.resize(fres_len); + return fres; +} - gpt_params params; - params.prompt = prompt; +std::string_view LLamaModel::tokenToString(Token id) const +{ + return llama_token_to_str(d_ptr->ctx, id); +} - // Add a space in front of the first character to match OG llama tokenizer behavior - params.prompt.insert(0, 1, ' '); +LLModel::Token LLamaModel::sampleToken(PromptContext &promptCtx) const +{ + const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size()); + return llama_sample_top_p_top_k(d_ptr->ctx, + promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks, + n_prev_toks, promptCtx.top_k, promptCtx.top_p, promptCtx.temp, + promptCtx.repeat_penalty); +} - // tokenize the prompt - std::vector embd_inp(params.prompt.size() + 4); - int n = llama_tokenize(d_ptr->ctx, params.prompt.c_str(), embd_inp.data(), embd_inp.size(), d_ptr->empty); - assert(n >= 0); - embd_inp.resize(n); +bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector &tokens) const +{ d_ptr->empty = false; + return llama_eval(d_ptr->ctx, tokens.data(), tokens.size(), ctx.n_past, d_ptr->n_threads) == 0; +} - // save the context size - promptCtx.n_ctx = llama_n_ctx(d_ptr->ctx); - - if ((int) embd_inp.size() > promptCtx.n_ctx - 4) { - responseCallback(-1, "The prompt size exceeds the context window size and cannot be processed."); - std::cerr << "LLAMA ERROR: The prompt is" << embd_inp.size() << - "tokens and the context window is" << promptCtx.n_ctx << "!\n"; - return; - } - - promptCtx.n_predict = std::min(promptCtx.n_predict, promptCtx.n_ctx - (int) embd_inp.size()); - promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx); - - // number of tokens to keep when resetting context - params.n_keep = (int)embd_inp.size(); - - // process the prompt in batches - size_t i = 0; - while (i < embd_inp.size()) { - size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size()); - std::vector batch(embd_inp.begin() + i, embd_inp.begin() + batch_end); - - // Check if the context has run out... - if (promptCtx.n_past + int32_t(batch.size()) > promptCtx.n_ctx) { - const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase; - // Erase the first percentage of context from the tokens... - std::cerr << "LLAMA: reached the end of the context window so resizing\n"; - promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); - promptCtx.n_past = promptCtx.tokens.size(); - recalculateContext(promptCtx, recalculateCallback); - assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx); - } - - if (!evalTokens(promptCtx, batch)) { - std::cerr << "LLAMA ERROR: Failed to process prompt\n"; - return; - } - - size_t tokens = batch_end - i; - for (size_t t = 0; t < tokens; ++t) { - if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx) - promptCtx.tokens.erase(promptCtx.tokens.begin()); - promptCtx.tokens.push_back(batch.at(t)); - if (!promptCallback(batch.at(t))) - return; - } - promptCtx.n_past += batch.size(); - i = batch_end; - } - - std::string cachedResponse; - std::vector cachedTokens; - std::unordered_set reversePrompts - = { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant" }; - - // predict next tokens - for (int i = 0; i < promptCtx.n_predict; i++) { - // sample next token - const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size()); - llama_token id = llama_sample_top_p_top_k(d_ptr->ctx, - promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks, - n_prev_toks, promptCtx.top_k, promptCtx.top_p, promptCtx.temp, - promptCtx.repeat_penalty); - - // Check if the context has run out... - if (promptCtx.n_past + 1 > promptCtx.n_ctx) { - const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase; - // Erase the first percentage of context from the tokens... - std::cerr << "LLAMA: reached the end of the context window so resizing\n"; - promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); - promptCtx.n_past = promptCtx.tokens.size(); - recalculateContext(promptCtx, recalculateCallback); - assert(promptCtx.n_past + 1 <= promptCtx.n_ctx); - } - - if (!evalTokens(promptCtx, { id })) { - std::cerr << "LLAMA ERROR: Failed to predict next token\n"; - return; - } - - promptCtx.n_past += 1; - // display text - if (id == llama_token_eos()) - return; - - const std::string str = llama_token_to_str(d_ptr->ctx, id); - - // Check if the provided str is part of our reverse prompts - bool foundPartialReversePrompt = false; - const std::string completed = cachedResponse + str; - if (reversePrompts.find(completed) != reversePrompts.end()) { - return; - } - - // Check if it partially matches our reverse prompts and if so, cache - for (const auto &s : reversePrompts) { - if (s.compare(0, completed.size(), completed) == 0) { - foundPartialReversePrompt = true; - cachedResponse = completed; - break; - } - } - - // Regardless the token gets added to our cache - cachedTokens.push_back(id); - - // Continue if we have found a partial match - if (foundPartialReversePrompt) - continue; - - // Empty the cache - for (auto t : cachedTokens) { - if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx) - promptCtx.tokens.erase(promptCtx.tokens.begin()); - promptCtx.tokens.push_back(t); - if (!responseCallback(t, llama_token_to_str(d_ptr->ctx, t))) - return; - } - cachedTokens.clear(); - } +int32_t LLamaModel::contextLength() const +{ + return llama_n_ctx(d_ptr->ctx); } -bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector &tokens) +const std::vector &LLamaModel::endTokens() const { - return llama_eval(d_ptr->ctx, tokens.data(), tokens.size(), ctx.n_past, d_ptr->n_threads) == 0; + static const std::vector fres = {llama_token_eos()}; + return fres; } #if defined(_WIN32) diff --git a/gpt4all-backend/llamamodel_impl.h b/gpt4all-backend/llamamodel_impl.h index 3c27fff8..a39f6ffb 100644 --- a/gpt4all-backend/llamamodel_impl.h +++ b/gpt4all-backend/llamamodel_impl.h @@ -20,17 +20,19 @@ public: size_t stateSize() const override; size_t saveState(uint8_t *dest) const override; size_t restoreState(const uint8_t *src) override; - void prompt(const std::string &prompt, - std::function promptCallback, - std::function responseCallback, - std::function recalculateCallback, - PromptContext &ctx) override; - bool evalTokens(PromptContext &ctx, const std::vector &tokens) override; void setThreadCount(int32_t n_threads) override; int32_t threadCount() const override; private: LLamaPrivate *d_ptr; + +protected: + std::vector tokenize(const std::string&) const override; + std::string_view tokenToString(Token) const override; + Token sampleToken(PromptContext& ctx) const override; + bool evalTokens(PromptContext& ctx, const std::vector &tokens) const override; + int32_t contextLength() const override; + const std::vector& endTokens() const override; }; #endif // LLAMAMODEL_H diff --git a/gpt4all-backend/llmodel.h b/gpt4all-backend/llmodel.h index 28b53ff8..634626e2 100644 --- a/gpt4all-backend/llmodel.h +++ b/gpt4all-backend/llmodel.h @@ -7,11 +7,14 @@ #include #include #include +#include class Dlhandle; class LLModel { public: + using Token = int32_t; + class Implementation { LLModel *(*construct_)(); @@ -60,11 +63,11 @@ public: virtual size_t saveState(uint8_t */*dest*/) const { return 0; } virtual size_t restoreState(const uint8_t */*src*/) { return 0; } virtual void prompt(const std::string &prompt, - std::function promptCallback, - std::function responseCallback, - std::function recalculateCallback, - PromptContext &ctx) = 0; - virtual bool evalTokens(PromptContext &ctx, const std::vector &tokens) = 0; + std::function promptCallback, + std::function responseCallback, + std::function recalculateCallback, + PromptContext &ctx); + virtual void setThreadCount(int32_t /*n_threads*/) {} virtual int32_t threadCount() const { return 1; } @@ -84,10 +87,20 @@ public: } protected: - const Implementation *m_implementation = nullptr; - + // These are pure virtual because subclasses need to implement as the default implementation of + // 'prompt' above calls these functions + virtual std::vector tokenize(const std::string&) const = 0; + virtual std::string_view tokenToString(Token) const = 0; + virtual Token sampleToken(PromptContext &ctx) const = 0; + virtual bool evalTokens(PromptContext &/*ctx*/, const std::vector& /*tokens*/) const = 0; + virtual int32_t contextLength() const = 0; + virtual const std::vector& endTokens() const = 0; + + // This is a helper function called from the default implementation of 'prompt' but it can be + // shared by all base classes so it isn't virtual void recalculateContext(PromptContext &promptCtx, std::function recalculate); - static std::string m_implementations_search_path; + const Implementation *m_implementation = nullptr; + static std::string m_implementations_search_path; }; #endif // LLMODEL_H diff --git a/gpt4all-backend/llmodel_shared.cpp b/gpt4all-backend/llmodel_shared.cpp index 18690c89..c84adfe5 100644 --- a/gpt4all-backend/llmodel_shared.cpp +++ b/gpt4all-backend/llmodel_shared.cpp @@ -2,6 +2,7 @@ #include #include +#include void LLModel::recalculateContext(PromptContext &promptCtx, std::function recalculate) { size_t i = 0; @@ -24,3 +25,135 @@ void LLModel::recalculateContext(PromptContext &promptCtx, std::function promptCallback, + std::function responseCallback, + std::function recalculateCallback, + PromptContext &promptCtx) +{ + if (!isModelLoaded()) { + std::cerr << implementation().modelType << " ERROR: prompt won't work with an unloaded model!\n"; + return; + } + + // tokenize the prompt + std::vector embd_inp = tokenize(prompt); + + // save the context size + promptCtx.n_ctx = contextLength(); + + if ((int) embd_inp.size() > promptCtx.n_ctx - 4) { + responseCallback(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed."); + std::cerr << implementation().modelType << " ERROR: The prompt is" << embd_inp.size() << + "tokens and the context window is" << promptCtx.n_ctx << "!\n"; + return; + } + + promptCtx.n_predict = std::min(promptCtx.n_predict, promptCtx.n_ctx - (int) embd_inp.size()); + promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx); + + // process the prompt in batches + size_t i = 0; + while (i < embd_inp.size()) { + size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size()); + std::vector batch(embd_inp.begin() + i, embd_inp.begin() + batch_end); + + // Check if the context has run out... + if (promptCtx.n_past + int32_t(batch.size()) > promptCtx.n_ctx) { + const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase; + // Erase the first percentage of context from the tokens... + std::cerr << implementation().modelType << ": reached the end of the context window so resizing\n"; + promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); + promptCtx.n_past = promptCtx.tokens.size(); + recalculateContext(promptCtx, recalculateCallback); + assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx); + } + + if (!evalTokens(promptCtx, batch)) { + std::cerr << implementation().modelType << " ERROR: Failed to process prompt\n"; + return; + } + + size_t tokens = batch_end - i; + for (size_t t = 0; t < tokens; ++t) { + if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx) + promptCtx.tokens.erase(promptCtx.tokens.begin()); + promptCtx.tokens.push_back(batch.at(t)); + if (!promptCallback(batch.at(t))) + return; + } + promptCtx.n_past += batch.size(); + i = batch_end; + } + + std::string cachedResponse; + std::vector cachedTokens; + std::unordered_set reversePrompts + = { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context" }; + + // predict next tokens + for (int i = 0; i < promptCtx.n_predict; i++) { + + // sample next token + auto id = sampleToken(promptCtx); + + // Check if the context has run out... + if (promptCtx.n_past + 1 > promptCtx.n_ctx) { + const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase; + // Erase the first percentage of context from the tokens... + std::cerr << implementation().modelType << ": reached the end of the context window so resizing\n"; + promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); + promptCtx.n_past = promptCtx.tokens.size(); + recalculateContext(promptCtx, recalculateCallback); + assert(promptCtx.n_past + 1 <= promptCtx.n_ctx); + } + + if (!evalTokens(promptCtx, { id })) { + std::cerr << implementation().modelType << " ERROR: Failed to predict next token\n"; + return; + } + + promptCtx.n_past += 1; + + // display text + for (const auto token : endTokens()) { + if (id == token) return; + } + + const std::string_view str = tokenToString(id); + + // Check if the provided str is part of our reverse prompts + bool foundPartialReversePrompt = false; + const std::string completed = cachedResponse + std::string(str); + if (reversePrompts.find(completed) != reversePrompts.end()) + return; + + // Check if it partially matches our reverse prompts and if so, cache + for (const auto& s : reversePrompts) { + if (s.compare(0, completed.size(), completed) == 0) { + foundPartialReversePrompt = true; + cachedResponse = completed; + break; + } + } + + // Regardless the token gets added to our cache + cachedTokens.push_back(id); + + // Continue if we have found a partial match + if (foundPartialReversePrompt) + continue; + + // Empty the cache + for (auto t : cachedTokens) { + if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx) + promptCtx.tokens.erase(promptCtx.tokens.begin()); + promptCtx.tokens.push_back(t); + //TODO: Conversion to std::string can be avoided here... + if (!responseCallback(t, std::string(tokenToString(t)))) + return; + } + cachedTokens.clear(); + } +} diff --git a/gpt4all-backend/mpt.cpp b/gpt4all-backend/mpt.cpp index 4145c9c2..018189a5 100644 --- a/gpt4all-backend/mpt.cpp +++ b/gpt4all-backend/mpt.cpp @@ -815,163 +815,50 @@ size_t MPT::restoreState(const uint8_t *src) return mpt_set_state_data(d_ptr->model, &d_ptr->rng, src); } -void MPT::prompt(const std::string &prompt, - std::function promptCallback, - std::function responseCallback, - std::function recalculateCallback, - PromptContext &promptCtx) { - - if (!isModelLoaded()) { - std::cerr << "GPT-J ERROR: prompt won't work with an unloaded model!\n"; - return; - } - - // tokenize the prompt - std::vector embd_inp = gpt_tokenize(d_ptr->vocab, prompt); - - // save the context size - promptCtx.n_ctx = d_ptr->model->hparams.n_ctx; +std::vector MPT::tokenize(const std::string &str) const +{ + return ::gpt_tokenize(d_ptr->vocab, str); +} - if ((int) embd_inp.size() > promptCtx.n_ctx - 4) { - responseCallback(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed."); - std::cerr << "GPT-J ERROR: The prompt is" << embd_inp.size() << - "tokens and the context window is" << promptCtx.n_ctx << "!\n"; - return; - } +std::string_view MPT::tokenToString(Token id) const +{ + return d_ptr->vocab.id_to_token[id]; +} - promptCtx.n_predict = std::min(promptCtx.n_predict, promptCtx.n_ctx - (int) embd_inp.size()); - promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx); +LLModel::Token MPT::sampleToken(PromptContext &promptCtx) const +{ + const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size()); + return gpt_sample_top_k_top_p(d_ptr->model->hparams.n_vocab, + promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks, + n_prev_toks, + promptCtx.logits, + promptCtx.top_k, promptCtx.top_p, promptCtx.temp, + promptCtx.repeat_penalty, + d_ptr->rng); +} +bool MPT::evalTokens(PromptContext &ctx, const std::vector &tokens) const +{ // determine the required inference memory per token: static bool initialized = false; - static std::vector p_instruct; - static std::vector r_instruct; if (!initialized) { - mpt_eval(*d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, promptCtx.logits, + mpt_eval(*d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, ctx.logits, d_ptr->mem_per_token); initialized = true; } - // process the prompt in batches - size_t i = 0; - while (i < embd_inp.size()) { - size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size()); - std::vector batch(embd_inp.begin() + i, embd_inp.begin() + batch_end); - - // Check if the context has run out... - if (promptCtx.n_past + int32_t(batch.size()) > promptCtx.n_ctx) { - const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase; - // Erase the first percentage of context from the tokens... - std::cerr << "MPT: reached the end of the context window so resizing\n"; - promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); - promptCtx.n_past = promptCtx.tokens.size(); - recalculateContext(promptCtx, recalculateCallback); - assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx); - } - - if (!evalTokens(promptCtx, batch)) { - std::cerr << "GPT-J ERROR: Failed to process prompt\n"; - return; - } - - size_t tokens = batch_end - i; - for (size_t t = 0; t < tokens; ++t) { - if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx) - promptCtx.tokens.erase(promptCtx.tokens.begin()); - promptCtx.tokens.push_back(batch.at(t)); - if (!promptCallback(batch.at(t))) - return; - } - promptCtx.n_past += batch.size(); - i = batch_end; - } - - std::string cachedResponse; - std::vector cachedTokens; - std::unordered_set reversePrompts - = { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context" }; - - // predict next tokens - for (int i = 0; i < promptCtx.n_predict; i++) { - - // sample next token - const int n_vocab = d_ptr->model->hparams.n_vocab; - int id = 0; - { - const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size()); - id = gpt_sample_top_k_top_p(n_vocab, - promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks, - n_prev_toks, - promptCtx.logits, - promptCtx.top_k, promptCtx.top_p, promptCtx.temp, - promptCtx.repeat_penalty, - d_ptr->rng); - } - - // Check if the context has run out... - if (promptCtx.n_past + 1 > promptCtx.n_ctx) { - const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase; - // Erase the first percentage of context from the tokens... - std::cerr << "MPT: reached the end of the context window so resizing\n"; - promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); - promptCtx.n_past = promptCtx.tokens.size(); - recalculateContext(promptCtx, recalculateCallback); - assert(promptCtx.n_past + 1 <= promptCtx.n_ctx); - } - - if (!evalTokens(promptCtx, { id })) { - std::cerr << "GPT-J ERROR: Failed to predict next token\n"; - return; - } - - promptCtx.n_past += 1; - // display tex - // mpt-7b-chat has special token for end - if (d_ptr->has_im_end && id == d_ptr->vocab.token_to_id["<|im_end|>"]) - return; - - if (id == 0 /*end of text*/) - return; - - const std::string str = d_ptr->vocab.id_to_token[id]; - - // Check if the provided str is part of our reverse prompts - bool foundPartialReversePrompt = false; - const std::string completed = cachedResponse + str; - if (reversePrompts.find(completed) != reversePrompts.end()) - return; - - // Check if it partially matches our reverse prompts and if so, cache - for (const auto &s : reversePrompts) { - if (s.compare(0, completed.size(), completed) == 0) { - foundPartialReversePrompt = true; - cachedResponse = completed; - break; - } - } - - // Regardless the token gets added to our cache - cachedTokens.push_back(id); - - // Continue if we have found a partial match - if (foundPartialReversePrompt) - continue; + return mpt_eval(*d_ptr->model, d_ptr->n_threads, ctx.n_past, tokens, ctx.logits, d_ptr->mem_per_token); +} - // Empty the cache - for (auto t : cachedTokens) { - if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx) - promptCtx.tokens.erase(promptCtx.tokens.begin()); - promptCtx.tokens.push_back(t); - if (!responseCallback(t, d_ptr->vocab.id_to_token[t])) - return; - } - cachedTokens.clear(); - } +int32_t MPT::contextLength() const +{ + return d_ptr->model->hparams.n_ctx; } -bool MPT::evalTokens(PromptContext &ctx, const std::vector &tokens) +const std::vector &MPT::endTokens() const { - return mpt_eval(*d_ptr->model, d_ptr->n_threads, ctx.n_past, tokens, ctx.logits, d_ptr->mem_per_token); + static const std::vector fres = {0, d_ptr->vocab.token_to_id["<|im_end|>"]}; + return fres; } #if defined(_WIN32) diff --git a/gpt4all-backend/mpt_impl.h b/gpt4all-backend/mpt_impl.h index f645b8bf..ec39c92c 100644 --- a/gpt4all-backend/mpt_impl.h +++ b/gpt4all-backend/mpt_impl.h @@ -20,17 +20,19 @@ public: size_t stateSize() const override; size_t saveState(uint8_t *dest) const override; size_t restoreState(const uint8_t *src) override; - void prompt(const std::string &prompt, - std::function promptCallback, - std::function responseCallback, - std::function recalculateCallback, - PromptContext &ctx) override; - bool evalTokens(PromptContext &ctx, const std::vector &tokens) override; void setThreadCount(int32_t n_threads) override; int32_t threadCount() const override; private: MPTPrivate *d_ptr; + +protected: + std::vector tokenize(const std::string&) const override; + std::string_view tokenToString(Token) const override; + Token sampleToken(PromptContext &ctx) const override; + bool evalTokens(PromptContext &ctx, const std::vector &tokens) const override; + int32_t contextLength() const override; + const std::vector& endTokens() const override; }; #endif // MPT_H diff --git a/gpt4all-chat/chatgpt.h b/gpt4all-chat/chatgpt.h index eb20a722..87abb01d 100644 --- a/gpt4all-chat/chatgpt.h +++ b/gpt4all-chat/chatgpt.h @@ -24,7 +24,7 @@ public: std::function responseCallback, std::function recalculateCallback, PromptContext &ctx) override; - bool evalTokens(PromptContext &ctx, const std::vector &tokens) override { return true; } + void setThreadCount(int32_t n_threads) override; int32_t threadCount() const override; @@ -34,6 +34,17 @@ public: QList context() const { return m_context; } void setContext(const QList &context) { m_context = context; } +protected: + // We have to implement these as they are pure virtual in base class, but we don't actually use + // them as they are only called from the default implementation of 'prompt' which we override and + // completely replace + std::vector tokenize(const std::string&) const override { return std::vector(); } + std::string_view tokenToString(Token) const override { return std::string_view(); } + Token sampleToken(PromptContext &ctx) const override { return -1; } + bool evalTokens(PromptContext &/*ctx*/, const std::vector& /*tokens*/) const override { return false; } + int32_t contextLength() const override { return -1; } + const std::vector& endTokens() const override { static const std::vector fres; return fres; } + private Q_SLOTS: void handleFinished(); void handleReadyRead();