Implement repeat penalty for both llama and gptj in gui.

This commit is contained in:
Adam Treat 2023-04-25 08:38:29 -04:00
parent cd2e559db4
commit 8b1ddabe3e
9 changed files with 107 additions and 50 deletions

View File

@ -683,8 +683,9 @@ bool GPTJ::isModelLoaded() const
return d_ptr->modelLoaded;
}
void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
PromptContext &promptCtx, int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch) {
void GPTJ::prompt(const std::string &prompt,
std::function<bool(int32_t, const std::string&)> response,
PromptContext &promptCtx) {
if (!isModelLoaded()) {
std::cerr << "GPT-J ERROR: prompt won't work with an unloaded model!\n";
@ -700,10 +701,11 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
// tokenize the prompt
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(d_ptr->vocab, prompt);
const int n_ctx = d_ptr->model.hparams.n_ctx;
// save the context size
promptCtx.n_ctx = d_ptr->model.hparams.n_ctx;
n_predict = std::min(n_predict, n_ctx - (int) embd_inp.size());
promptCtx.n_past = std::min(promptCtx.n_past, n_ctx);
promptCtx.n_predict = std::min(promptCtx.n_predict, promptCtx.n_ctx - (int) embd_inp.size());
promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx);
// determine the required inference memory per token:
static bool initialized = false;
@ -719,13 +721,13 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
size_t i = 0;
const int64_t t_start_prompt_us = ggml_time_us();
while (i < embd_inp.size()) {
size_t batch_end = std::min(i + n_batch, embd_inp.size());
size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size());
std::vector<gpt_vocab::id> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
// Check if the context has run out...
if (promptCtx.n_past + batch.size() > n_ctx) {
if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) {
// FIXME: will produce gibberish after this
promptCtx.n_past = std::min(promptCtx.n_past, int(n_ctx - batch.size()));
promptCtx.n_past = std::min(promptCtx.n_past, int(promptCtx.n_ctx - batch.size()));
std::cerr << "GPT-J WARNING: reached the end of the context window!\n";
}
@ -736,7 +738,7 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
// We pass a null string for each token to see if the user has asked us to stop...
size_t tokens = batch_end - i;
for (size_t t = 0; t < tokens; ++t)
if (!response(""))
if (!response(batch.at(t), ""))
return;
promptCtx.n_past += batch.size();
i = batch_end;
@ -748,22 +750,28 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
// predict next tokens
int32_t totalPredictions = 0;
for (int i = 0; i < n_predict; i++) {
for (int i = 0; i < promptCtx.n_predict; i++) {
// sample next token
const int n_vocab = d_ptr->model.hparams.n_vocab;
gpt_vocab::id id = 0;
{
const int64_t t_start_sample_us = ggml_time_us();
id = gpt_sample_top_k_top_p(d_ptr->vocab, promptCtx.logits.data() + (promptCtx.logits.size() - n_vocab),
top_k, top_p, temp, d_ptr->rng);
id = gpt_sample_top_k_top_p(d_ptr->vocab,
promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.n_ctx,
promptCtx.n_ctx,
promptCtx.logits,
promptCtx.top_k, promptCtx.top_p, promptCtx.temp,
promptCtx.repeat_penalty,
d_ptr->rng);
t_sample_us += ggml_time_us() - t_start_sample_us;
}
// Check if the context has run out...
if (promptCtx.n_past + 1 > n_ctx) {
if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
// FIXME: will produce gibberish after this
promptCtx.n_past = std::min(promptCtx.n_past, n_ctx - 1);
promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx - 1);
std::cerr << "GPT-J WARNING: reached the end of the context window!\n";
}
@ -777,7 +785,7 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
promptCtx.n_past += 1;
// display text
++totalPredictions;
if (id == 50256 /*end of text*/ || !response(d_ptr->vocab.id_to_token[id]))
if (id == 50256 /*end of text*/ || !response(id, d_ptr->vocab.id_to_token[id]))
goto stop_generating;
}

6
gptj.h
View File

@ -15,9 +15,9 @@ public:
bool loadModel(const std::string &modelPath) override;
bool loadModel(const std::string &modelPath, std::istream &fin) override;
bool isModelLoaded() const override;
void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 50400, float top_p = 1.0f,
float temp = 0.0f, int32_t n_batch = 9) override;
void prompt(const std::string &prompt,
std::function<bool(int32_t, const std::string&)> response,
PromptContext &ctx) override;
void setThreadCount(int32_t n_threads) override;
int32_t threadCount() override;

View File

@ -78,8 +78,9 @@ bool LLamaModel::isModelLoaded() const
return d_ptr->modelLoaded;
}
void LLamaModel::prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
PromptContext &promptCtx, int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch) {
void LLamaModel::prompt(const std::string &prompt,
std::function<bool(int32_t, const std::string&)> response,
PromptContext &promptCtx) {
if (!isModelLoaded()) {
std::cerr << "LLAMA ERROR: prompt won't work with an unloaded model!\n";
@ -94,15 +95,17 @@ void LLamaModel::prompt(const std::string &prompt, std::function<bool(const std:
// tokenize the prompt
auto embd_inp = ::llama_tokenize(d_ptr->ctx, params.prompt, false);
const int n_ctx = llama_n_ctx(d_ptr->ctx);
if ((int) embd_inp.size() > n_ctx - 4) {
// save the context size
promptCtx.n_ctx = llama_n_ctx(d_ptr->ctx);
if ((int) embd_inp.size() > promptCtx.n_ctx - 4) {
std::cerr << "LLAMA ERROR: prompt is too long\n";
return;
}
n_predict = std::min(n_predict, n_ctx - (int) embd_inp.size());
promptCtx.n_past = std::min(promptCtx.n_past, n_ctx);
promptCtx.n_predict = std::min(promptCtx.n_predict, promptCtx.n_ctx - (int) embd_inp.size());
promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx);
// number of tokens to keep when resetting context
params.n_keep = (int)embd_inp.size();
@ -111,13 +114,13 @@ void LLamaModel::prompt(const std::string &prompt, std::function<bool(const std:
size_t i = 0;
const int64_t t_start_prompt_us = ggml_time_us();
while (i < embd_inp.size()) {
size_t batch_end = std::min(i + n_batch, embd_inp.size());
size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size());
std::vector<llama_token> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
// Check if the context has run out...
if (promptCtx.n_past + batch.size() > n_ctx) {
if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) {
// FIXME: will produce gibberish after this
promptCtx.n_past = std::min(promptCtx.n_past, int(n_ctx - batch.size()));
promptCtx.n_past = std::min(promptCtx.n_past, int(promptCtx.n_ctx - batch.size()));
std::cerr << "LLAMA WARNING: reached the end of the context window!\n";
}
@ -129,7 +132,7 @@ void LLamaModel::prompt(const std::string &prompt, std::function<bool(const std:
// We pass a null string for each token to see if the user has asked us to stop...
size_t tokens = batch_end - i;
for (size_t t = 0; t < tokens; ++t)
if (!response(""))
if (!response(batch.at(t), ""))
return;
promptCtx.n_past += batch.size();
i = batch_end;
@ -137,14 +140,17 @@ void LLamaModel::prompt(const std::string &prompt, std::function<bool(const std:
// predict next tokens
int32_t totalPredictions = 0;
for (int i = 0; i < n_predict; i++) {
for (int i = 0; i < promptCtx.n_predict; i++) {
// sample next token
llama_token id = llama_sample_top_p_top_k(d_ptr->ctx, {}, 0, top_k, top_p, temp, 1.0f);
llama_token id = llama_sample_top_p_top_k(d_ptr->ctx,
promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.repeat_last_n,
promptCtx.repeat_last_n, promptCtx.top_k, promptCtx.top_p, promptCtx.temp,
promptCtx.repeat_penalty);
// Check if the context has run out...
if (promptCtx.n_past + 1 > n_ctx) {
if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
// FIXME: will produce gibberish after this
promptCtx.n_past = std::min(promptCtx.n_past, n_ctx - 1);
promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx - 1);
std::cerr << "LLAMA WARNING: reached the end of the context window!\n";
}
@ -156,7 +162,7 @@ void LLamaModel::prompt(const std::string &prompt, std::function<bool(const std:
promptCtx.n_past += 1;
// display text
++totalPredictions;
if (id == llama_token_eos() || !response(llama_token_to_str(d_ptr->ctx, id)))
if (id == llama_token_eos() || !response(id, llama_token_to_str(d_ptr->ctx, id)))
return;
}
}

View File

@ -15,9 +15,9 @@ public:
bool loadModel(const std::string &modelPath) override;
bool loadModel(const std::string &modelPath, std::istream &fin) override;
bool isModelLoaded() const override;
void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 50400, float top_p = 1.0f,
float temp = 0.0f, int32_t n_batch = 9) override;
void prompt(const std::string &prompt,
std::function<bool(int32_t, const std::string&)> response,
PromptContext &ctx) override;
void setThreadCount(int32_t n_threads) override;
int32_t threadCount() override;

20
llm.cpp
View File

@ -124,6 +124,7 @@ void LLMObject::regenerateResponse()
s_ctx.n_past = std::max(0, s_ctx.n_past);
// FIXME: This does not seem to be needed in my testing and llama models don't to it. Remove?
s_ctx.logits.erase(s_ctx.logits.end() -= m_responseLogits, s_ctx.logits.end());
s_ctx.tokens.erase(s_ctx.tokens.end() -= m_responseTokens, s_ctx.tokens.end());
m_responseTokens = 0;
m_responseLogits = 0;
m_response = std::string();
@ -243,12 +244,20 @@ QList<QString> LLMObject::modelList() const
return list;
}
bool LLMObject::handleResponse(const std::string &response)
bool LLMObject::handleResponse(int32_t token, const std::string &response)
{
#if 0
printf("%s", response.c_str());
fflush(stdout);
#endif
// Save the token to our prompt ctxt
if (s_ctx.tokens.size() == s_ctx.n_ctx)
s_ctx.tokens.erase(s_ctx.tokens.begin());
s_ctx.tokens.push_back(token);
// m_responseTokens and m_responseLogits are related to last prompt/response not
// the entire context window which we can reset on regenerate prompt
++m_responseTokens;
if (!response.empty()) {
m_response.append(response);
@ -271,10 +280,15 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in
QString instructPrompt = prompt_template.arg(prompt);
m_stopGenerating = false;
auto func = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1);
auto func = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1, std::placeholders::_2);
emit responseStarted();
qint32 logitsBefore = s_ctx.logits.size();
m_llmodel->prompt(instructPrompt.toStdString(), func, s_ctx, n_predict, top_k, top_p, temp, n_batch);
s_ctx.n_predict = n_predict;
s_ctx.top_k = top_k;
s_ctx.top_p = top_p;
s_ctx.temp = temp;
s_ctx.n_batch = n_batch;
m_llmodel->prompt(instructPrompt.toStdString(), func, s_ctx);
m_responseLogits += s_ctx.logits.size() - logitsBefore;
std::string trimmed = trim_whitespace(m_response);
if (trimmed != m_response) {

2
llm.h
View File

@ -50,7 +50,7 @@ Q_SIGNALS:
private:
bool loadModelPrivate(const QString &modelName);
bool handleResponse(const std::string &response);
bool handleResponse(int32_t token, const std::string &response);
private:
LLModel *m_llmodel;

View File

@ -14,12 +14,22 @@ public:
virtual bool loadModel(const std::string &modelPath, std::istream &fin) = 0;
virtual bool isModelLoaded() const = 0;
struct PromptContext {
std::vector<float> logits;
int32_t n_past = 0; // number of tokens in past conversation
std::vector<float> logits; // logits of current context
std::vector<int32_t> tokens; // current tokens in the context window
int32_t n_past = 0; // number of tokens in past conversation
int32_t n_ctx = 0; // number of tokens possible in context window
int32_t n_predict = 200;
int32_t top_k = 40;
float top_p = 0.9f;
float temp = 0.9f;
int32_t n_batch = 9;
float repeat_penalty = 1.10f;
int32_t repeat_last_n = 64; // last n tokens to penalize
};
virtual void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f,
float temp = 0.9f, int32_t n_batch = 9) = 0;
virtual void prompt(const std::string &prompt,
std::function<bool(int32_t, const std::string&)> response,
PromptContext &ctx) = 0;
virtual void setThreadCount(int32_t n_threads) {}
virtual int32_t threadCount() { return 1; }
};

View File

@ -178,20 +178,37 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
gpt_vocab::id gpt_sample_top_k_top_p(
const gpt_vocab & vocab,
const float * logits,
const int32_t * last_n_tokens_data,
int last_n_tokens_size,
const std::vector<float> logits,
int top_k,
double top_p,
double temp,
float repeat_penalty,
std::mt19937 & rng) {
int n_logits = vocab.id_to_token.size();
const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size);
const auto * plogits = logits.data() + logits.size() - n_logits;
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
logits_id.reserve(n_logits);
{
const double scale = 1.0/temp;
const float scale = 1.0f/temp;
for (int i = 0; i < n_logits; ++i) {
logits_id.push_back(std::make_pair(logits[i]*scale, i));
// repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if (plogits[i] < 0.0f) {
logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
} else {
logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
}
} else {
logits_id.push_back(std::make_pair(plogits[i]*scale, i));
}
}
}

View File

@ -72,12 +72,14 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
// - from them, consider only the top tokens with cumulative probability > P
//
// TODO: not sure if this implementation is correct
// TODO: temperature is not implemented
//
gpt_vocab::id gpt_sample_top_k_top_p(
const gpt_vocab & vocab,
const float * logits,
const int32_t * last_n_tokens_data,
int last_n_tokens_size,
const std::vector<float> logits,
int top_k,
double top_p,
double temp,
float repeat_penalty,
std::mt19937 & rng);