Infinite context window through trimming.

This commit is contained in:
Adam Treat 2023-04-25 11:20:51 -04:00
parent 8b1ddabe3e
commit b6937c39db
9 changed files with 187 additions and 29 deletions

View File

@ -635,6 +635,7 @@ struct GPTJPrivate {
gpt_vocab vocab; gpt_vocab vocab;
gptj_model model; gptj_model model;
int64_t n_threads = 0; int64_t n_threads = 0;
size_t mem_per_token = 0;
std::mt19937 rng; std::mt19937 rng;
}; };
@ -662,6 +663,7 @@ bool GPTJ::loadModel(const std::string &modelPath, std::istream &fin) {
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
d_ptr->modelLoaded = true; d_ptr->modelLoaded = true;
fflush(stdout);
return true; return true;
} }
@ -685,6 +687,7 @@ bool GPTJ::isModelLoaded() const
void GPTJ::prompt(const std::string &prompt, void GPTJ::prompt(const std::string &prompt,
std::function<bool(int32_t, const std::string&)> response, std::function<bool(int32_t, const std::string&)> response,
std::function<bool(bool)> recalculate,
PromptContext &promptCtx) { PromptContext &promptCtx) {
if (!isModelLoaded()) { if (!isModelLoaded()) {
@ -711,9 +714,9 @@ void GPTJ::prompt(const std::string &prompt,
static bool initialized = false; static bool initialized = false;
static std::vector<gpt_vocab::id> p_instruct; static std::vector<gpt_vocab::id> p_instruct;
static std::vector<gpt_vocab::id> r_instruct; static std::vector<gpt_vocab::id> r_instruct;
size_t mem_per_token = 0;
if (!initialized) { if (!initialized) {
gptj_eval(d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, promptCtx.logits, mem_per_token); gptj_eval(d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, promptCtx.logits,
d_ptr->mem_per_token);
initialized = true; initialized = true;
} }
@ -726,12 +729,17 @@ void GPTJ::prompt(const std::string &prompt,
// Check if the context has run out... // Check if the context has run out...
if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) { if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) {
// FIXME: will produce gibberish after this const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
promptCtx.n_past = std::min(promptCtx.n_past, int(promptCtx.n_ctx - batch.size())); // Erase the first percentage of context from the tokens...
std::cerr << "GPT-J WARNING: reached the end of the context window!\n"; std::cerr << "GPTJ: reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
promptCtx.n_past = promptCtx.tokens.size();
recalculateContext(promptCtx, recalculate);
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
} }
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits, mem_per_token)) { if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits,
d_ptr->mem_per_token)) {
std::cerr << "GPT-J ERROR: Failed to process prompt\n"; std::cerr << "GPT-J ERROR: Failed to process prompt\n";
return; return;
} }
@ -770,13 +778,18 @@ void GPTJ::prompt(const std::string &prompt,
// Check if the context has run out... // Check if the context has run out...
if (promptCtx.n_past + 1 > promptCtx.n_ctx) { if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
// FIXME: will produce gibberish after this const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx - 1); // Erase the first percentage of context from the tokens...
std::cerr << "GPT-J WARNING: reached the end of the context window!\n"; std::cerr << "GPTJ: reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
promptCtx.n_past = promptCtx.tokens.size();
recalculateContext(promptCtx, recalculate);
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
} }
const int64_t t_start_predict_us = ggml_time_us(); const int64_t t_start_predict_us = ggml_time_us();
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits, mem_per_token)) { if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits,
d_ptr->mem_per_token)) {
std::cerr << "GPT-J ERROR: Failed to predict next token\n"; std::cerr << "GPT-J ERROR: Failed to predict next token\n";
return; return;
} }
@ -807,3 +820,29 @@ stop_generating:
return; return;
} }
void GPTJ::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate)
{
size_t i = 0;
promptCtx.n_past = 0;
while (i < promptCtx.tokens.size()) {
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
std::vector<gpt_vocab::id> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits,
d_ptr->mem_per_token)) {
std::cerr << "GPTJ ERROR: Failed to process prompt\n";
goto stop_generating;
}
promptCtx.n_past += batch.size();
if (!recalculate(true))
goto stop_generating;
i = batch_end;
}
assert(promptCtx.n_past == promptCtx.tokens.size());
stop_generating:
recalculate(false);
}

5
gptj.h
View File

@ -17,10 +17,15 @@ public:
bool isModelLoaded() const override; bool isModelLoaded() const override;
void prompt(const std::string &prompt, void prompt(const std::string &prompt,
std::function<bool(int32_t, const std::string&)> response, std::function<bool(int32_t, const std::string&)> response,
std::function<bool(bool)> recalculate,
PromptContext &ctx) override; PromptContext &ctx) override;
void setThreadCount(int32_t n_threads) override; void setThreadCount(int32_t n_threads) override;
int32_t threadCount() override; int32_t threadCount() override;
protected:
void recalculateContext(PromptContext &promptCtx,
std::function<bool(bool)> recalculate) override;
private: private:
GPTJPrivate *d_ptr; GPTJPrivate *d_ptr;
}; };

View File

@ -58,6 +58,7 @@ bool LLamaModel::loadModel(const std::string &modelPath)
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
d_ptr->modelLoaded = true; d_ptr->modelLoaded = true;
fflush(stderr);
return true; return true;
} }
@ -80,6 +81,7 @@ bool LLamaModel::isModelLoaded() const
void LLamaModel::prompt(const std::string &prompt, void LLamaModel::prompt(const std::string &prompt,
std::function<bool(int32_t, const std::string&)> response, std::function<bool(int32_t, const std::string&)> response,
std::function<bool(bool)> recalculate,
PromptContext &promptCtx) { PromptContext &promptCtx) {
if (!isModelLoaded()) { if (!isModelLoaded()) {
@ -119,9 +121,13 @@ void LLamaModel::prompt(const std::string &prompt,
// Check if the context has run out... // Check if the context has run out...
if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) { if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) {
// FIXME: will produce gibberish after this const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
promptCtx.n_past = std::min(promptCtx.n_past, int(promptCtx.n_ctx - batch.size())); // Erase the first percentage of context from the tokens...
std::cerr << "LLAMA WARNING: reached the end of the context window!\n"; std::cerr << "LLAMA: reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
promptCtx.n_past = promptCtx.tokens.size();
recalculateContext(promptCtx, recalculate);
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
} }
if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) { if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) {
@ -149,9 +155,13 @@ void LLamaModel::prompt(const std::string &prompt,
// Check if the context has run out... // Check if the context has run out...
if (promptCtx.n_past + 1 > promptCtx.n_ctx) { if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
// FIXME: will produce gibberish after this const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx - 1); // Erase the first percentage of context from the tokens...
std::cerr << "LLAMA WARNING: reached the end of the context window!\n"; std::cerr << "LLAMA: reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
promptCtx.n_past = promptCtx.tokens.size();
recalculateContext(promptCtx, recalculate);
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
} }
if (llama_eval(d_ptr->ctx, &id, 1, promptCtx.n_past, d_ptr->n_threads)) { if (llama_eval(d_ptr->ctx, &id, 1, promptCtx.n_past, d_ptr->n_threads)) {
@ -166,3 +176,28 @@ void LLamaModel::prompt(const std::string &prompt,
return; return;
} }
} }
void LLamaModel::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate)
{
size_t i = 0;
promptCtx.n_past = 0;
while (i < promptCtx.tokens.size()) {
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
std::vector<llama_token> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) {
std::cerr << "LLAMA ERROR: Failed to process prompt\n";
goto stop_generating;
}
promptCtx.n_past += batch.size();
if (!recalculate(true))
goto stop_generating;
i = batch_end;
}
assert(promptCtx.n_past == promptCtx.tokens.size());
stop_generating:
recalculate(false);
}

View File

@ -17,10 +17,15 @@ public:
bool isModelLoaded() const override; bool isModelLoaded() const override;
void prompt(const std::string &prompt, void prompt(const std::string &prompt,
std::function<bool(int32_t, const std::string&)> response, std::function<bool(int32_t, const std::string&)> response,
std::function<bool(bool)> recalculate,
PromptContext &ctx) override; PromptContext &ctx) override;
void setThreadCount(int32_t n_threads) override; void setThreadCount(int32_t n_threads) override;
int32_t threadCount() override; int32_t threadCount() override;
protected:
void recalculateContext(PromptContext &promptCtx,
std::function<bool(bool)> recalculate) override;
private: private:
LLamaPrivate *d_ptr; LLamaPrivate *d_ptr;
}; };

22
llm.cpp
View File

