From c800291e7f4f82c6e4eed3cc27642bce55bc1046 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Sat, 20 May 2023 20:04:36 -0400 Subject: [PATCH] Add prompt processing and localdocs to the busy indicator in UI. --- gpt4all-chat/chat.cpp | 41 +++++++++++++++++++++++++++++++++++----- gpt4all-chat/chat.h | 16 +++++++++++++++- gpt4all-chat/chatllm.cpp | 2 +- gpt4all-chat/chatllm.h | 2 +- gpt4all-chat/main.qml | 24 +++++++++++++++++------ 5 files changed, 71 insertions(+), 14 deletions(-) diff --git a/gpt4all-chat/chat.cpp b/gpt4all-chat/chat.cpp index ac714a7c..9c9e9e14 100644 --- a/gpt4all-chat/chat.cpp +++ b/gpt4all-chat/chat.cpp @@ -10,10 +10,12 @@ Chat::Chat(QObject *parent) , m_name(tr("New Chat")) , m_chatModel(new ChatModel(this)) , m_responseInProgress(false) + , m_responseState(Chat::ResponseStopped) , m_creationDate(QDateTime::currentSecsSinceEpoch()) , m_llmodel(new ChatLLM(this)) , m_isServer(false) , m_shouldDeleteLater(false) + , m_contextContainsLocalDocs(false) { connectLLM(); } @@ -24,10 +26,12 @@ Chat::Chat(bool isServer, QObject *parent) , m_name(tr("Server Chat")) , m_chatModel(new ChatModel(this)) , m_responseInProgress(false) + , m_responseState(Chat::ResponseStopped) , m_creationDate(QDateTime::currentSecsSinceEpoch()) , m_llmodel(new Server(this)) , m_isServer(true) , m_shouldDeleteLater(false) + , m_contextContainsLocalDocs(false) { connectLLM(); } @@ -49,7 +53,7 @@ void Chat::connectLLM() connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::handleModelLoadedChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::responseStarted, this, &Chat::responseStarted, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelNameChanged, this, &Chat::handleModelNameChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::modelLoadingError, Qt::QueuedConnection); @@ -99,6 +103,11 @@ void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens) { + m_contextContainsLocalDocs = false; + m_responseInProgress = true; + m_responseState = Chat::LocalDocsRetrieval; + emit responseInProgressChanged(); + emit responseStateChanged(); m_queuedPrompt.prompt = prompt; m_queuedPrompt.prompt_template = prompt_template; m_queuedPrompt.n_predict = n_predict; @@ -116,8 +125,9 @@ void Chat::handleLocalDocsRetrieved() QList results = LocalDocs::globalInstance()->result(); if (!results.isEmpty()) { augmentedTemplate.append("### Context:"); - augmentedTemplate.append(results); + augmentedTemplate.append(results.join("\n\n")); } + m_contextContainsLocalDocs = !results.isEmpty(); augmentedTemplate.append(m_queuedPrompt.prompt_template); emit promptRequested( m_queuedPrompt.prompt, @@ -148,8 +158,26 @@ QString Chat::response() const return m_llmodel->response(); } +QString Chat::responseState() const +{ + switch (m_responseState) { + case ResponseStopped: return QStringLiteral("response stopped"); + case LocalDocsRetrieval: return QStringLiteral("retrieving localdocs"); + case LocalDocsProcessing: return QStringLiteral("processing localdocs"); + case PromptProcessing: return QStringLiteral("processing prompt"); + case ResponseGeneration: return QStringLiteral("generating response"); + }; + Q_UNREACHABLE(); + return QString(); +} + void Chat::handleResponseChanged() { + if (m_responseState != Chat::ResponseGeneration) { + m_responseState = Chat::ResponseGeneration; + emit responseStateChanged(); + } + const int index = m_chatModel->count() - 1; m_chatModel->updateValue(index, response()); emit responseChanged(); @@ -161,16 +189,19 @@ void Chat::handleModelLoadedChanged() deleteLater(); } -void Chat::responseStarted() +void Chat::promptProcessing() { - m_responseInProgress = true; - emit responseInProgressChanged(); + m_responseState = m_contextContainsLocalDocs ? Chat::LocalDocsProcessing : Chat::PromptProcessing; + emit responseStateChanged(); } void Chat::responseStopped() { + m_contextContainsLocalDocs = false; m_responseInProgress = false; + m_responseState = Chat::ResponseStopped; emit responseInProgressChanged(); + emit responseStateChanged(); if (m_llmodel->generatedName().isEmpty()) emit generateNameRequested(); if (chatModel()->count() < 3) diff --git a/gpt4all-chat/chat.h b/gpt4all-chat/chat.h index b4f30015..0f67ecba 100644 --- a/gpt4all-chat/chat.h +++ b/gpt4all-chat/chat.h @@ -22,10 +22,20 @@ class Chat : public QObject Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged) Q_PROPERTY(QList modelList READ modelList NOTIFY modelListChanged) Q_PROPERTY(bool isServer READ isServer NOTIFY isServerChanged) + Q_PROPERTY(QString responseState READ responseState NOTIFY responseStateChanged) QML_ELEMENT QML_UNCREATABLE("Only creatable from c++!") public: + enum ResponseState { + ResponseStopped, + LocalDocsRetrieval, + LocalDocsProcessing, + PromptProcessing, + ResponseGeneration + }; + Q_ENUM(ResponseState) + explicit Chat(QObject *parent = nullptr); explicit Chat(bool isServer, QObject *parent = nullptr); virtual ~Chat(); @@ -50,6 +60,7 @@ public: QString response() const; bool responseInProgress() const { return m_responseInProgress; } + QString responseState() const; QString modelName() const; void setModelName(const QString &modelName); bool isRecalc() const; @@ -77,6 +88,7 @@ Q_SIGNALS: void isModelLoadedChanged(); void responseChanged(); void responseInProgressChanged(); + void responseStateChanged(); void promptRequested(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, int32_t n_threads); @@ -97,7 +109,7 @@ private Q_SLOTS: void handleLocalDocsRetrieved(); void handleResponseChanged(); void handleModelLoadedChanged(); - void responseStarted(); + void promptProcessing(); void responseStopped(); void generatedNameChanged(); void handleRecalculating(); @@ -122,10 +134,12 @@ private: QString m_savedModelName; ChatModel *m_chatModel; bool m_responseInProgress; + ResponseState m_responseState; qint64 m_creationDate; ChatLLM *m_llmodel; bool m_isServer; bool m_shouldDeleteLater; + bool m_contextContainsLocalDocs; Prompt m_queuedPrompt; }; diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 481a2256..1073af66 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -411,7 +411,7 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3 auto responseFunc = std::bind(&ChatLLM::handleResponse, this, std::placeholders::_1, std::placeholders::_2); auto recalcFunc = std::bind(&ChatLLM::handleRecalculate, this, std::placeholders::_1); - emit responseStarted(); + emit promptProcessing(); qint32 logitsBefore = m_ctx.logits.size(); m_ctx.n_predict = n_predict; m_ctx.top_k = top_k; diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index d5fb1f3b..bf617f2d 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -75,7 +75,7 @@ Q_SIGNALS: void isModelLoadedChanged(); void modelLoadingError(const QString &error); void responseChanged(); - void responseStarted(); + void promptProcessing(); void responseStopped(); void modelNameChanged(); void recalcChanged(); diff --git a/gpt4all-chat/main.qml b/gpt4all-chat/main.qml index cbcfe845..dc94e7b6 100644 --- a/gpt4all-chat/main.qml +++ b/gpt4all-chat/main.qml @@ -577,17 +577,29 @@ Window { leftPadding: 100 rightPadding: 100 - BusyIndicator { + Item { anchors.left: parent.left anchors.leftMargin: 90 anchors.top: parent.top anchors.topMargin: 5 visible: (currentResponse ? true : false) && value === "" && currentChat.responseInProgress - running: (currentResponse ? true : false) && value === "" && currentChat.responseInProgress - - Accessible.role: Accessible.Animation - Accessible.name: qsTr("Busy indicator") - Accessible.description: qsTr("Displayed when the model is thinking") + width: childrenRect.width + height: childrenRect.height + Row { + spacing: 5 + BusyIndicator { + anchors.verticalCenter: parent.verticalCenter + running: (currentResponse ? true : false) && value === "" && currentChat.responseInProgress + Accessible.role: Accessible.Animation + Accessible.name: qsTr("Busy indicator") + Accessible.description: qsTr("Displayed when the model is thinking") + } + Label { + anchors.verticalCenter: parent.verticalCenter + text: currentChat.responseState + "..." + color: theme.mutedTextColor + } + } } Rectangle {