From 2f3a46c17facc4eded69e5e16d6e982f18ba6158 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Sat, 15 Apr 2023 09:19:06 -0400 Subject: [PATCH] Erase the correct amount of logits when regenerating which is not the same as the number of tokens. --- llm.cpp | 6 +++++- llm.h | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/llm.cpp b/llm.cpp index 68b596b6..58766a37 100644 --- a/llm.cpp +++ b/llm.cpp @@ -20,6 +20,7 @@ LLMObject::LLMObject() : QObject{nullptr} , m_llmodel(new GPTJ) , m_responseTokens(0) + , m_responseLogits(0) { moveToThread(&m_llmThread); connect(&m_llmThread, &QThread::started, this, &LLMObject::loadModel); @@ -66,8 +67,9 @@ bool LLMObject::isModelLoaded() const void LLMObject::resetResponse() { s_ctx.n_past -= m_responseTokens; - s_ctx.logits.erase(s_ctx.logits.end() -= m_responseTokens, s_ctx.logits.end()); + s_ctx.logits.erase(s_ctx.logits.end() -= m_responseLogits, s_ctx.logits.end()); m_responseTokens = 0; + m_responseLogits = 0; m_response = std::string(); emit responseChanged(); } @@ -110,7 +112,9 @@ bool LLMObject::prompt(const QString &prompt) m_stopGenerating = false; auto func = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1); emit responseStarted(); + qint32 logitsBefore = s_ctx.logits.size(); m_llmodel->prompt(prompt.toStdString(), func, s_ctx, 4096 /*number of chars to predict*/); + m_responseLogits += s_ctx.logits.size() - logitsBefore; emit responseStopped(); return true; } diff --git a/llm.h b/llm.h index d47ab148..f5f03378 100644 --- a/llm.h +++ b/llm.h @@ -42,6 +42,7 @@ private: LLModel *m_llmodel; std::string m_response; quint32 m_responseTokens; + quint32 m_responseLogits; QString m_modelName; QThread m_llmThread; std::atomic m_stopGenerating;