mirror of
https://github.com/nomic-ai/gpt4all
synced 2024-11-02 09:40:42 +00:00
Implement repeat penalty for both llama and gptj in gui.
This commit is contained in:
parent
cd2e559db4
commit
8b1ddabe3e
38
gptj.cpp
38
gptj.cpp
@ -683,8 +683,9 @@ bool GPTJ::isModelLoaded() const
|
||||
return d_ptr->modelLoaded;
|
||||
}
|
||||
|
||||
void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
||||
PromptContext &promptCtx, int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch) {
|
||||
void GPTJ::prompt(const std::string &prompt,
|
||||
std::function<bool(int32_t, const std::string&)> response,
|
||||
PromptContext &promptCtx) {
|
||||
|
||||
if (!isModelLoaded()) {
|
||||
std::cerr << "GPT-J ERROR: prompt won't work with an unloaded model!\n";
|
||||
@ -700,10 +701,11 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
|
||||
// tokenize the prompt
|
||||
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(d_ptr->vocab, prompt);
|
||||
|
||||
const int n_ctx = d_ptr->model.hparams.n_ctx;
|
||||
// save the context size
|
||||
promptCtx.n_ctx = d_ptr->model.hparams.n_ctx;
|
||||
|
||||
n_predict = std::min(n_predict, n_ctx - (int) embd_inp.size());
|
||||
promptCtx.n_past = std::min(promptCtx.n_past, n_ctx);
|
||||
promptCtx.n_predict = std::min(promptCtx.n_predict, promptCtx.n_ctx - (int) embd_inp.size());
|
||||
promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx);
|
||||
|
||||
// determine the required inference memory per token:
|
||||
static bool initialized = false;
|
||||
@ -719,13 +721,13 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
|
||||
size_t i = 0;
|
||||
const int64_t t_start_prompt_us = ggml_time_us();
|
||||
while (i < embd_inp.size()) {
|
||||
size_t batch_end = std::min(i + n_batch, embd_inp.size());
|
||||
size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size());
|
||||
std::vector<gpt_vocab::id> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
|
||||
|
||||
// Check if the context has run out...
|
||||
if (promptCtx.n_past + batch.size() > n_ctx) {
|
||||
if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) {
|
||||
// FIXME: will produce gibberish after this
|
||||
promptCtx.n_past = std::min(promptCtx.n_past, int(n_ctx - batch.size()));
|
||||
promptCtx.n_past = std::min(promptCtx.n_past, int(promptCtx.n_ctx - batch.size()));
|
||||
std::cerr << "GPT-J WARNING: reached the end of the context window!\n";
|
||||
}
|
||||
|
||||
@ -736,7 +738,7 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
|
||||
// We pass a null string for each token to see if the user has asked us to stop...
|
||||
size_t tokens = batch_end - i;
|
||||
for (size_t t = 0; t < tokens; ++t)
|
||||
if (!response(""))
|
||||
if (!response(batch.at(t), ""))
|
||||
return;
|
||||
promptCtx.n_past += batch.size();
|
||||
i = batch_end;
|
||||
@ -748,22 +750,28 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
|
||||
|
||||
// predict next tokens
|
||||
int32_t totalPredictions = 0;
|
||||
for (int i = 0; i < n_predict; i++) {
|
||||
for (int i = 0; i < promptCtx.n_predict; i++) {
|
||||
|
||||
// sample next token
|
||||
const int n_vocab = d_ptr->model.hparams.n_vocab;
|
||||
gpt_vocab::id id = 0;
|
||||
{
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
id = gpt_sample_top_k_top_p(d_ptr->vocab, promptCtx.logits.data() + (promptCtx.logits.size() - n_vocab),
|
||||
top_k, top_p, temp, d_ptr->rng);
|
||||
id = gpt_sample_top_k_top_p(d_ptr->vocab,
|
||||
promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.n_ctx,
|
||||
promptCtx.n_ctx,
|
||||
promptCtx.logits,
|
||||
promptCtx.top_k, promptCtx.top_p, promptCtx.temp,
|
||||
promptCtx.repeat_penalty,
|
||||
d_ptr->rng);
|
||||
|
||||
t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
|
||||
// Check if the context has run out...
|
||||
if (promptCtx.n_past + 1 > n_ctx) {
|
||||
if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
|
||||
// FIXME: will produce gibberish after this
|
||||
promptCtx.n_past = std::min(promptCtx.n_past, n_ctx - 1);
|
||||
promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx - 1);
|
||||
std::cerr << "GPT-J WARNING: reached the end of the context window!\n";
|
||||
}
|
||||
|
||||
@ -777,7 +785,7 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
|
||||
promptCtx.n_past += 1;
|
||||
// display text
|
||||
++totalPredictions;
|
||||
if (id == 50256 /*end of text*/ || !response(d_ptr->vocab.id_to_token[id]))
|
||||
if (id == 50256 /*end of text*/ || !response(id, d_ptr->vocab.id_to_token[id]))
|
||||
goto stop_generating;
|
||||
}
|
||||
|
||||
|
6
gptj.h
6
gptj.h
@ -15,9 +15,9 @@ public:
|
||||
bool loadModel(const std::string &modelPath) override;
|
||||
bool loadModel(const std::string &modelPath, std::istream &fin) override;
|
||||
bool isModelLoaded() const override;
|
||||
void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
||||
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 50400, float top_p = 1.0f,
|
||||
float temp = 0.0f, int32_t n_batch = 9) override;
|
||||
void prompt(const std::string &prompt,
|
||||
std::function<bool(int32_t, const std::string&)> response,
|
||||
PromptContext &ctx) override;
|
||||
void setThreadCount(int32_t n_threads) override;
|
||||
int32_t threadCount() override;
|
||||
|
||||
|
@ -78,8 +78,9 @@ bool LLamaModel::isModelLoaded() const
|
||||
return d_ptr->modelLoaded;
|
||||
}
|
||||
|
||||
void LLamaModel::prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
||||
PromptContext &promptCtx, int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch) {
|
||||
void LLamaModel::prompt(const std::string &prompt,
|
||||
std::function<bool(int32_t, const std::string&)> response,
|
||||
PromptContext &promptCtx) {
|
||||
|
||||
if (!isModelLoaded()) {
|
||||
std::cerr << "LLAMA ERROR: prompt won't work with an unloaded model!\n";
|
||||
@ -94,15 +95,17 @@ void LLamaModel::prompt(const std::string &prompt, std::function<bool(const std:
|
||||
|
||||
// tokenize the prompt
|
||||
auto embd_inp = ::llama_tokenize(d_ptr->ctx, params.prompt, false);
|
||||
const int n_ctx = llama_n_ctx(d_ptr->ctx);
|
||||
|
||||
if ((int) embd_inp.size() > n_ctx - 4) {
|
||||
// save the context size
|
||||
promptCtx.n_ctx = llama_n_ctx(d_ptr->ctx);
|
||||
|
||||
if ((int) embd_inp.size() > promptCtx.n_ctx - 4) {
|
||||
std::cerr << "LLAMA ERROR: prompt is too long\n";
|
||||
return;
|
||||
}
|
||||
|
||||
n_predict = std::min(n_predict, n_ctx - (int) embd_inp.size());
|
||||
promptCtx.n_past = std::min(promptCtx.n_past, n_ctx);
|
||||
promptCtx.n_predict = std::min(promptCtx.n_predict, promptCtx.n_ctx - (int) embd_inp.size());
|
||||
promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx);
|
||||
|
||||
// number of tokens to keep when resetting context
|
||||
params.n_keep = (int)embd_inp.size();
|
||||
@ -111,13 +114,13 @@ void LLamaModel::prompt(const std::string &prompt, std::function<bool(const std:
|
||||
size_t i = 0;
|
||||
const int64_t t_start_prompt_us = ggml_time_us();
|
||||
while (i < embd_inp.size()) {
|
||||
size_t batch_end = std::min(i + n_batch, embd_inp.size());
|
||||
size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size());
|
||||
std::vector<llama_token> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
|
||||
|
||||
// Check if the context has run out...
|
||||
if (promptCtx.n_past + batch.size() > n_ctx) {
|
||||
if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) {
|
||||
// FIXME: will produce gibberish after this
|
||||
promptCtx.n_past = std::min(promptCtx.n_past, int(n_ctx - batch.size()));
|
||||
promptCtx.n_past = std::min(promptCtx.n_past, int(promptCtx.n_ctx - batch.size()));
|
||||
std::cerr << "LLAMA WARNING: reached the end of the context window!\n";
|
||||
}
|
||||
|
||||
@ -129,7 +132,7 @@ void LLamaModel::prompt(const std::string &prompt, std::function<bool(const std:
|
||||
// We pass a null string for each token to see if the user has asked us to stop...
|
||||
size_t tokens = batch_end - i;
|
||||
for (size_t t = 0; t < tokens; ++t)
|
||||
if (!response(""))
|
||||
if (!response(batch.at(t), ""))
|
||||
return;
|
||||
promptCtx.n_past += batch.size();
|
||||
i = batch_end;
|
||||
@ -137,14 +140,17 @@ void LLamaModel::prompt(const std::string &prompt, std::function<bool(const std:
|
||||
|
||||
// predict next tokens
|
||||
int32_t totalPredictions = 0;
|
||||
for (int i = 0; i < n_predict; i++) {
|
||||
for (int i = 0; i < promptCtx.n_predict; i++) {
|
||||
// sample next token
|
||||
llama_token id = llama_sample_top_p_top_k(d_ptr->ctx, {}, 0, top_k, top_p, temp, 1.0f);
|
||||
llama_token id = llama_sample_top_p_top_k(d_ptr->ctx,
|
||||
promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.repeat_last_n,
|
||||
promptCtx.repeat_last_n, promptCtx.top_k, promptCtx.top_p, promptCtx.temp,
|
||||
promptCtx.repeat_penalty);
|
||||
|
||||
// Check if the context has run out...
|
||||
if (promptCtx.n_past + 1 > n_ctx) {
|
||||
if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
|
||||
// FIXME: will produce gibberish after this
|
||||
promptCtx.n_past = std::min(promptCtx.n_past, n_ctx - 1);
|
||||
promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx - 1);
|
||||
std::cerr << "LLAMA WARNING: reached the end of the context window!\n";
|
||||
}
|
||||
|
||||
@ -156,7 +162,7 @@ void LLamaModel::prompt(const std::string &prompt, std::function<bool(const std:
|
||||
promptCtx.n_past += 1;
|
||||
// display text
|
||||
++totalPredictions;
|
||||
if (id == llama_token_eos() || !response(llama_token_to_str(d_ptr->ctx, id)))
|
||||
if (id == llama_token_eos() || !response(id, llama_token_to_str(d_ptr->ctx, id)))
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -15,9 +15,9 @@ public:
|
||||
bool loadModel(const std::string &modelPath) override;
|
||||
bool loadModel(const std::string &modelPath, std::istream &fin) override;
|
||||
bool isModelLoaded() const override;
|
||||
void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
||||
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 50400, float top_p = 1.0f,
|
||||
float temp = 0.0f, int32_t n_batch = 9) override;
|
||||
void prompt(const std::string &prompt,
|
||||
std::function<bool(int32_t, const std::string&)> response,
|
||||
PromptContext &ctx) override;
|
||||
void setThreadCount(int32_t n_threads) override;
|
||||
int32_t threadCount() override;
|
||||
|
||||
|
20
llm.cpp
20
llm.cpp
@ -124,6 +124,7 @@ void LLMObject::regenerateResponse()
|
||||
s_ctx.n_past = std::max(0, s_ctx.n_past);
|
||||
// FIXME: This does not seem to be needed in my testing and llama models don't to it. Remove?
|
||||
s_ctx.logits.erase(s_ctx.logits.end() -= m_responseLogits, s_ctx.logits.end());
|
||||
s_ctx.tokens.erase(s_ctx.tokens.end() -= m_responseTokens, s_ctx.tokens.end());
|
||||
m_responseTokens = 0;
|
||||
m_responseLogits = 0;
|
||||
m_response = std::string();
|
||||
@ -243,12 +244,20 @@ QList<QString> LLMObject::modelList() const
|
||||
return list;
|
||||
}
|
||||
|
||||
bool LLMObject::handleResponse(const std::string &response)
|
||||
bool LLMObject::handleResponse(int32_t token, const std::string &response)
|
||||
{
|
||||
#if 0
|
||||
printf("%s", response.c_str());
|
||||
fflush(stdout);
|
||||
#endif
|
||||
|
||||
// Save the token to our prompt ctxt
|
||||
if (s_ctx.tokens.size() == s_ctx.n_ctx)
|
||||
s_ctx.tokens.erase(s_ctx.tokens.begin());
|
||||
s_ctx.tokens.push_back(token);
|
||||
|
||||
// m_responseTokens and m_responseLogits are related to last prompt/response not
|
||||
// the entire context window which we can reset on regenerate prompt
|
||||
++m_responseTokens;
|
||||
if (!response.empty()) {
|
||||
m_response.append(response);
|
||||
@ -271,10 +280,15 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in
|
||||
QString instructPrompt = prompt_template.arg(prompt);
|
||||
|
||||
m_stopGenerating = false;
|
||||
auto func = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1);
|
||||
auto func = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1, std::placeholders::_2);
|
||||
emit responseStarted();
|
||||
qint32 logitsBefore = s_ctx.logits.size();
|
||||
m_llmodel->prompt(instructPrompt.toStdString(), func, s_ctx, n_predict, top_k, top_p, temp, n_batch);
|
||||
s_ctx.n_predict = n_predict;
|
||||
s_ctx.top_k = top_k;
|
||||
s_ctx.top_p = top_p;
|
||||
s_ctx.temp = temp;
|
||||
s_ctx.n_batch = n_batch;
|
||||
m_llmodel->prompt(instructPrompt.toStdString(), func, s_ctx);
|
||||
m_responseLogits += s_ctx.logits.size() - logitsBefore;
|
||||
std::string trimmed = trim_whitespace(m_response);
|
||||
if (trimmed != m_response) {
|
||||
|
2
llm.h
2
llm.h
@ -50,7 +50,7 @@ Q_SIGNALS:
|
||||
|
||||
private:
|
||||
bool loadModelPrivate(const QString &modelName);
|
||||
bool handleResponse(const std::string &response);
|
||||
bool handleResponse(int32_t token, const std::string &response);
|
||||
|
||||
private:
|
||||
LLModel *m_llmodel;
|
||||
|
20
llmodel.h
20
llmodel.h
@ -14,12 +14,22 @@ public:
|
||||
virtual bool loadModel(const std::string &modelPath, std::istream &fin) = 0;
|
||||
virtual bool isModelLoaded() const = 0;
|
||||
struct PromptContext {
|
||||
std::vector<float> logits;
|
||||
int32_t n_past = 0; // number of tokens in past conversation
|
||||
std::vector<float> logits; // logits of current context
|
||||
std::vector<int32_t> tokens; // current tokens in the context window
|
||||
int32_t n_past = 0; // number of tokens in past conversation
|
||||
int32_t n_ctx = 0; // number of tokens possible in context window
|
||||
int32_t n_predict = 200;
|
||||
int32_t top_k = 40;
|
||||
float top_p = 0.9f;
|
||||
float temp = 0.9f;
|
||||
int32_t n_batch = 9;
|
||||
float repeat_penalty = 1.10f;
|
||||
int32_t repeat_last_n = 64; // last n tokens to penalize
|
||||
|
||||
};
|
||||
virtual void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
||||
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f,
|
||||
float temp = 0.9f, int32_t n_batch = 9) = 0;
|
||||
virtual void prompt(const std::string &prompt,
|
||||
std::function<bool(int32_t, const std::string&)> response,
|
||||
PromptContext &ctx) = 0;
|
||||
virtual void setThreadCount(int32_t n_threads) {}
|
||||
virtual int32_t threadCount() { return 1; }
|
||||
};
|
||||
|
23
utils.cpp
23
utils.cpp
@ -178,20 +178,37 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
|
||||
|
||||
gpt_vocab::id gpt_sample_top_k_top_p(
|
||||
const gpt_vocab & vocab,
|
||||
const float * logits,
|
||||
const int32_t * last_n_tokens_data,
|
||||
int last_n_tokens_size,
|
||||
const std::vector<float> logits,
|
||||
int top_k,
|
||||
double top_p,
|
||||
double temp,
|
||||
float repeat_penalty,
|
||||
std::mt19937 & rng) {
|
||||
int n_logits = vocab.id_to_token.size();
|
||||
|
||||
const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size);
|
||||
const auto * plogits = logits.data() + logits.size() - n_logits;
|
||||
|
||||
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
|
||||
logits_id.reserve(n_logits);
|
||||
|
||||
{
|
||||
const double scale = 1.0/temp;
|
||||
const float scale = 1.0f/temp;
|
||||
for (int i = 0; i < n_logits; ++i) {
|
||||
logits_id.push_back(std::make_pair(logits[i]*scale, i));
|
||||
// repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
|
||||
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
|
||||
if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
|
||||
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
||||
if (plogits[i] < 0.0f) {
|
||||
logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
|
||||
} else {
|
||||
logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
|
||||
}
|
||||
} else {
|
||||
logits_id.push_back(std::make_pair(plogits[i]*scale, i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
6
utils.h
6
utils.h
@ -72,12 +72,14 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
|
||||
// - from them, consider only the top tokens with cumulative probability > P
|
||||
//
|
||||
// TODO: not sure if this implementation is correct
|
||||
// TODO: temperature is not implemented
|
||||
//
|
||||
gpt_vocab::id gpt_sample_top_k_top_p(
|
||||
const gpt_vocab & vocab,
|
||||
const float * logits,
|
||||
const int32_t * last_n_tokens_data,
|
||||
int last_n_tokens_size,
|
||||
const std::vector<float> logits,
|
||||
int top_k,
|
||||
double top_p,
|
||||
double temp,
|
||||
float repeat_penalty,
|
||||
std::mt19937 & rng);
|
||||
|
Loading…
Reference in New Issue
Block a user