mirror of
https://github.com/nomic-ai/gpt4all
synced 2024-11-06 09:20:33 +00:00
Infinite context window through trimming.
This commit is contained in:
parent
8b1ddabe3e
commit
b6937c39db
59
gptj.cpp
59
gptj.cpp
@ -635,6 +635,7 @@ struct GPTJPrivate {
|
|||||||
gpt_vocab vocab;
|
gpt_vocab vocab;
|
||||||
gptj_model model;
|
gptj_model model;
|
||||||
int64_t n_threads = 0;
|
int64_t n_threads = 0;
|
||||||
|
size_t mem_per_token = 0;
|
||||||
std::mt19937 rng;
|
std::mt19937 rng;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -662,6 +663,7 @@ bool GPTJ::loadModel(const std::string &modelPath, std::istream &fin) {
|
|||||||
|
|
||||||
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||||
d_ptr->modelLoaded = true;
|
d_ptr->modelLoaded = true;
|
||||||
|
fflush(stdout);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -685,6 +687,7 @@ bool GPTJ::isModelLoaded() const
|
|||||||
|
|
||||||
void GPTJ::prompt(const std::string &prompt,
|
void GPTJ::prompt(const std::string &prompt,
|
||||||
std::function<bool(int32_t, const std::string&)> response,
|
std::function<bool(int32_t, const std::string&)> response,
|
||||||
|
std::function<bool(bool)> recalculate,
|
||||||
PromptContext &promptCtx) {
|
PromptContext &promptCtx) {
|
||||||
|
|
||||||
if (!isModelLoaded()) {
|
if (!isModelLoaded()) {
|
||||||
@ -711,9 +714,9 @@ void GPTJ::prompt(const std::string &prompt,
|
|||||||
static bool initialized = false;
|
static bool initialized = false;
|
||||||
static std::vector<gpt_vocab::id> p_instruct;
|
static std::vector<gpt_vocab::id> p_instruct;
|
||||||
static std::vector<gpt_vocab::id> r_instruct;
|
static std::vector<gpt_vocab::id> r_instruct;
|
||||||
size_t mem_per_token = 0;
|
|
||||||
if (!initialized) {
|
if (!initialized) {
|
||||||
gptj_eval(d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, promptCtx.logits, mem_per_token);
|
gptj_eval(d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, promptCtx.logits,
|
||||||
|
d_ptr->mem_per_token);
|
||||||
initialized = true;
|
initialized = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -726,12 +729,17 @@ void GPTJ::prompt(const std::string &prompt,
|
|||||||
|
|
||||||
// Check if the context has run out...
|
// Check if the context has run out...
|
||||||
if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) {
|
if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) {
|
||||||
// FIXME: will produce gibberish after this
|
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
|
||||||
promptCtx.n_past = std::min(promptCtx.n_past, int(promptCtx.n_ctx - batch.size()));
|
// Erase the first percentage of context from the tokens...
|
||||||
std::cerr << "GPT-J WARNING: reached the end of the context window!\n";
|
std::cerr << "GPTJ: reached the end of the context window so resizing\n";
|
||||||
|
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
||||||
|
promptCtx.n_past = promptCtx.tokens.size();
|
||||||
|
recalculateContext(promptCtx, recalculate);
|
||||||
|
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits, mem_per_token)) {
|
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits,
|
||||||
|
d_ptr->mem_per_token)) {
|
||||||
std::cerr << "GPT-J ERROR: Failed to process prompt\n";
|
std::cerr << "GPT-J ERROR: Failed to process prompt\n";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -770,13 +778,18 @@ void GPTJ::prompt(const std::string &prompt,
|
|||||||
|
|
||||||
// Check if the context has run out...
|
// Check if the context has run out...
|
||||||
if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
|
if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
|
||||||
// FIXME: will produce gibberish after this
|
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
|
||||||
promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx - 1);
|
// Erase the first percentage of context from the tokens...
|
||||||
std::cerr << "GPT-J WARNING: reached the end of the context window!\n";
|
std::cerr << "GPTJ: reached the end of the context window so resizing\n";
|
||||||
|
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
||||||
|
promptCtx.n_past = promptCtx.tokens.size();
|
||||||
|
recalculateContext(promptCtx, recalculate);
|
||||||
|
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t t_start_predict_us = ggml_time_us();
|
const int64_t t_start_predict_us = ggml_time_us();
|
||||||
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits, mem_per_token)) {
|
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits,
|
||||||
|
d_ptr->mem_per_token)) {
|
||||||
std::cerr << "GPT-J ERROR: Failed to predict next token\n";
|
std::cerr << "GPT-J ERROR: Failed to predict next token\n";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -807,3 +820,29 @@ stop_generating:
|
|||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GPTJ::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate)
|
||||||
|
{
|
||||||
|
size_t i = 0;
|
||||||
|
promptCtx.n_past = 0;
|
||||||
|
while (i < promptCtx.tokens.size()) {
|
||||||
|
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
|
||||||
|
std::vector<gpt_vocab::id> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
|
||||||
|
|
||||||
|
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
||||||
|
|
||||||
|
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits,
|
||||||
|
d_ptr->mem_per_token)) {
|
||||||
|
std::cerr << "GPTJ ERROR: Failed to process prompt\n";
|
||||||
|
goto stop_generating;
|
||||||
|
}
|
||||||
|
promptCtx.n_past += batch.size();
|
||||||
|
if (!recalculate(true))
|
||||||
|
goto stop_generating;
|
||||||
|
i = batch_end;
|
||||||
|
}
|
||||||
|
assert(promptCtx.n_past == promptCtx.tokens.size());
|
||||||
|
|
||||||
|
stop_generating:
|
||||||
|
recalculate(false);
|
||||||
|
}
|
||||||
|
5
gptj.h
5
gptj.h
@ -17,10 +17,15 @@ public:
|
|||||||
bool isModelLoaded() const override;
|
bool isModelLoaded() const override;
|
||||||
void prompt(const std::string &prompt,
|
void prompt(const std::string &prompt,
|
||||||
std::function<bool(int32_t, const std::string&)> response,
|
std::function<bool(int32_t, const std::string&)> response,
|
||||||
|
std::function<bool(bool)> recalculate,
|
||||||
PromptContext &ctx) override;
|
PromptContext &ctx) override;
|
||||||
void setThreadCount(int32_t n_threads) override;
|
void setThreadCount(int32_t n_threads) override;
|
||||||
int32_t threadCount() override;
|
int32_t threadCount() override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void recalculateContext(PromptContext &promptCtx,
|
||||||
|
std::function<bool(bool)> recalculate) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
GPTJPrivate *d_ptr;
|
GPTJPrivate *d_ptr;
|
||||||
};
|
};
|
||||||
|
@ -58,6 +58,7 @@ bool LLamaModel::loadModel(const std::string &modelPath)
|
|||||||
|
|
||||||
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||||
d_ptr->modelLoaded = true;
|
d_ptr->modelLoaded = true;
|
||||||
|
fflush(stderr);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -80,6 +81,7 @@ bool LLamaModel::isModelLoaded() const
|
|||||||
|
|
||||||
void LLamaModel::prompt(const std::string &prompt,
|
void LLamaModel::prompt(const std::string &prompt,
|
||||||
std::function<bool(int32_t, const std::string&)> response,
|
std::function<bool(int32_t, const std::string&)> response,
|
||||||
|
std::function<bool(bool)> recalculate,
|
||||||
PromptContext &promptCtx) {
|
PromptContext &promptCtx) {
|
||||||
|
|
||||||
if (!isModelLoaded()) {
|
if (!isModelLoaded()) {
|
||||||
@ -119,9 +121,13 @@ void LLamaModel::prompt(const std::string &prompt,
|
|||||||
|
|
||||||
// Check if the context has run out...
|
// Check if the context has run out...
|
||||||
if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) {
|
if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) {
|
||||||
// FIXME: will produce gibberish after this
|
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
|
||||||
promptCtx.n_past = std::min(promptCtx.n_past, int(promptCtx.n_ctx - batch.size()));
|
// Erase the first percentage of context from the tokens...
|
||||||
std::cerr << "LLAMA WARNING: reached the end of the context window!\n";
|
std::cerr << "LLAMA: reached the end of the context window so resizing\n";
|
||||||
|
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
||||||
|
promptCtx.n_past = promptCtx.tokens.size();
|
||||||
|
recalculateContext(promptCtx, recalculate);
|
||||||
|
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) {
|
if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) {
|
||||||
@ -149,9 +155,13 @@ void LLamaModel::prompt(const std::string &prompt,
|
|||||||
|
|
||||||
// Check if the context has run out...
|
// Check if the context has run out...
|
||||||
if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
|
if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
|
||||||
// FIXME: will produce gibberish after this
|
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
|
||||||
promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx - 1);
|
// Erase the first percentage of context from the tokens...
|
||||||
std::cerr << "LLAMA WARNING: reached the end of the context window!\n";
|
std::cerr << "LLAMA: reached the end of the context window so resizing\n";
|
||||||
|
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
||||||
|
promptCtx.n_past = promptCtx.tokens.size();
|
||||||
|
recalculateContext(promptCtx, recalculate);
|
||||||
|
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_eval(d_ptr->ctx, &id, 1, promptCtx.n_past, d_ptr->n_threads)) {
|
if (llama_eval(d_ptr->ctx, &id, 1, promptCtx.n_past, d_ptr->n_threads)) {
|
||||||
@ -166,3 +176,28 @@ void LLamaModel::prompt(const std::string &prompt,
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void LLamaModel::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate)
|
||||||
|
{
|
||||||
|
size_t i = 0;
|
||||||
|
promptCtx.n_past = 0;
|
||||||
|
while (i < promptCtx.tokens.size()) {
|
||||||
|
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
|
||||||
|
std::vector<llama_token> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
|
||||||
|
|
||||||
|
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
||||||
|
|
||||||
|
if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) {
|
||||||
|
std::cerr << "LLAMA ERROR: Failed to process prompt\n";
|
||||||
|
goto stop_generating;
|
||||||
|
}
|
||||||
|
promptCtx.n_past += batch.size();
|
||||||
|
if (!recalculate(true))
|
||||||
|
goto stop_generating;
|
||||||
|
i = batch_end;
|
||||||
|
}
|
||||||
|
assert(promptCtx.n_past == promptCtx.tokens.size());
|
||||||
|
|
||||||
|
stop_generating:
|
||||||
|
recalculate(false);
|
||||||
|
}
|
||||||
|
@ -17,10 +17,15 @@ public:
|
|||||||
bool isModelLoaded() const override;
|
bool isModelLoaded() const override;
|
||||||
void prompt(const std::string &prompt,
|
void prompt(const std::string &prompt,
|
||||||
std::function<bool(int32_t, const std::string&)> response,
|
std::function<bool(int32_t, const std::string&)> response,
|
||||||
|
std::function<bool(bool)> recalculate,
|
||||||
PromptContext &ctx) override;
|
PromptContext &ctx) override;
|
||||||
void setThreadCount(int32_t n_threads) override;
|
void setThreadCount(int32_t n_threads) override;
|
||||||
int32_t threadCount() override;
|
int32_t threadCount() override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void recalculateContext(PromptContext &promptCtx,
|
||||||
|
std::function<bool(bool)> recalculate) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
LLamaPrivate *d_ptr;
|
LLamaPrivate *d_ptr;
|
||||||
};
|
};
|
||||||
|
22
llm.cpp
22
llm.cpp
@ -39,6 +39,7 @@ LLMObject::LLMObject()
|
|||||||
, m_llmodel(nullptr)
|
, m_llmodel(nullptr)
|
||||||
, m_responseTokens(0)
|
, m_responseTokens(0)
|
||||||
, m_responseLogits(0)
|
, m_responseLogits(0)
|
||||||
|
, m_isRecalc(false)
|
||||||
{
|
{
|
||||||
moveToThread(&m_llmThread);
|
moveToThread(&m_llmThread);
|
||||||
connect(&m_llmThread, &QThread::started, this, &LLMObject::loadModel);
|
connect(&m_llmThread, &QThread::started, this, &LLMObject::loadModel);
|
||||||
@ -271,6 +272,15 @@ bool LLMObject::handleResponse(int32_t token, const std::string &response)
|
|||||||
return !m_stopGenerating;
|
return !m_stopGenerating;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool LLMObject::handleRecalculate(bool isRecalc)
|
||||||
|
{
|
||||||
|
if (m_isRecalc != isRecalc) {
|
||||||
|
m_isRecalc = isRecalc;
|
||||||
|
emit recalcChanged();
|
||||||
|
}
|
||||||
|
return !m_stopGenerating;
|
||||||
|
}
|
||||||
|
|
||||||
bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
|
bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
|
||||||
float temp, int32_t n_batch)
|
float temp, int32_t n_batch)
|
||||||
{
|
{
|
||||||
@ -280,7 +290,9 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in
|
|||||||
QString instructPrompt = prompt_template.arg(prompt);
|
QString instructPrompt = prompt_template.arg(prompt);
|
||||||
|
|
||||||
m_stopGenerating = false;
|
m_stopGenerating = false;
|
||||||
auto func = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1, std::placeholders::_2);
|
auto responseFunc = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1,
|
||||||
|
std::placeholders::_2);
|
||||||
|
auto recalcFunc = std::bind(&LLMObject::handleRecalculate, this, std::placeholders::_1);
|
||||||
emit responseStarted();
|
emit responseStarted();
|
||||||
qint32 logitsBefore = s_ctx.logits.size();
|
qint32 logitsBefore = s_ctx.logits.size();
|
||||||
s_ctx.n_predict = n_predict;
|
s_ctx.n_predict = n_predict;
|
||||||
@ -288,7 +300,7 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in
|
|||||||
s_ctx.top_p = top_p;
|
s_ctx.top_p = top_p;
|
||||||
s_ctx.temp = temp;
|
s_ctx.temp = temp;
|
||||||
s_ctx.n_batch = n_batch;
|
s_ctx.n_batch = n_batch;
|
||||||
m_llmodel->prompt(instructPrompt.toStdString(), func, s_ctx);
|
m_llmodel->prompt(instructPrompt.toStdString(), responseFunc, recalcFunc, s_ctx);
|
||||||
m_responseLogits += s_ctx.logits.size() - logitsBefore;
|
m_responseLogits += s_ctx.logits.size() - logitsBefore;
|
||||||
std::string trimmed = trim_whitespace(m_response);
|
std::string trimmed = trim_whitespace(m_response);
|
||||||
if (trimmed != m_response) {
|
if (trimmed != m_response) {
|
||||||
@ -314,7 +326,7 @@ LLM::LLM()
|
|||||||
connect(m_llmodel, &LLMObject::modelListChanged, this, &LLM::modelListChanged, Qt::QueuedConnection);
|
connect(m_llmodel, &LLMObject::modelListChanged, this, &LLM::modelListChanged, Qt::QueuedConnection);
|
||||||
connect(m_llmodel, &LLMObject::threadCountChanged, this, &LLM::threadCountChanged, Qt::QueuedConnection);
|
connect(m_llmodel, &LLMObject::threadCountChanged, this, &LLM::threadCountChanged, Qt::QueuedConnection);
|
||||||
connect(m_llmodel, &LLMObject::threadCountChanged, this, &LLM::syncThreadCount, Qt::QueuedConnection);
|
connect(m_llmodel, &LLMObject::threadCountChanged, this, &LLM::syncThreadCount, Qt::QueuedConnection);
|
||||||
|
connect(m_llmodel, &LLMObject::recalcChanged, this, &LLM::recalcChanged, Qt::QueuedConnection);
|
||||||
|
|
||||||
connect(this, &LLM::promptRequested, m_llmodel, &LLMObject::prompt, Qt::QueuedConnection);
|
connect(this, &LLM::promptRequested, m_llmodel, &LLMObject::prompt, Qt::QueuedConnection);
|
||||||
connect(this, &LLM::modelNameChangeRequested, m_llmodel, &LLMObject::modelNameChangeRequested, Qt::QueuedConnection);
|
connect(this, &LLM::modelNameChangeRequested, m_llmodel, &LLMObject::modelNameChangeRequested, Qt::QueuedConnection);
|
||||||
@ -428,3 +440,7 @@ bool LLM::checkForUpdates() const
|
|||||||
return QProcess::startDetached(fileName);
|
return QProcess::startDetached(fileName);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool LLM::isRecalc() const
|
||||||
|
{
|
||||||
|
return m_llmodel->isRecalc();
|
||||||
|
}
|
||||||
|
11
llm.h
11
llm.h
@ -14,6 +14,7 @@ class LLMObject : public QObject
|
|||||||
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
|
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
|
||||||
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
|
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
|
||||||
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
|
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
|
||||||
|
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
@ -33,6 +34,8 @@ public:
|
|||||||
QList<QString> modelList() const;
|
QList<QString> modelList() const;
|
||||||
void setModelName(const QString &modelName);
|
void setModelName(const QString &modelName);
|
||||||
|
|
||||||
|
bool isRecalc() const { return m_isRecalc; }
|
||||||
|
|
||||||
public Q_SLOTS:
|
public Q_SLOTS:
|
||||||
bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
|
bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
|
||||||
float temp, int32_t n_batch);
|
float temp, int32_t n_batch);
|
||||||
@ -47,10 +50,12 @@ Q_SIGNALS:
|
|||||||
void modelNameChanged();
|
void modelNameChanged();
|
||||||
void modelListChanged();
|
void modelListChanged();
|
||||||
void threadCountChanged();
|
void threadCountChanged();
|
||||||
|
void recalcChanged();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool loadModelPrivate(const QString &modelName);
|
bool loadModelPrivate(const QString &modelName);
|
||||||
bool handleResponse(int32_t token, const std::string &response);
|
bool handleResponse(int32_t token, const std::string &response);
|
||||||
|
bool handleRecalculate(bool isRecalc);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
LLModel *m_llmodel;
|
LLModel *m_llmodel;
|
||||||
@ -60,6 +65,7 @@ private:
|
|||||||
QString m_modelName;
|
QString m_modelName;
|
||||||
QThread m_llmThread;
|
QThread m_llmThread;
|
||||||
std::atomic<bool> m_stopGenerating;
|
std::atomic<bool> m_stopGenerating;
|
||||||
|
bool m_isRecalc;
|
||||||
};
|
};
|
||||||
|
|
||||||
class LLM : public QObject
|
class LLM : public QObject
|
||||||
@ -71,6 +77,8 @@ class LLM : public QObject
|
|||||||
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
|
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
|
||||||
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
|
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
|
||||||
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
|
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
|
||||||
|
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
static LLM *globalInstance();
|
static LLM *globalInstance();
|
||||||
@ -96,6 +104,8 @@ public:
|
|||||||
|
|
||||||
Q_INVOKABLE bool checkForUpdates() const;
|
Q_INVOKABLE bool checkForUpdates() const;
|
||||||
|
|
||||||
|
bool isRecalc() const;
|
||||||
|
|
||||||
Q_SIGNALS:
|
Q_SIGNALS:
|
||||||
void isModelLoadedChanged();
|
void isModelLoadedChanged();
|
||||||
void responseChanged();
|
void responseChanged();
|
||||||
@ -110,6 +120,7 @@ Q_SIGNALS:
|
|||||||
void modelListChanged();
|
void modelListChanged();
|
||||||
void threadCountChanged();
|
void threadCountChanged();
|
||||||
void setThreadCountRequested(int32_t threadCount);
|
void setThreadCountRequested(int32_t threadCount);
|
||||||
|
void recalcChanged();
|
||||||
|
|
||||||
private Q_SLOTS:
|
private Q_SLOTS:
|
||||||
void responseStarted();
|
void responseStarted();
|
||||||
|
@ -25,13 +25,19 @@ public:
|
|||||||
int32_t n_batch = 9;
|
int32_t n_batch = 9;
|
||||||
float repeat_penalty = 1.10f;
|
float repeat_penalty = 1.10f;
|
||||||
int32_t repeat_last_n = 64; // last n tokens to penalize
|
int32_t repeat_last_n = 64; // last n tokens to penalize
|
||||||
|
float contextErase = 0.75f; // percent of context to erase if we exceed the context
|
||||||
|
// window
|
||||||
};
|
};
|
||||||
virtual void prompt(const std::string &prompt,
|
virtual void prompt(const std::string &prompt,
|
||||||
std::function<bool(int32_t, const std::string&)> response,
|
std::function<bool(int32_t, const std::string&)> response,
|
||||||
|
std::function<bool(bool)> recalculate,
|
||||||
PromptContext &ctx) = 0;
|
PromptContext &ctx) = 0;
|
||||||
virtual void setThreadCount(int32_t n_threads) {}
|
virtual void setThreadCount(int32_t n_threads) {}
|
||||||
virtual int32_t threadCount() { return 1; }
|
virtual int32_t threadCount() { return 1; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
virtual void recalculateContext(PromptContext &promptCtx,
|
||||||
|
std::function<bool(bool)> recalculate) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // LLMODEL_H
|
#endif // LLMODEL_H
|
||||||
|
18
main.qml
18
main.qml
@ -288,6 +288,24 @@ Window {
|
|||||||
text: qsTr("Connection to datalake failed.")
|
text: qsTr("Connection to datalake failed.")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PopupDialog {
|
||||||
|
id: recalcPopup
|
||||||
|
anchors.centerIn: parent
|
||||||
|
shouldTimeOut: false
|
||||||
|
shouldShowBusy: true
|
||||||
|
text: qsTr("Recalculating context.")
|
||||||
|
|
||||||
|
Connections {
|
||||||
|
target: LLM
|
||||||
|
function onRecalcChanged() {
|
||||||
|
if (LLM.isRecalc)
|
||||||
|
recalcPopup.open()
|
||||||
|
else
|
||||||
|
recalcPopup.close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Button {
|
Button {
|
||||||
id: copyButton
|
id: copyButton
|
||||||
anchors.right: settingsButton.left
|
anchors.right: settingsButton.left
|
||||||
|
@ -7,23 +7,45 @@ import QtQuick.Layouts
|
|||||||
Dialog {
|
Dialog {
|
||||||
id: popupDialog
|
id: popupDialog
|
||||||
anchors.centerIn: parent
|
anchors.centerIn: parent
|
||||||
modal: false
|
|
||||||
opacity: 0.9
|
opacity: 0.9
|
||||||
padding: 20
|
padding: 20
|
||||||
property alias text: textField.text
|
property alias text: textField.text
|
||||||
|
property bool shouldTimeOut: true
|
||||||
|
property bool shouldShowBusy: false
|
||||||
|
modal: shouldShowBusy
|
||||||
|
closePolicy: shouldShowBusy ? Popup.NoAutoClose : (Popup.CloseOnEscape | Popup.CloseOnPressOutside)
|
||||||
|
|
||||||
Theme {
|
Theme {
|
||||||
id: theme
|
id: theme
|
||||||
}
|
}
|
||||||
|
|
||||||
Text {
|
Row {
|
||||||
id: textField
|
anchors.centerIn: parent
|
||||||
horizontalAlignment: Text.AlignJustify
|
width: childrenRect.width
|
||||||
color: theme.textColor
|
height: childrenRect.height
|
||||||
Accessible.role: Accessible.HelpBalloon
|
spacing: 20
|
||||||
Accessible.name: text
|
|
||||||
Accessible.description: qsTr("Reveals a shortlived help balloon")
|
Text {
|
||||||
|
id: textField
|
||||||
|
anchors.verticalCenter: busyIndicator.verticalCenter
|
||||||
|
horizontalAlignment: Text.AlignJustify
|
||||||
|
color: theme.textColor
|
||||||
|
Accessible.role: Accessible.HelpBalloon
|
||||||
|
Accessible.name: text
|
||||||
|
Accessible.description: qsTr("Reveals a shortlived help balloon")
|
||||||
|
}
|
||||||
|
|
||||||
|
BusyIndicator {
|
||||||
|
id: busyIndicator
|
||||||
|
visible: shouldShowBusy
|
||||||
|
running: shouldShowBusy
|
||||||
|
|
||||||
|
Accessible.role: Accessible.Animation
|
||||||
|
Accessible.name: qsTr("Busy indicator")
|
||||||
|
Accessible.description: qsTr("Displayed when the popup is showing busy")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
background: Rectangle {
|
background: Rectangle {
|
||||||
anchors.fill: parent
|
anchors.fill: parent
|
||||||
color: theme.backgroundDarkest
|
color: theme.backgroundDarkest
|
||||||
@ -37,7 +59,8 @@ Dialog {
|
|||||||
}
|
}
|
||||||
|
|
||||||
onOpened: {
|
onOpened: {
|
||||||
timer.start()
|
if (shouldTimeOut)
|
||||||
|
timer.start()
|
||||||
}
|
}
|
||||||
|
|
||||||
Timer {
|
Timer {
|
||||||
|
Loading…
Reference in New Issue
Block a user