@ -39,6 +39,7 @@ LLMObject::LLMObject()
, m_llmodel(nullptr) , m_llmodel(nullptr)
, m_responseTokens(0) , m_responseTokens(0)
, m_responseLogits(0) , m_responseLogits(0)
, m_isRecalc(false)
{ {
moveToThread(&m_llmThread); moveToThread(&m_llmThread);
connect(&m_llmThread, &QThread::started, this, &LLMObject::loadModel); connect(&m_llmThread, &QThread::started, this, &LLMObject::loadModel);
@ -271,6 +272,15 @@ bool LLMObject::handleResponse(int32_t token, const std::string &response)
return !m_stopGenerating; return !m_stopGenerating;
} }
bool LLMObject::handleRecalculate(bool isRecalc)
{
if (m_isRecalc != isRecalc) {
m_isRecalc = isRecalc;
emit recalcChanged();
}
return !m_stopGenerating;
}
bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch) float temp, int32_t n_batch)
{ {
@ -280,7 +290,9 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in
QString instructPrompt = prompt_template.arg(prompt); QString instructPrompt = prompt_template.arg(prompt);
m_stopGenerating = false; m_stopGenerating = false;
auto func = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1, std::placeholders::_2); auto responseFunc = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1,
std::placeholders::_2);
auto recalcFunc = std::bind(&LLMObject::handleRecalculate, this, std::placeholders::_1);
emit responseStarted(); emit responseStarted();
qint32 logitsBefore = s_ctx.logits.size(); qint32 logitsBefore = s_ctx.logits.size();
s_ctx.n_predict = n_predict; s_ctx.n_predict = n_predict;
@ -288,7 +300,7 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in
s_ctx.top_p = top_p; s_ctx.top_p = top_p;
s_ctx.temp = temp; s_ctx.temp = temp;
s_ctx.n_batch = n_batch; s_ctx.n_batch = n_batch;
m_llmodel->prompt(instructPrompt.toStdString(), func, s_ctx); m_llmodel->prompt(instructPrompt.toStdString(), responseFunc, recalcFunc, s_ctx);
m_responseLogits += s_ctx.logits.size() - logitsBefore; m_responseLogits += s_ctx.logits.size() - logitsBefore;
std::string trimmed = trim_whitespace(m_response); std::string trimmed = trim_whitespace(m_response);
if (trimmed != m_response) { if (trimmed != m_response) {
@ -314,7 +326,7 @@ LLM::LLM()
connect(m_llmodel, &LLMObject::modelListChanged, this, &LLM::modelListChanged, Qt::QueuedConnection); connect(m_llmodel, &LLMObject::modelListChanged, this, &LLM::modelListChanged, Qt::QueuedConnection);
connect(m_llmodel, &LLMObject::threadCountChanged, this, &LLM::threadCountChanged, Qt::QueuedConnection); connect(m_llmodel, &LLMObject::threadCountChanged, this, &LLM::threadCountChanged, Qt::QueuedConnection);
connect(m_llmodel, &LLMObject::threadCountChanged, this, &LLM::syncThreadCount, Qt::QueuedConnection); connect(m_llmodel, &LLMObject::threadCountChanged, this, &LLM::syncThreadCount, Qt::QueuedConnection);
connect(m_llmodel, &LLMObject::recalcChanged, this, &LLM::recalcChanged, Qt::QueuedConnection);
connect(this, &LLM::promptRequested, m_llmodel, &LLMObject::prompt, Qt::QueuedConnection); connect(this, &LLM::promptRequested, m_llmodel, &LLMObject::prompt, Qt::QueuedConnection);
connect(this, &LLM::modelNameChangeRequested, m_llmodel, &LLMObject::modelNameChangeRequested, Qt::QueuedConnection); connect(this, &LLM::modelNameChangeRequested, m_llmodel, &LLMObject::modelNameChangeRequested, Qt::QueuedConnection);
@ -428,3 +440,7 @@ bool LLM::checkForUpdates() const
return QProcess::startDetached(fileName); return QProcess::startDetached(fileName);
} }
bool LLM::isRecalc() const
{
return m_llmodel->isRecalc();
}

11
llm.h
View File

@ -14,6 +14,7 @@ class LLMObject : public QObject
Q_PROPERTY(QString response READ response NOTIFY responseChanged) Q_PROPERTY(QString response READ response NOTIFY responseChanged)
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged) Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged) Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
public: public:
@ -33,6 +34,8 @@ public:
QList<QString> modelList() const; QList<QString> modelList() const;
void setModelName(const QString &modelName); void setModelName(const QString &modelName);
bool isRecalc() const { return m_isRecalc; }
public Q_SLOTS: public Q_SLOTS:
bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch); float temp, int32_t n_batch);
@ -47,10 +50,12 @@ Q_SIGNALS:
void modelNameChanged(); void modelNameChanged();
void modelListChanged(); void modelListChanged();
void threadCountChanged(); void threadCountChanged();
void recalcChanged();
private: private:
bool loadModelPrivate(const QString &modelName); bool loadModelPrivate(const QString &modelName);
bool handleResponse(int32_t token, const std::string &response); bool handleResponse(int32_t token, const std::string &response);
bool handleRecalculate(bool isRecalc);
private: private:
LLModel *m_llmodel; LLModel *m_llmodel;
@ -60,6 +65,7 @@ private:
QString m_modelName; QString m_modelName;
QThread m_llmThread; QThread m_llmThread;
std::atomic<bool> m_stopGenerating; std::atomic<bool> m_stopGenerating;
bool m_isRecalc;
}; };
class LLM : public QObject class LLM : public QObject
@ -71,6 +77,8 @@ class LLM : public QObject
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged) Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged) Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged) Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
public: public:
static LLM *globalInstance(); static LLM *globalInstance();
@ -96,6 +104,8 @@ public:
Q_INVOKABLE bool checkForUpdates() const; Q_INVOKABLE bool checkForUpdates() const;
bool isRecalc() const;
Q_SIGNALS: Q_SIGNALS:
void isModelLoadedChanged(); void isModelLoadedChanged();
void responseChanged(); void responseChanged();
@ -110,6 +120,7 @@ Q_SIGNALS:
void modelListChanged(); void modelListChanged();
void threadCountChanged(); void threadCountChanged();
void setThreadCountRequested(int32_t threadCount); void setThreadCountRequested(int32_t threadCount);
void recalcChanged();
private Q_SLOTS: private Q_SLOTS:
void responseStarted(); void responseStarted();

