diff --git a/llmodel/gptj.cpp b/llmodel/gptj.cpp index 3ac0bf17..74977854 100644 --- a/llmodel/gptj.cpp +++ b/llmodel/gptj.cpp @@ -14,6 +14,7 @@ #include #include #include +#include // default hparams (GPT-J 6B) static const size_t MB = 1024*1024; @@ -968,6 +969,11 @@ void GPTJ::prompt(const std::string &prompt, int p_instructFound = 0; int r_instructFound = 0; + std::string cachedResponse; + std::vector cachedTokens; + std::unordered_set reversePrompts + = { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant" }; + // predict next tokens int32_t totalPredictions = 0; for (int i = 0; i < promptCtx.n_predict; i++) { @@ -1014,11 +1020,40 @@ void GPTJ::prompt(const std::string &prompt, if (id == 50256 /*end of text*/) goto stop_generating; - if (promptCtx.tokens.size() == promptCtx.n_ctx) - promptCtx.tokens.erase(promptCtx.tokens.begin()); - promptCtx.tokens.push_back(id); - if (!responseCallback(id, d_ptr->vocab.id_to_token[id])) + const std::string str = d_ptr->vocab.id_to_token[id]; + + // Check if the provided str is part of our reverse prompts + bool foundPartialReversePrompt = false; + const std::string completed = cachedResponse + str; + if (reversePrompts.find(completed) != reversePrompts.end()) { goto stop_generating; + } + + // Check if it partially matches our reverse prompts and if so, cache + for (auto s : reversePrompts) { + if (s.compare(0, completed.size(), completed) == 0) { + foundPartialReversePrompt = true; + cachedResponse = completed; + break; + } + } + + // Regardless the token gets added to our cache + cachedTokens.push_back(id); + + // Continue if we have found a partial match + if (foundPartialReversePrompt) + continue; + + // Empty the cache + for (auto t : cachedTokens) { + if (promptCtx.tokens.size() == promptCtx.n_ctx) + promptCtx.tokens.erase(promptCtx.tokens.begin()); + promptCtx.tokens.push_back(t); + if (!responseCallback(t, d_ptr->vocab.id_to_token[t])) + goto stop_generating; + } + cachedTokens.clear(); } stop_generating: