From b6937c39db7121dd8137f34f37b6bdebae251fb3 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Tue, 25 Apr 2023 11:20:51 -0400 Subject: [PATCH] Infinite context window through trimming. --- gptj.cpp | 59 +++++++++++++++++++++++++++++++++++++-------- gptj.h | 5 ++++ llamamodel.cpp | 47 +++++++++++++++++++++++++++++++----- llamamodel.h | 5 ++++ llm.cpp | 22 ++++++++++++++--- llm.h | 11 +++++++++ llmodel.h | 8 +++++- main.qml | 18 ++++++++++++++ qml/PopupDialog.qml | 41 ++++++++++++++++++++++++------- 9 files changed, 187 insertions(+), 29 deletions(-) diff --git a/gptj.cpp b/gptj.cpp index e4ceacfe..6f353d75 100644 --- a/gptj.cpp +++ b/gptj.cpp @@ -635,6 +635,7 @@ struct GPTJPrivate { gpt_vocab vocab; gptj_model model; int64_t n_threads = 0; + size_t mem_per_token = 0; std::mt19937 rng; }; @@ -662,6 +663,7 @@ bool GPTJ::loadModel(const std::string &modelPath, std::istream &fin) { d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); d_ptr->modelLoaded = true; + fflush(stdout); return true; } @@ -685,6 +687,7 @@ bool GPTJ::isModelLoaded() const void GPTJ::prompt(const std::string &prompt, std::function response, + std::function recalculate, PromptContext &promptCtx) { if (!isModelLoaded()) { @@ -711,9 +714,9 @@ void GPTJ::prompt(const std::string &prompt, static bool initialized = false; static std::vector p_instruct; static std::vector r_instruct; - size_t mem_per_token = 0; if (!initialized) { - gptj_eval(d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, promptCtx.logits, mem_per_token); + gptj_eval(d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, promptCtx.logits, + d_ptr->mem_per_token); initialized = true; } @@ -726,12 +729,17 @@ void GPTJ::prompt(const std::string &prompt, // Check if the context has run out... if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) { - // FIXME: will produce gibberish after this - promptCtx.n_past = std::min(promptCtx.n_past, int(promptCtx.n_ctx - batch.size())); - std::cerr << "GPT-J WARNING: reached the end of the context window!\n"; + const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase; + // Erase the first percentage of context from the tokens... + std::cerr << "GPTJ: reached the end of the context window so resizing\n"; + promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); + promptCtx.n_past = promptCtx.tokens.size(); + recalculateContext(promptCtx, recalculate); + assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx); } - if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits, mem_per_token)) { + if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits, + d_ptr->mem_per_token)) { std::cerr << "GPT-J ERROR: Failed to process prompt\n"; return; } @@ -770,13 +778,18 @@ void GPTJ::prompt(const std::string &prompt, // Check if the context has run out... if (promptCtx.n_past + 1 > promptCtx.n_ctx) { - // FIXME: will produce gibberish after this - promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx - 1); - std::cerr << "GPT-J WARNING: reached the end of the context window!\n"; + const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase; + // Erase the first percentage of context from the tokens... + std::cerr << "GPTJ: reached the end of the context window so resizing\n"; + promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); + promptCtx.n_past = promptCtx.tokens.size(); + recalculateContext(promptCtx, recalculate); + assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx); } const int64_t t_start_predict_us = ggml_time_us(); - if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits, mem_per_token)) { + if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits, + d_ptr->mem_per_token)) { std::cerr << "GPT-J ERROR: Failed to predict next token\n"; return; } @@ -807,3 +820,29 @@ stop_generating: return; } + +void GPTJ::recalculateContext(PromptContext &promptCtx, std::function recalculate) +{ + size_t i = 0; + promptCtx.n_past = 0; + while (i < promptCtx.tokens.size()) { + size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size()); + std::vector batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end); + + assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx); + + if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits, + d_ptr->mem_per_token)) { + std::cerr << "GPTJ ERROR: Failed to process prompt\n"; + goto stop_generating; + } + promptCtx.n_past += batch.size(); + if (!recalculate(true)) + goto stop_generating; + i = batch_end; + } + assert(promptCtx.n_past == promptCtx.tokens.size()); + +stop_generating: + recalculate(false); +} diff --git a/gptj.h b/gptj.h index a6a0b8dc..17cb069c 100644 --- a/gptj.h +++ b/gptj.h @@ -17,10 +17,15 @@ public: bool isModelLoaded() const override; void prompt(const std::string &prompt, std::function response, + std::function recalculate, PromptContext &ctx) override; void setThreadCount(int32_t n_threads) override; int32_t threadCount() override; +protected: + void recalculateContext(PromptContext &promptCtx, + std::function recalculate) override; + private: GPTJPrivate *d_ptr; }; diff --git a/llamamodel.cpp b/llamamodel.cpp index 693c05ea..06e4aced 100644 --- a/llamamodel.cpp +++ b/llamamodel.cpp @@ -58,6 +58,7 @@ bool LLamaModel::loadModel(const std::string &modelPath) d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); d_ptr->modelLoaded = true; + fflush(stderr); return true; } @@ -80,6 +81,7 @@ bool LLamaModel::isModelLoaded() const void LLamaModel::prompt(const std::string &prompt, std::function response, + std::function recalculate, PromptContext &promptCtx) { if (!isModelLoaded()) { @@ -119,9 +121,13 @@ void LLamaModel::prompt(const std::string &prompt, // Check if the context has run out... if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) { - // FIXME: will produce gibberish after this - promptCtx.n_past = std::min(promptCtx.n_past, int(promptCtx.n_ctx - batch.size())); - std::cerr << "LLAMA WARNING: reached the end of the context window!\n"; + const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase; + // Erase the first percentage of context from the tokens... + std::cerr << "LLAMA: reached the end of the context window so resizing\n"; + promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); + promptCtx.n_past = promptCtx.tokens.size(); + recalculateContext(promptCtx, recalculate); + assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx); } if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) { @@ -149,9 +155,13 @@ void LLamaModel::prompt(const std::string &prompt, // Check if the context has run out... if (promptCtx.n_past + 1 > promptCtx.n_ctx) { - // FIXME: will produce gibberish after this - promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx - 1); - std::cerr << "LLAMA WARNING: reached the end of the context window!\n"; + const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase; + // Erase the first percentage of context from the tokens... + std::cerr << "LLAMA: reached the end of the context window so resizing\n"; + promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); + promptCtx.n_past = promptCtx.tokens.size(); + recalculateContext(promptCtx, recalculate); + assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx); } if (llama_eval(d_ptr->ctx, &id, 1, promptCtx.n_past, d_ptr->n_threads)) { @@ -166,3 +176,28 @@ void LLamaModel::prompt(const std::string &prompt, return; } } + +void LLamaModel::recalculateContext(PromptContext &promptCtx, std::function recalculate) +{ + size_t i = 0; + promptCtx.n_past = 0; + while (i < promptCtx.tokens.size()) { + size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size()); + std::vector batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end); + + assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx); + + if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) { + std::cerr << "LLAMA ERROR: Failed to process prompt\n"; + goto stop_generating; + } + promptCtx.n_past += batch.size(); + if (!recalculate(true)) + goto stop_generating; + i = batch_end; + } + assert(promptCtx.n_past == promptCtx.tokens.size()); + +stop_generating: + recalculate(false); +} diff --git a/llamamodel.h b/llamamodel.h index 57eb4194..163260bb 100644 --- a/llamamodel.h +++ b/llamamodel.h @@ -17,10 +17,15 @@ public: bool isModelLoaded() const override; void prompt(const std::string &prompt, std::function response, + std::function recalculate, PromptContext &ctx) override; void setThreadCount(int32_t n_threads) override; int32_t threadCount() override; +protected: + void recalculateContext(PromptContext &promptCtx, + std::function recalculate) override; + private: LLamaPrivate *d_ptr; }; diff --git a/llm.cpp b/llm.cpp index 332b1e85..ae0e4fa4 100644 --- a/llm.cpp +++ b/llm.cpp @@ -39,6 +39,7 @@ LLMObject::LLMObject() , m_llmodel(nullptr) , m_responseTokens(0) , m_responseLogits(0) + , m_isRecalc(false) { moveToThread(&m_llmThread); connect(&m_llmThread, &QThread::started, this, &LLMObject::loadModel); @@ -271,6 +272,15 @@ bool LLMObject::handleResponse(int32_t token, const std::string &response) return !m_stopGenerating; } +bool LLMObject::handleRecalculate(bool isRecalc) +{ + if (m_isRecalc != isRecalc) { + m_isRecalc = isRecalc; + emit recalcChanged(); + } + return !m_stopGenerating; +} + bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch) { @@ -280,7 +290,9 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in QString instructPrompt = prompt_template.arg(prompt); m_stopGenerating = false; - auto func = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1, std::placeholders::_2); + auto responseFunc = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1, + std::placeholders::_2); + auto recalcFunc = std::bind(&LLMObject::handleRecalculate, this, std::placeholders::_1); emit responseStarted(); qint32 logitsBefore = s_ctx.logits.size(); s_ctx.n_predict = n_predict; @@ -288,7 +300,7 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in s_ctx.top_p = top_p; s_ctx.temp = temp; s_ctx.n_batch = n_batch; - m_llmodel->prompt(instructPrompt.toStdString(), func, s_ctx); + m_llmodel->prompt(instructPrompt.toStdString(), responseFunc, recalcFunc, s_ctx); m_responseLogits += s_ctx.logits.size() - logitsBefore; std::string trimmed = trim_whitespace(m_response); if (trimmed != m_response) { @@ -314,7 +326,7 @@ LLM::LLM() connect(m_llmodel, &LLMObject::modelListChanged, this, &LLM::modelListChanged, Qt::QueuedConnection); connect(m_llmodel, &LLMObject::threadCountChanged, this, &LLM::threadCountChanged, Qt::QueuedConnection); connect(m_llmodel, &LLMObject::threadCountChanged, this, &LLM::syncThreadCount, Qt::QueuedConnection); - + connect(m_llmodel, &LLMObject::recalcChanged, this, &LLM::recalcChanged, Qt::QueuedConnection); connect(this, &LLM::promptRequested, m_llmodel, &LLMObject::prompt, Qt::QueuedConnection); connect(this, &LLM::modelNameChangeRequested, m_llmodel, &LLMObject::modelNameChangeRequested, Qt::QueuedConnection); @@ -428,3 +440,7 @@ bool LLM::checkForUpdates() const return QProcess::startDetached(fileName); } +bool LLM::isRecalc() const +{ + return m_llmodel->isRecalc(); +} diff --git a/llm.h b/llm.h index bf95a348..c2061eca 100644 --- a/llm.h +++ b/llm.h @@ -14,6 +14,7 @@ class LLMObject : public QObject Q_PROPERTY(QString response READ response NOTIFY responseChanged) Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged) Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged) + Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged) public: @@ -33,6 +34,8 @@ public: QList modelList() const; void setModelName(const QString &modelName); + bool isRecalc() const { return m_isRecalc; } + public Q_SLOTS: bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch); @@ -47,10 +50,12 @@ Q_SIGNALS: void modelNameChanged(); void modelListChanged(); void threadCountChanged(); + void recalcChanged(); private: bool loadModelPrivate(const QString &modelName); bool handleResponse(int32_t token, const std::string &response); + bool handleRecalculate(bool isRecalc); private: LLModel *m_llmodel; @@ -60,6 +65,7 @@ private: QString m_modelName; QThread m_llmThread; std::atomic m_stopGenerating; + bool m_isRecalc; }; class LLM : public QObject @@ -71,6 +77,8 @@ class LLM : public QObject Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged) Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged) Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged) + Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged) + public: static LLM *globalInstance(); @@ -96,6 +104,8 @@ public: Q_INVOKABLE bool checkForUpdates() const; + bool isRecalc() const; + Q_SIGNALS: void isModelLoadedChanged(); void responseChanged(); @@ -110,6 +120,7 @@ Q_SIGNALS: void modelListChanged(); void threadCountChanged(); void setThreadCountRequested(int32_t threadCount); + void recalcChanged(); private Q_SLOTS: void responseStarted(); diff --git a/llmodel.h b/llmodel.h index cacd23aa..5945eb98 100644 --- a/llmodel.h +++ b/llmodel.h @@ -25,13 +25,19 @@ public: int32_t n_batch = 9; float repeat_penalty = 1.10f; int32_t repeat_last_n = 64; // last n tokens to penalize - + float contextErase = 0.75f; // percent of context to erase if we exceed the context + // window }; virtual void prompt(const std::string &prompt, std::function response, + std::function recalculate, PromptContext &ctx) = 0; virtual void setThreadCount(int32_t n_threads) {} virtual int32_t threadCount() { return 1; } + +protected: + virtual void recalculateContext(PromptContext &promptCtx, + std::function recalculate) = 0; }; #endif // LLMODEL_H diff --git a/main.qml b/main.qml index d7bd17c5..8c3596fe 100644 --- a/main.qml +++ b/main.qml @@ -288,6 +288,24 @@ Window { text: qsTr("Connection to datalake failed.") } + PopupDialog { + id: recalcPopup + anchors.centerIn: parent + shouldTimeOut: false + shouldShowBusy: true + text: qsTr("Recalculating context.") + + Connections { + target: LLM + function onRecalcChanged() { + if (LLM.isRecalc) + recalcPopup.open() + else + recalcPopup.close() + } + } + } + Button { id: copyButton anchors.right: settingsButton.left diff --git a/qml/PopupDialog.qml b/qml/PopupDialog.qml index 3633755b..dfd80d54 100644 --- a/qml/PopupDialog.qml +++ b/qml/PopupDialog.qml @@ -7,23 +7,45 @@ import QtQuick.Layouts Dialog { id: popupDialog anchors.centerIn: parent - modal: false opacity: 0.9 padding: 20 property alias text: textField.text + property bool shouldTimeOut: true + property bool shouldShowBusy: false + modal: shouldShowBusy + closePolicy: shouldShowBusy ? Popup.NoAutoClose : (Popup.CloseOnEscape | Popup.CloseOnPressOutside) Theme { id: theme } - Text { - id: textField - horizontalAlignment: Text.AlignJustify - color: theme.textColor - Accessible.role: Accessible.HelpBalloon - Accessible.name: text - Accessible.description: qsTr("Reveals a shortlived help balloon") + Row { + anchors.centerIn: parent + width: childrenRect.width + height: childrenRect.height + spacing: 20 + + Text { + id: textField + anchors.verticalCenter: busyIndicator.verticalCenter + horizontalAlignment: Text.AlignJustify + color: theme.textColor + Accessible.role: Accessible.HelpBalloon + Accessible.name: text + Accessible.description: qsTr("Reveals a shortlived help balloon") + } + + BusyIndicator { + id: busyIndicator + visible: shouldShowBusy + running: shouldShowBusy + + Accessible.role: Accessible.Animation + Accessible.name: qsTr("Busy indicator") + Accessible.description: qsTr("Displayed when the popup is showing busy") + } } + background: Rectangle { anchors.fill: parent color: theme.backgroundDarkest @@ -37,7 +59,8 @@ Dialog { } onOpened: { - timer.start() + if (shouldTimeOut) + timer.start() } Timer {