View File

@ -25,13 +25,19 @@ public:
int32_t n_batch = 9; int32_t n_batch = 9;
float repeat_penalty = 1.10f; float repeat_penalty = 1.10f;
int32_t repeat_last_n = 64; // last n tokens to penalize int32_t repeat_last_n = 64; // last n tokens to penalize
float contextErase = 0.75f; // percent of context to erase if we exceed the context
// window
}; };
virtual void prompt(const std::string &prompt, virtual void prompt(const std::string &prompt,
std::function<bool(int32_t, const std::string&)> response, std::function<bool(int32_t, const std::string&)> response,
std::function<bool(bool)> recalculate,
PromptContext &ctx) = 0; PromptContext &ctx) = 0;
virtual void setThreadCount(int32_t n_threads) {} virtual void setThreadCount(int32_t n_threads) {}
virtual int32_t threadCount() { return 1; } virtual int32_t threadCount() { return 1; }
protected:
virtual void recalculateContext(PromptContext &promptCtx,
std::function<bool(bool)> recalculate) = 0;
}; };
#endif // LLMODEL_H #endif // LLMODEL_H

View File

@ -288,6 +288,24 @@ Window {
text: qsTr("Connection to datalake failed.") text: qsTr("Connection to datalake failed.")
} }
PopupDialog {
id: recalcPopup
anchors.centerIn: parent
shouldTimeOut: false
shouldShowBusy: true
text: qsTr("Recalculating context.")
Connections {
target: LLM
function onRecalcChanged() {
if (LLM.isRecalc)
recalcPopup.open()
else
recalcPopup.close()
}
}
}
Button { Button {
id: copyButton id: copyButton
anchors.right: settingsButton.left anchors.right: settingsButton.left

View File

@ -7,23 +7,45 @@ import QtQuick.Layouts
Dialog { Dialog {
id: popupDialog id: popupDialog
anchors.centerIn: parent anchors.centerIn: parent
modal: false
opacity: 0.9 opacity: 0.9
padding: 20 padding: 20
property alias text: textField.text property alias text: textField.text
property bool shouldTimeOut: true
property bool shouldShowBusy: false
modal: shouldShowBusy
closePolicy: shouldShowBusy ? Popup.NoAutoClose : (Popup.CloseOnEscape | Popup.CloseOnPressOutside)
Theme { Theme {
id: theme id: theme
} }
Text { Row {
id: textField anchors.centerIn: parent
horizontalAlignment: Text.AlignJustify width: childrenRect.width
color: theme.textColor height: childrenRect.height
Accessible.role: Accessible.HelpBalloon spacing: 20
Accessible.name: text
Accessible.description: qsTr("Reveals a shortlived help balloon") Text {
id: textField
anchors.verticalCenter: busyIndicator.verticalCenter
horizontalAlignment: Text.AlignJustify
color: theme.textColor
Accessible.role: Accessible.HelpBalloon
Accessible.name: text
Accessible.description: qsTr("Reveals a shortlived help balloon")
}
BusyIndicator {
id: busyIndicator
visible: shouldShowBusy
running: shouldShowBusy
Accessible.role: Accessible.Animation
Accessible.name: qsTr("Busy indicator")
Accessible.description: qsTr("Displayed when the popup is showing busy")
}
} }
background: Rectangle { background: Rectangle {
anchors.fill: parent anchors.fill: parent
color: theme.backgroundDarkest color: theme.backgroundDarkest
@ -37,7 +59,8 @@ Dialog {
} }
onOpened: { onOpened: {
timer.start() if (shouldTimeOut)
timer.start()
} }
Timer { Timer {