|
|
@ -14,6 +14,7 @@
|
|
|
|
#include <iostream>
|
|
|
|
#include <iostream>
|
|
|
|
#include <unistd.h>
|
|
|
|
#include <unistd.h>
|
|
|
|
#include <sstream>
|
|
|
|
#include <sstream>
|
|
|
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
|
|
|
|
|
|
|
// default hparams (GPT-J 6B)
|
|
|
|
// default hparams (GPT-J 6B)
|
|
|
|
static const size_t MB = 1024*1024;
|
|
|
|
static const size_t MB = 1024*1024;
|
|
|
@ -968,6 +969,11 @@ void GPTJ::prompt(const std::string &prompt,
|
|
|
|
int p_instructFound = 0;
|
|
|
|
int p_instructFound = 0;
|
|
|
|
int r_instructFound = 0;
|
|
|
|
int r_instructFound = 0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::string cachedResponse;
|
|
|
|
|
|
|
|
std::vector<gpt_vocab::id> cachedTokens;
|
|
|
|
|
|
|
|
std::unordered_set<std::string> reversePrompts
|
|
|
|
|
|
|
|
= { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant" };
|
|
|
|
|
|
|
|
|
|
|
|
// predict next tokens
|
|
|
|
// predict next tokens
|
|
|
|
int32_t totalPredictions = 0;
|
|
|
|
int32_t totalPredictions = 0;
|
|
|
|
for (int i = 0; i < promptCtx.n_predict; i++) {
|
|
|
|
for (int i = 0; i < promptCtx.n_predict; i++) {
|
|
|
@ -1014,12 +1020,41 @@ void GPTJ::prompt(const std::string &prompt,
|
|
|
|
if (id == 50256 /*end of text*/)
|
|
|
|
if (id == 50256 /*end of text*/)
|
|
|
|
goto stop_generating;
|
|
|
|
goto stop_generating;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const std::string str = d_ptr->vocab.id_to_token[id];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Check if the provided str is part of our reverse prompts
|
|
|
|
|
|
|
|
bool foundPartialReversePrompt = false;
|
|
|
|
|
|
|
|
const std::string completed = cachedResponse + str;
|
|
|
|
|
|
|
|
if (reversePrompts.find(completed) != reversePrompts.end()) {
|
|
|
|
|
|
|
|
goto stop_generating;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Check if it partially matches our reverse prompts and if so, cache
|
|
|
|
|
|
|
|
for (auto s : reversePrompts) {
|
|
|
|
|
|
|
|
if (s.compare(0, completed.size(), completed) == 0) {
|
|
|
|
|
|
|
|
foundPartialReversePrompt = true;
|
|
|
|
|
|
|
|
cachedResponse = completed;
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Regardless the token gets added to our cache
|
|
|
|
|
|
|
|
cachedTokens.push_back(id);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Continue if we have found a partial match
|
|
|
|
|
|
|
|
if (foundPartialReversePrompt)
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Empty the cache
|
|
|
|
|
|
|
|
for (auto t : cachedTokens) {
|
|
|
|
if (promptCtx.tokens.size() == promptCtx.n_ctx)
|
|
|
|
if (promptCtx.tokens.size() == promptCtx.n_ctx)
|
|
|
|
promptCtx.tokens.erase(promptCtx.tokens.begin());
|
|
|
|
promptCtx.tokens.erase(promptCtx.tokens.begin());
|
|
|
|
promptCtx.tokens.push_back(id);
|
|
|
|
promptCtx.tokens.push_back(t);
|
|
|
|
if (!responseCallback(id, d_ptr->vocab.id_to_token[id]))
|
|
|
|
if (!responseCallback(t, d_ptr->vocab.id_to_token[t]))
|
|
|
|
goto stop_generating;
|
|
|
|
goto stop_generating;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
cachedTokens.clear();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
stop_generating:
|
|
|
|
stop_generating:
|
|
|
|
|
|
|
|
|
|
|
|