mirror of
https://github.com/nomic-ai/gpt4all
synced 2024-11-06 09:20:33 +00:00
Infinite context window through trimming.
This commit is contained in:
parent
8b1ddabe3e
commit
b6937c39db
59
gptj.cpp
59
gptj.cpp
@ -635,6 +635,7 @@ struct GPTJPrivate {
|
||||
gpt_vocab vocab;
|
||||
gptj_model model;
|
||||
int64_t n_threads = 0;
|
||||
size_t mem_per_token = 0;
|
||||
std::mt19937 rng;
|
||||
};
|
||||
|
||||
@ -662,6 +663,7 @@ bool GPTJ::loadModel(const std::string &modelPath, std::istream &fin) {
|
||||
|
||||
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
d_ptr->modelLoaded = true;
|
||||
fflush(stdout);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -685,6 +687,7 @@ bool GPTJ::isModelLoaded() const
|
||||
|
||||
void GPTJ::prompt(const std::string &prompt,
|
||||
std::function<bool(int32_t, const std::string&)> response,
|
||||
std::function<bool(bool)> recalculate,
|
||||
PromptContext &promptCtx) {
|
||||
|
||||
if (!isModelLoaded()) {
|
||||
@ -711,9 +714,9 @@ void GPTJ::prompt(const std::string &prompt,
|
||||
static bool initialized = false;
|
||||
static std::vector<gpt_vocab::id> p_instruct;
|
||||
static std::vector<gpt_vocab::id> r_instruct;
|
||||
size_t mem_per_token = 0;
|
||||
if (!initialized) {
|
||||
gptj_eval(d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, promptCtx.logits, mem_per_token);
|
||||
gptj_eval(d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, promptCtx.logits,
|
||||
d_ptr->mem_per_token);
|
||||
initialized = true;
|
||||
}
|
||||
|
||||
@ -726,12 +729,17 @@ void GPTJ::prompt(const std::string &prompt,
|
||||
|
||||
// Check if the context has run out...
|
||||
if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) {
|
||||
// FIXME: will produce gibberish after this
|
||||
promptCtx.n_past = std::min(promptCtx.n_past, int(promptCtx.n_ctx - batch.size()));
|
||||
std::cerr << "GPT-J WARNING: reached the end of the context window!\n";
|
||||
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
|
||||
// Erase the first percentage of context from the tokens...
|
||||
std::cerr << "GPTJ: reached the end of the context window so resizing\n";
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
||||
promptCtx.n_past = promptCtx.tokens.size();
|
||||
recalculateContext(promptCtx, recalculate);
|
||||
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
||||
}
|
||||
|
||||
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits, mem_per_token)) {
|
||||
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits,
|
||||
d_ptr->mem_per_token)) {
|
||||
std::cerr << "GPT-J ERROR: Failed to process prompt\n";
|
||||
return;
|
||||
}
|
||||
@ -770,13 +778,18 @@ void GPTJ::prompt(const std::string &prompt,
|
||||
|
||||
// Check if the context has run out...
|
||||
if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
|
||||
// FIXME: will produce gibberish after this
|
||||
promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx - 1);
|
||||
std::cerr << "GPT-J WARNING: reached the end of the context window!\n";
|
||||
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
|
||||
// Erase the first percentage of context from the tokens...
|
||||
std::cerr << "GPTJ: reached the end of the context window so resizing\n";
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
||||
promptCtx.n_past = promptCtx.tokens.size();
|
||||
recalculateContext(promptCtx, recalculate);
|
||||
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
||||
}
|
||||
|
||||
const int64_t t_start_predict_us = ggml_time_us();
|
||||
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits, mem_per_token)) {
|
||||
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits,
|
||||
d_ptr->mem_per_token)) {
|
||||
std::cerr << "GPT-J ERROR: Failed to predict next token\n";
|
||||
return;
|
||||
}
|
||||
@ -807,3 +820,29 @@ stop_generating:
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
void GPTJ::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate)
|
||||
{
|
||||
size_t i = 0;
|
||||
promptCtx.n_past = 0;
|
||||
while (i < promptCtx.tokens.size()) {
|
||||
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
|
||||
std::vector<gpt_vocab::id> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
|
||||
|
||||
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
||||
|
||||
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits,
|
||||
d_ptr->mem_per_token)) {
|
||||
std::cerr << "GPTJ ERROR: Failed to process prompt\n";
|
||||
goto stop_generating;
|
||||
}
|
||||
promptCtx.n_past += batch.size();
|
||||
if (!recalculate(true))
|
||||
goto stop_generating;
|
||||
i = batch_end;
|
||||
}
|
||||
assert(promptCtx.n_past == promptCtx.tokens.size());
|
||||
|
||||
stop_generating:
|
||||
recalculate(false);
|
||||
}
|
||||
|
5
gptj.h
5
gptj.h
@ -17,10 +17,15 @@ public:
|
||||
bool isModelLoaded() const override;
|
||||
void prompt(const std::string &prompt,
|
||||
std::function<bool(int32_t, const std::string&)> response,
|
||||
std::function<bool(bool)> recalculate,
|
||||
PromptContext &ctx) override;
|
||||
void setThreadCount(int32_t n_threads) override;
|
||||
int32_t threadCount() override;
|
||||
|
||||
protected:
|
||||
void recalculateContext(PromptContext &promptCtx,
|
||||
std::function<bool(bool)> recalculate) override;
|
||||
|
||||
private:
|
||||
GPTJPrivate *d_ptr;
|
||||
};
|
||||
|
@ -58,6 +58,7 @@ bool LLamaModel::loadModel(const std::string &modelPath)
|
||||
|
||||
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
d_ptr->modelLoaded = true;
|
||||
fflush(stderr);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -80,6 +81,7 @@ bool LLamaModel::isModelLoaded() const
|
||||
|
||||
void LLamaModel::prompt(const std::string &prompt,
|
||||
std::function<bool(int32_t, const std::string&)> response,
|
||||
std::function<bool(bool)> recalculate,
|
||||
PromptContext &promptCtx) {
|
||||
|
||||
if (!isModelLoaded()) {
|
||||
@ -119,9 +121,13 @@ void LLamaModel::prompt(const std::string &prompt,
|
||||
|
||||
// Check if the context has run out...
|
||||
if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) {
|
||||
// FIXME: will produce gibberish after this
|
||||
promptCtx.n_past = std::min(promptCtx.n_past, int(promptCtx.n_ctx - batch.size()));
|
||||
std::cerr << "LLAMA WARNING: reached the end of the context window!\n";
|
||||
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
|
||||
// Erase the first percentage of context from the tokens...
|
||||
std::cerr << "LLAMA: reached the end of the context window so resizing\n";
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
||||
promptCtx.n_past = promptCtx.tokens.size();
|
||||
recalculateContext(promptCtx, recalculate);
|
||||
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
||||
}
|
||||
|
||||
if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) {
|
||||
@ -149,9 +155,13 @@ void LLamaModel::prompt(const std::string &prompt,
|
||||
|
||||
// Check if the context has run out...
|
||||
if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
|
||||
// FIXME: will produce gibberish after this
|
||||
promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx - 1);
|
||||
std::cerr << "LLAMA WARNING: reached the end of the context window!\n";
|
||||
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
|
||||
// Erase the first percentage of context from the tokens...
|
||||
std::cerr << "LLAMA: reached the end of the context window so resizing\n";
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
||||
promptCtx.n_past = promptCtx.tokens.size();
|
||||
recalculateContext(promptCtx, recalculate);
|
||||
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
||||
}
|
||||
|
||||
if (llama_eval(d_ptr->ctx, &id, 1, promptCtx.n_past, d_ptr->n_threads)) {
|
||||
@ -166,3 +176,28 @@ void LLamaModel::prompt(const std::string &prompt,
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void LLamaModel::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate)
|
||||
{
|
||||
size_t i = 0;
|
||||
promptCtx.n_past = 0;
|
||||
while (i < promptCtx.tokens.size()) {
|
||||
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
|
||||
std::vector<llama_token> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
|
||||
|
||||
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
||||
|
||||
if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) {
|
||||
std::cerr << "LLAMA ERROR: Failed to process prompt\n";
|
||||
goto stop_generating;
|
||||
}
|
||||
promptCtx.n_past += batch.size();
|
||||
if (!recalculate(true))
|
||||
goto stop_generating;
|
||||
i = batch_end;
|
||||
}
|
||||
assert(promptCtx.n_past == promptCtx.tokens.size());
|
||||
|
||||
stop_generating:
|
||||
recalculate(false);
|
||||
}
|
||||
|
@ -17,10 +17,15 @@ public:
|
||||
bool isModelLoaded() const override;
|
||||
void prompt(const std::string &prompt,
|
||||
std::function<bool(int32_t, const std::string&)> response,
|
||||
std::function<bool(bool)> recalculate,
|
||||
PromptContext &ctx) override;
|
||||
void setThreadCount(int32_t n_threads) override;
|
||||
int32_t threadCount() override;
|
||||
|
||||
protected:
|
||||
void recalculateContext(PromptContext &promptCtx,
|
||||
std::function<bool(bool)> recalculate) override;
|
||||
|
||||
private:
|
||||
LLamaPrivate *d_ptr;
|
||||
};
|
||||
|
22
llm.cpp
22
llm.cpp
@ -39,6 +39,7 @@ LLMObject::LLMObject()
|
||||
, m_llmodel(nullptr)
|
||||
, m_responseTokens(0)
|
||||
, m_responseLogits(0)
|
||||
, m_isRecalc(false)
|
||||
{
|
||||
moveToThread(&m_llmThread);
|
||||
connect(&m_llmThread, &QThread::started, this, &LLMObject::loadModel);
|
||||
@ -271,6 +272,15 @@ bool LLMObject::handleResponse(int32_t token, const std::string &response)
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
|
||||
bool LLMObject::handleRecalculate(bool isRecalc)
|
||||
{
|
||||
if (m_isRecalc != isRecalc) {
|
||||
m_isRecalc = isRecalc;
|
||||
emit recalcChanged();
|
||||
}
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
|
||||
bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
|
||||
float temp, int32_t n_batch)
|
||||
{
|
||||
@ -280,7 +290,9 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in
|
||||
QString instructPrompt = prompt_template.arg(prompt);
|
||||
|
||||
m_stopGenerating = false;
|
||||
auto func = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1, std::placeholders::_2);
|
||||
auto responseFunc = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1,
|
||||
std::placeholders::_2);
|
||||
auto recalcFunc = std::bind(&LLMObject::handleRecalculate, this, std::placeholders::_1);
|
||||
emit responseStarted();
|
||||
qint32 logitsBefore = s_ctx.logits.size();
|
||||
s_ctx.n_predict = n_predict;
|
||||
@ -288,7 +300,7 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in
|
||||
s_ctx.top_p = top_p;
|
||||
s_ctx.temp = temp;
|
||||
s_ctx.n_batch = n_batch;
|
||||
m_llmodel->prompt(instructPrompt.toStdString(), func, s_ctx);
|
||||
m_llmodel->prompt(instructPrompt.toStdString(), responseFunc, recalcFunc, s_ctx);
|
||||
m_responseLogits += s_ctx.logits.size() - logitsBefore;
|
||||
std::string trimmed = trim_whitespace(m_response);
|
||||
if (trimmed != m_response) {
|
||||
@ -314,7 +326,7 @@ LLM::LLM()
|
||||
connect(m_llmodel, &LLMObject::modelListChanged, this, &LLM::modelListChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LLMObject::threadCountChanged, this, &LLM::threadCountChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LLMObject::threadCountChanged, this, &LLM::syncThreadCount, Qt::QueuedConnection);
|
||||
|
||||
connect(m_llmodel, &LLMObject::recalcChanged, this, &LLM::recalcChanged, Qt::QueuedConnection);
|
||||
|
||||
connect(this, &LLM::promptRequested, m_llmodel, &LLMObject::prompt, Qt::QueuedConnection);
|
||||
connect(this, &LLM::modelNameChangeRequested, m_llmodel, &LLMObject::modelNameChangeRequested, Qt::QueuedConnection);
|
||||
@ -428,3 +440,7 @@ bool LLM::checkForUpdates() const
|
||||
return QProcess::startDetached(fileName);
|
||||
}
|
||||
|
||||
bool LLM::isRecalc() const
|
||||
{
|
||||
return m_llmodel->isRecalc();
|
||||
}
|
||||
|
11
llm.h
11
llm.h
@ -14,6 +14,7 @@ class LLMObject : public QObject
|
||||
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
|
||||
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
|
||||
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
|
||||
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
|
||||
|
||||
public:
|
||||
|
||||
@ -33,6 +34,8 @@ public:
|
||||
QList<QString> modelList() const;
|
||||
void setModelName(const QString &modelName);
|
||||
|
||||
bool isRecalc() const { return m_isRecalc; }
|
||||
|
||||
public Q_SLOTS:
|
||||
bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
|
||||
float temp, int32_t n_batch);
|
||||
@ -47,10 +50,12 @@ Q_SIGNALS:
|
||||
void modelNameChanged();
|
||||
void modelListChanged();
|
||||
void threadCountChanged();
|
||||
void recalcChanged();
|
||||
|
||||
private:
|
||||
bool loadModelPrivate(const QString &modelName);
|
||||
bool handleResponse(int32_t token, const std::string &response);
|
||||
bool handleRecalculate(bool isRecalc);
|
||||
|
||||
private:
|
||||
LLModel *m_llmodel;
|
||||
@ -60,6 +65,7 @@ private:
|
||||
QString m_modelName;
|
||||
QThread m_llmThread;
|
||||
std::atomic<bool> m_stopGenerating;
|
||||
bool m_isRecalc;
|
||||
};
|
||||
|
||||
class LLM : public QObject
|
||||
@ -71,6 +77,8 @@ class LLM : public QObject
|
||||
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
|
||||
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
|
||||
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
|
||||
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
|
||||
|
||||
public:
|
||||
|
||||
static LLM *globalInstance();
|
||||
@ -96,6 +104,8 @@ public:
|
||||
|
||||
Q_INVOKABLE bool checkForUpdates() const;
|
||||
|
||||
bool isRecalc() const;
|
||||
|
||||
Q_SIGNALS:
|
||||
void isModelLoadedChanged();
|
||||
void responseChanged();
|
||||
@ -110,6 +120,7 @@ Q_SIGNALS:
|
||||
void modelListChanged();
|
||||
void threadCountChanged();
|
||||
void setThreadCountRequested(int32_t threadCount);
|
||||
void recalcChanged();
|
||||
|
||||
private Q_SLOTS:
|
||||
void responseStarted();
|
||||
|
@ -25,13 +25,19 @@ public:
|
||||
int32_t n_batch = 9;
|
||||
float repeat_penalty = 1.10f;
|
||||
int32_t repeat_last_n = 64; // last n tokens to penalize
|
||||
|
||||
float contextErase = 0.75f; // percent of context to erase if we exceed the context
|
||||
// window
|
||||
};
|
||||
virtual void prompt(const std::string &prompt,
|
||||
std::function<bool(int32_t, const std::string&)> response,
|
||||
std::function<bool(bool)> recalculate,
|
||||
PromptContext &ctx) = 0;
|
||||
virtual void setThreadCount(int32_t n_threads) {}
|
||||
virtual int32_t threadCount() { return 1; }
|
||||
|
||||
protected:
|
||||
virtual void recalculateContext(PromptContext &promptCtx,
|
||||
std::function<bool(bool)> recalculate) = 0;
|
||||
};
|
||||
|
||||
#endif // LLMODEL_H
|
||||
|
18
main.qml
18
main.qml
@ -288,6 +288,24 @@ Window {
|
||||
text: qsTr("Connection to datalake failed.")
|
||||
}
|
||||
|
||||
PopupDialog {
|
||||
id: recalcPopup
|
||||
anchors.centerIn: parent
|
||||
shouldTimeOut: false
|
||||
shouldShowBusy: true
|
||||
text: qsTr("Recalculating context.")
|
||||
|
||||
Connections {
|
||||
target: LLM
|
||||
function onRecalcChanged() {
|
||||
if (LLM.isRecalc)
|
||||
recalcPopup.open()
|
||||
else
|
||||
recalcPopup.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Button {
|
||||
id: copyButton
|
||||
anchors.right: settingsButton.left
|
||||
|
@ -7,23 +7,45 @@ import QtQuick.Layouts
|
||||
Dialog {
|
||||
id: popupDialog
|
||||
anchors.centerIn: parent
|
||||
modal: false
|
||||
opacity: 0.9
|
||||
padding: 20
|
||||
property alias text: textField.text
|
||||
property bool shouldTimeOut: true
|
||||
property bool shouldShowBusy: false
|
||||
modal: shouldShowBusy
|
||||
closePolicy: shouldShowBusy ? Popup.NoAutoClose : (Popup.CloseOnEscape | Popup.CloseOnPressOutside)
|
||||
|
||||
Theme {
|
||||
id: theme
|
||||
}
|
||||
|
||||
Text {
|
||||
id: textField
|
||||
horizontalAlignment: Text.AlignJustify
|
||||
color: theme.textColor
|
||||
Accessible.role: Accessible.HelpBalloon
|
||||
Accessible.name: text
|
||||
Accessible.description: qsTr("Reveals a shortlived help balloon")
|
||||
Row {
|
||||
anchors.centerIn: parent
|
||||
width: childrenRect.width
|
||||
height: childrenRect.height
|
||||
spacing: 20
|
||||
|
||||
Text {
|
||||
id: textField
|
||||
anchors.verticalCenter: busyIndicator.verticalCenter
|
||||
horizontalAlignment: Text.AlignJustify
|
||||
color: theme.textColor
|
||||
Accessible.role: Accessible.HelpBalloon
|
||||
Accessible.name: text
|
||||
Accessible.description: qsTr("Reveals a shortlived help balloon")
|
||||
}
|
||||
|
||||
BusyIndicator {
|
||||
id: busyIndicator
|
||||
visible: shouldShowBusy
|
||||
running: shouldShowBusy
|
||||
|
||||
Accessible.role: Accessible.Animation
|
||||
Accessible.name: qsTr("Busy indicator")
|
||||
Accessible.description: qsTr("Displayed when the popup is showing busy")
|
||||
}
|
||||
}
|
||||
|
||||
background: Rectangle {
|
||||
anchors.fill: parent
|
||||
color: theme.backgroundDarkest
|
||||
@ -37,7 +59,8 @@ Dialog {
|
||||
}
|
||||
|
||||
onOpened: {
|
||||
timer.start()
|
||||
if (shouldTimeOut)
|
||||
timer.start()
|
||||
}
|
||||
|
||||
Timer {
|
||||
|
Loading…
Reference in New Issue
Block a user