From 15ae0a44414be7b38ad7f992ab4de1d836174b2d Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Mon, 17 Apr 2023 14:11:41 -0400 Subject: [PATCH] Fix the context. --- llm.cpp | 19 ++++++++++++++++--- llm.h | 5 ++++- main.qml | 2 +- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/llm.cpp b/llm.cpp index 1d904511..db13807a 100644 --- a/llm.cpp +++ b/llm.cpp @@ -64,7 +64,7 @@ bool LLMObject::isModelLoaded() const return m_llmodel->isModelLoaded(); } -void LLMObject::resetResponse() +void LLMObject::regenerateResponse() { s_ctx.n_past -= m_responseTokens; s_ctx.n_past = std::max(0, s_ctx.n_past); @@ -75,9 +75,17 @@ void LLMObject::resetResponse() emit responseChanged(); } +void LLMObject::resetResponse() +{ + m_responseTokens = 0; + m_responseLogits = 0; + m_response = std::string(); + emit responseChanged(); +} + void LLMObject::resetContext() { - resetResponse(); + regenerateResponse(); s_ctx = LLModel::PromptContext(); } @@ -142,7 +150,6 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in auto func = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1); emit responseStarted(); qint32 logitsBefore = s_ctx.logits.size(); - qInfo() << instructPrompt << "\n"; m_llmodel->prompt(instructPrompt.toStdString(), func, s_ctx, n_predict, top_k, top_p, temp, n_batch); m_responseLogits += s_ctx.logits.size() - logitsBefore; std::string trimmed = trim_whitespace(m_response); @@ -167,6 +174,7 @@ LLM::LLM() connect(m_llmodel, &LLMObject::modelNameChanged, this, &LLM::modelNameChanged, Qt::QueuedConnection); connect(this, &LLM::promptRequested, m_llmodel, &LLMObject::prompt, Qt::QueuedConnection); + connect(this, &LLM::regenerateResponseRequested, m_llmodel, &LLMObject::regenerateResponse, Qt::BlockingQueuedConnection); connect(this, &LLM::resetResponseRequested, m_llmodel, &LLMObject::resetResponse, Qt::BlockingQueuedConnection); connect(this, &LLM::resetContextRequested, m_llmodel, &LLMObject::resetContext, Qt::BlockingQueuedConnection); } @@ -182,6 +190,11 @@ void LLM::prompt(const QString &prompt, const QString &prompt_template, int32_t emit promptRequested(prompt, prompt_template, n_predict, top_k, top_p, temp, n_batch); } +void LLM::regenerateResponse() +{ + emit regenerateResponseRequested(); // blocking queued connection +} + void LLM::resetResponse() { emit resetResponseRequested(); // blocking queued connection diff --git a/llm.h b/llm.h index 03db9f9a..33aa95c6 100644 --- a/llm.h +++ b/llm.h @@ -18,6 +18,7 @@ public: bool loadModel(); bool isModelLoaded() const; + void regenerateResponse(); void resetResponse(); void resetContext(); void stopGenerating() { m_stopGenerating = true; } @@ -63,8 +64,9 @@ public: Q_INVOKABLE bool isModelLoaded() const; Q_INVOKABLE void prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch); - Q_INVOKABLE void resetContext(); + Q_INVOKABLE void regenerateResponse(); Q_INVOKABLE void resetResponse(); + Q_INVOKABLE void resetContext(); Q_INVOKABLE void stopGenerating(); QString response() const; @@ -80,6 +82,7 @@ Q_SIGNALS: void responseInProgressChanged(); void promptRequested(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch); + void regenerateResponseRequested(); void resetResponseRequested(); void resetContextRequested(); void modelNameChanged(); diff --git a/main.qml b/main.qml index a1d2be79..18b48af1 100644 --- a/main.qml +++ b/main.qml @@ -666,7 +666,7 @@ Window { if (LLM.responseInProgress) LLM.stopGenerating() else { - LLM.resetResponse() + LLM.regenerateResponse() if (chatModel.count) { var listElement = chatModel.get(chatModel.count - 1) if (listElement.name === qsTr("Response: ")) {