mirror of
https://github.com/nomic-ai/gpt4all
synced 2024-11-02 09:40:42 +00:00
Add reverse prompt support for gptj too.
This commit is contained in:
parent
06bb6960d4
commit
d0d5d84e06
@ -14,6 +14,7 @@
|
||||
#include <iostream>
|
||||
#include <unistd.h>
|
||||
#include <sstream>
|
||||
#include <unordered_set>
|
||||
|
||||
// default hparams (GPT-J 6B)
|
||||
static const size_t MB = 1024*1024;
|
||||
@ -968,6 +969,11 @@ void GPTJ::prompt(const std::string &prompt,
|
||||
int p_instructFound = 0;
|
||||
int r_instructFound = 0;
|
||||
|
||||
std::string cachedResponse;
|
||||
std::vector<gpt_vocab::id> cachedTokens;
|
||||
std::unordered_set<std::string> reversePrompts
|
||||
= { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant" };
|
||||
|
||||
// predict next tokens
|
||||
int32_t totalPredictions = 0;
|
||||
for (int i = 0; i < promptCtx.n_predict; i++) {
|
||||
@ -1014,11 +1020,40 @@ void GPTJ::prompt(const std::string &prompt,
|
||||
if (id == 50256 /*end of text*/)
|
||||
goto stop_generating;
|
||||
|
||||
if (promptCtx.tokens.size() == promptCtx.n_ctx)
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin());
|
||||
promptCtx.tokens.push_back(id);
|
||||
if (!responseCallback(id, d_ptr->vocab.id_to_token[id]))
|
||||
const std::string str = d_ptr->vocab.id_to_token[id];
|
||||
|
||||
// Check if the provided str is part of our reverse prompts
|
||||
bool foundPartialReversePrompt = false;
|
||||
const std::string completed = cachedResponse + str;
|
||||
if (reversePrompts.find(completed) != reversePrompts.end()) {
|
||||
goto stop_generating;
|
||||
}
|
||||
|
||||
// Check if it partially matches our reverse prompts and if so, cache
|
||||
for (auto s : reversePrompts) {
|
||||
if (s.compare(0, completed.size(), completed) == 0) {
|
||||
foundPartialReversePrompt = true;
|
||||
cachedResponse = completed;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Regardless the token gets added to our cache
|
||||
cachedTokens.push_back(id);
|
||||
|
||||
// Continue if we have found a partial match
|
||||
if (foundPartialReversePrompt)
|
||||
continue;
|
||||
|
||||
// Empty the cache
|
||||
for (auto t : cachedTokens) {
|
||||
if (promptCtx.tokens.size() == promptCtx.n_ctx)
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin());
|
||||
promptCtx.tokens.push_back(t);
|
||||
if (!responseCallback(t, d_ptr->vocab.id_to_token[t]))
|
||||
goto stop_generating;
|
||||
}
|
||||
cachedTokens.clear();
|
||||
}
|
||||
|
||||
stop_generating:
|
||||
|
Loading…
Reference in New Issue
Block a user