From 2b6cc99a31a124f1f27f2dc6515b94b84d35b254 Mon Sep 17 00:00:00 2001 From: AT Date: Mon, 19 Jun 2023 11:34:53 -0700 Subject: [PATCH] Show token generation speed in gui. (#1020) --- gpt4all-chat/chat.cpp | 12 +++++++++++ gpt4all-chat/chat.h | 6 ++++++ gpt4all-chat/chatllm.cpp | 14 ++++++++++++- gpt4all-chat/chatllm.h | 44 +++++++++++++++++++++++++++++++++++++++- gpt4all-chat/main.qml | 10 +++++++++ 5 files changed, 84 insertions(+), 2 deletions(-) diff --git a/gpt4all-chat/chat.cpp b/gpt4all-chat/chat.cpp index 397c351b..9c54afb3 100644 --- a/gpt4all-chat/chat.cpp +++ b/gpt4all-chat/chat.cpp @@ -57,6 +57,7 @@ void Chat::connectLLM() connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection); connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection); connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection); @@ -102,6 +103,8 @@ void Chat::resetResponseState() if (m_responseInProgress && m_responseState == Chat::LocalDocsRetrieval) return; + m_tokenSpeed = QString(); + emit tokenSpeedChanged(); m_responseInProgress = true; m_responseState = Chat::LocalDocsRetrieval; emit responseInProgressChanged(); @@ -187,6 +190,9 @@ void Chat::promptProcessing() void Chat::responseStopped() { + m_tokenSpeed = QString(); + emit tokenSpeedChanged(); + const QString chatResponse = response(); QList references; QList referencesContext; @@ -336,6 +342,12 @@ void Chat::handleModelLoadingError(const QString &error) emit modelLoadingErrorChanged(); } +void Chat::handleTokenSpeedChanged(const QString &tokenSpeed) +{ + m_tokenSpeed = tokenSpeed; + emit tokenSpeedChanged(); +} + bool Chat::serialize(QDataStream &stream, int version) const { stream << m_creationDate; diff --git a/gpt4all-chat/chat.h b/gpt4all-chat/chat.h index 7d6ea593..71c2f761 100644 --- a/gpt4all-chat/chat.h +++ b/gpt4all-chat/chat.h @@ -25,6 +25,7 @@ class Chat : public QObject Q_PROPERTY(QString responseState READ responseState NOTIFY responseStateChanged) Q_PROPERTY(QList collectionList READ collectionList NOTIFY collectionListChanged) Q_PROPERTY(QString modelLoadingError READ modelLoadingError NOTIFY modelLoadingErrorChanged) + Q_PROPERTY(QString tokenSpeed READ tokenSpeed NOTIFY tokenSpeedChanged); QML_ELEMENT QML_UNCREATABLE("Only creatable from c++!") @@ -91,6 +92,8 @@ public: QString modelLoadingError() const { return m_modelLoadingError; } + QString tokenSpeed() const { return m_tokenSpeed; } + public Q_SLOTS: void serverNewPromptResponsePair(const QString &prompt); @@ -118,6 +121,7 @@ Q_SIGNALS: void modelLoadingErrorChanged(); void isServerChanged(); void collectionListChanged(); + void tokenSpeedChanged(); private Q_SLOTS: void handleResponseChanged(); @@ -128,6 +132,7 @@ private Q_SLOTS: void handleRecalculating(); void handleModelNameChanged(); void handleModelLoadingError(const QString &error); + void handleTokenSpeedChanged(const QString &tokenSpeed); private: QString m_id; @@ -135,6 +140,7 @@ private: QString m_userName; QString m_savedModelName; QString m_modelLoadingError; + QString m_tokenSpeed; QList m_collections; ChatModel *m_chatModel; bool m_responseInProgress; diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 6b86663d..92d14376 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -94,6 +94,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer) , m_responseLogits(0) , m_isRecalc(false) , m_chat(parent) + , m_timer(nullptr) , m_isServer(isServer) , m_isChatGPT(false) { @@ -103,7 +104,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer) connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged, Qt::QueuedConnection); // explicitly queued connect(m_chat, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged); - connect(&m_llmThread, &QThread::started, this, &ChatLLM::threadStarted); + connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted); // The following are blocking operations and will block the llm thread connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB, @@ -126,6 +127,13 @@ ChatLLM::~ChatLLM() } } +void ChatLLM::handleThreadStarted() +{ + m_timer = new TokenTimer(this); + connect(m_timer, &TokenTimer::report, this, &ChatLLM::reportSpeed); + emit threadStarted(); +} + bool ChatLLM::loadDefaultModel() { const QList models = m_chat->modelList(); @@ -367,6 +375,7 @@ bool ChatLLM::handlePrompt(int32_t token) #endif ++m_promptTokens; ++m_promptResponseTokens; + m_timer->inc(); return !m_stopGenerating; } @@ -387,6 +396,7 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response) // m_promptResponseTokens and m_responseLogits are related to last prompt/response not // the entire context window which we can reset on regenerate prompt ++m_promptResponseTokens; + m_timer->inc(); Q_ASSERT(!response.empty()); m_response.append(response); emit responseChanged(); @@ -441,11 +451,13 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3 printf("%s", qPrintable(instructPrompt)); fflush(stdout); #endif + m_timer->start(); m_modelInfo.model->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx); #if defined(DEBUG) printf("\n"); fflush(stdout); #endif + m_timer->stop(); m_responseLogits += m_ctx.logits.size() - logitsBefore; std::string trimmed = trim_whitespace(m_response); if (trimmed != m_response) { diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index b00b316f..dd34f3f4 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -23,6 +23,46 @@ struct LLModelInfo { // must be able to serialize the information even if it is in the unloaded state }; +class TokenTimer : public QObject { + Q_OBJECT +public: + explicit TokenTimer(QObject *parent) + : QObject(parent) + , m_elapsed(0) {} + + static int rollingAverage(int oldAvg, int newNumber, int n) + { + // i.e. to calculate the new average after then nth number, + // you multiply the old average by n−1, add the new number, and divide the total by n. + return qRound(((float(oldAvg) * (n - 1)) + newNumber) / float(n)); + } + + void start() { m_tokens = 0; m_elapsed = 0; m_time.invalidate(); } + void stop() { handleTimeout(); } + void inc() { + if (!m_time.isValid()) + m_time.start(); + ++m_tokens; + if (m_time.elapsed() > 999) + handleTimeout(); + } + +Q_SIGNALS: + void report(const QString &speed); + +private Q_SLOTS: + void handleTimeout() + { + m_elapsed += m_time.restart(); + emit report(QString("%1 tokens/sec").arg(m_tokens / float(m_elapsed / 1000.0f), 0, 'g', 2)); + } + +private: + QElapsedTimer m_time; + qint64 m_elapsed; + quint32 m_tokens; +}; + class Chat; class ChatLLM : public QObject { @@ -73,6 +113,7 @@ public Q_SLOTS: void generateName(); void handleChatIdChanged(); void handleShouldBeLoadedChanged(); + void handleThreadStarted(); Q_SIGNALS: void isModelLoadedChanged(); @@ -89,7 +130,7 @@ Q_SIGNALS: void threadStarted(); void shouldBeLoadedChanged(); void requestRetrieveFromDB(const QList &collections, const QString &text, int retrievalSize, QList *results); - + void reportSpeed(const QString &speed); protected: bool handlePrompt(int32_t token); @@ -112,6 +153,7 @@ protected: quint32 m_responseLogits; QString m_modelName; Chat *m_chat; + TokenTimer *m_timer; QByteArray m_state; QThread m_llmThread; std::atomic m_stopGenerating; diff --git a/gpt4all-chat/main.qml b/gpt4all-chat/main.qml index 5946073b..20d4a7ed 100644 --- a/gpt4all-chat/main.qml +++ b/gpt4all-chat/main.qml @@ -845,6 +845,16 @@ Window { Accessible.description: qsTr("Controls generation of the response") } + Text { + id: speed + anchors.bottom: textInputView.top + anchors.bottomMargin: 20 + anchors.right: parent.right + anchors.rightMargin: 30 + color: theme.mutedTextColor + text: currentChat.tokenSpeed + } + RectangularGlow { id: effect anchors.fill: textInputView