From c8a590bc6f820e5bbf5994d83b71ddbfbd9669f9 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Tue, 20 Jun 2023 16:14:30 -0400 Subject: [PATCH] Get rid of last blocking operations and make the chat/llm thread safe. --- gpt4all-chat/chat.cpp | 64 ++++++++++++++++++---------------- gpt4all-chat/chat.h | 12 ++++--- gpt4all-chat/chatlistmodel.cpp | 1 - gpt4all-chat/chatlistmodel.h | 8 ----- gpt4all-chat/chatllm.cpp | 35 +++++++++---------- gpt4all-chat/chatllm.h | 12 ++----- 6 files changed, 60 insertions(+), 72 deletions(-) diff --git a/gpt4all-chat/chat.cpp b/gpt4all-chat/chat.cpp index cde4c99c..a73b829f 100644 --- a/gpt4all-chat/chat.cpp +++ b/gpt4all-chat/chat.cpp @@ -16,6 +16,7 @@ Chat::Chat(QObject *parent) , m_llmodel(new ChatLLM(this)) , m_isServer(false) , m_shouldDeleteLater(false) + , m_isModelLoaded(false) { connectLLM(); } @@ -31,6 +32,7 @@ Chat::Chat(bool isServer, QObject *parent) , m_llmodel(new Server(this)) , m_isServer(true) , m_shouldDeleteLater(false) + , m_isModelLoaded(false) { connectLLM(); } @@ -55,12 +57,10 @@ void Chat::connectLLM() connect(m_watcher, &QFileSystemWatcher::directoryChanged, this, &Chat::handleModelListChanged); // Should be in different threads - connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::handleModelLoadedChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::modelNameChanged, this, &Chat::handleModelNameChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection); @@ -73,11 +73,8 @@ void Chat::connectLLM() connect(this, &Chat::loadModelRequested, m_llmodel, &ChatLLM::loadModel, Qt::QueuedConnection); connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection); connect(this, &Chat::regenerateResponseRequested, m_llmodel, &ChatLLM::regenerateResponse, Qt::QueuedConnection); - - // The following are blocking operations and will block the gui thread, therefore must be fast - // to respond to - connect(this, &Chat::resetResponseRequested, m_llmodel, &ChatLLM::resetResponse, Qt::BlockingQueuedConnection); - connect(this, &Chat::resetContextRequested, m_llmodel, &ChatLLM::resetContext, Qt::BlockingQueuedConnection); + connect(this, &Chat::resetResponseRequested, m_llmodel, &ChatLLM::resetResponse, Qt::QueuedConnection); + connect(this, &Chat::resetContextRequested, m_llmodel, &ChatLLM::resetContext, Qt::QueuedConnection); emit defaultModelChanged(modelList().first()); } @@ -87,7 +84,7 @@ void Chat::reset() stopGenerating(); // Erase our current on disk representation as we're completely resetting the chat along with id LLM::globalInstance()->chatListModel()->removeChatFile(this); - emit resetContextRequested(); // blocking queued connection + emit resetContextRequested(); m_id = Network::globalInstance()->generateUniqueId(); emit idChanged(m_id); // NOTE: We deliberately do no reset the name or creation date to indictate that this was originally @@ -105,7 +102,7 @@ void Chat::reset() bool Chat::isModelLoaded() const { - return m_llmodel->isModelLoaded(); + return m_isModelLoaded; } void Chat::resetResponseState() @@ -154,7 +151,7 @@ void Chat::stopGenerating() QString Chat::response() const { - return m_llmodel->response(); + return m_response; } QString Chat::responseState() const @@ -170,22 +167,29 @@ QString Chat::responseState() const return QString(); } -void Chat::handleResponseChanged() +void Chat::handleResponseChanged(const QString &response) { if (m_responseState != Chat::ResponseGeneration) { m_responseState = Chat::ResponseGeneration; emit responseStateChanged(); } + m_response = response; const int index = m_chatModel->count() - 1; - m_chatModel->updateValue(index, response()); + m_chatModel->updateValue(index, this->response()); emit responseChanged(); } -void Chat::handleModelLoadedChanged() +void Chat::handleModelLoadedChanged(bool loaded) { if (m_shouldDeleteLater) deleteLater(); + + if (loaded == m_isModelLoaded) + return; + + m_isModelLoaded = loaded; + emit isModelLoadedChanged(); } void Chat::promptProcessing() @@ -241,7 +245,7 @@ void Chat::responseStopped() m_responseState = Chat::ResponseStopped; emit responseInProgressChanged(); emit responseStateChanged(); - if (m_llmodel->generatedName().isEmpty()) + if (m_generatedName.isEmpty()) emit generateNameRequested(); if (chatModel()->count() < 3) Network::globalInstance()->sendChatStarted(); @@ -249,15 +253,18 @@ void Chat::responseStopped() QString Chat::modelName() const { - return m_llmodel->modelName(); + return m_modelName; } void Chat::setModelName(const QString &modelName) { - // doesn't block but will unload old model and load new one which the gui can see through changes - // to the isModelLoaded property + if (m_modelName == modelName) + return; + m_modelLoadingError = QString(); emit modelLoadingErrorChanged(); + m_modelName = modelName; + emit modelNameChanged(); emit modelNameChangeRequested(modelName); } @@ -267,7 +274,7 @@ void Chat::newPromptResponsePair(const QString &prompt) m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false); m_chatModel->appendPrompt(tr("Prompt: "), prompt); m_chatModel->appendResponse(tr("Response: "), prompt); - emit resetResponseRequested(); // blocking queued connection + emit resetResponseRequested(); } void Chat::serverNewPromptResponsePair(const QString &prompt) @@ -320,11 +327,11 @@ void Chat::reloadModel() m_llmodel->setShouldBeLoaded(true); } -void Chat::generatedNameChanged() +void Chat::generatedNameChanged(const QString &name) { // Only use the first three words maximum and remove newlines and extra spaces - QString gen = m_llmodel->generatedName().simplified(); - QStringList words = gen.split(' ', Qt::SkipEmptyParts); + m_generatedName = name.simplified(); + QStringList words = m_generatedName.split(' ', Qt::SkipEmptyParts); int wordCount = qMin(3, words.size()); m_name = words.mid(0, wordCount).join(' '); emit nameChanged(); @@ -336,12 +343,6 @@ void Chat::handleRecalculating() emit recalcChanged(); } -void Chat::handleModelNameChanged() -{ - m_savedModelName = modelName(); - emit modelNameChanged(); -} - void Chat::handleModelLoadingError(const QString &error) { qWarning() << "ERROR:" << qPrintable(error) << "id" << id(); @@ -366,7 +367,7 @@ bool Chat::serialize(QDataStream &stream, int version) const stream << m_id; stream << m_name; stream << m_userName; - stream << m_savedModelName; + stream << m_modelName; if (version > 2) stream << m_collections; if (!m_llmodel->serialize(stream, version)) @@ -384,16 +385,17 @@ bool Chat::deserialize(QDataStream &stream, int version) stream >> m_name; stream >> m_userName; emit nameChanged(); - stream >> m_savedModelName; + stream >> m_modelName; + emit modelNameChanged(); // Prior to version 2 gptj models had a bug that fixed the kv_cache to F32 instead of F16 so // unfortunately, we cannot deserialize these - if (version < 2 && m_savedModelName.contains("gpt4all-j")) + if (version < 2 && m_modelName.contains("gpt4all-j")) return false; if (version > 2) { stream >> m_collections; emit collectionListChanged(m_collections); } - m_llmodel->setModelName(m_savedModelName); + m_llmodel->setModelName(m_modelName); if (!m_llmodel->deserialize(stream, version)) return false; if (!m_chatModel->deserialize(stream, version)) diff --git a/gpt4all-chat/chat.h b/gpt4all-chat/chat.h index 6067ac01..50bf4e2d 100644 --- a/gpt4all-chat/chat.h +++ b/gpt4all-chat/chat.h @@ -125,13 +125,12 @@ Q_SIGNALS: void defaultModelChanged(const QString &defaultModel); private Q_SLOTS: - void handleResponseChanged(); - void handleModelLoadedChanged(); + void handleResponseChanged(const QString &response); + void handleModelLoadedChanged(bool); void promptProcessing(); void responseStopped(); - void generatedNameChanged(); + void generatedNameChanged(const QString &name); void handleRecalculating(); - void handleModelNameChanged(); void handleModelLoadingError(const QString &error); void handleTokenSpeedChanged(const QString &tokenSpeed); void handleDatabaseResultsChanged(const QList &results); @@ -141,10 +140,12 @@ private Q_SLOTS: private: QString m_id; QString m_name; + QString m_generatedName; QString m_userName; - QString m_savedModelName; + QString m_modelName; QString m_modelLoadingError; QString m_tokenSpeed; + QString m_response; QList m_collections; ChatModel *m_chatModel; bool m_responseInProgress; @@ -154,6 +155,7 @@ private: QList m_databaseResults; bool m_isServer; bool m_shouldDeleteLater; + bool m_isModelLoaded; QFileSystemWatcher *m_watcher; }; diff --git a/gpt4all-chat/chatlistmodel.cpp b/gpt4all-chat/chatlistmodel.cpp index 258b67c1..510d9dbc 100644 --- a/gpt4all-chat/chatlistmodel.cpp +++ b/gpt4all-chat/chatlistmodel.cpp @@ -233,7 +233,6 @@ void ChatListModel::restoreChat(Chat *chat) { chat->setParent(this); connect(chat, &Chat::nameChanged, this, &ChatListModel::nameChanged); - connect(chat, &Chat::modelLoadingErrorChanged, this, &ChatListModel::handleModelLoadingError); if (m_dummyChat) { beginResetModel(); diff --git a/gpt4all-chat/chatlistmodel.h b/gpt4all-chat/chatlistmodel.h index bc38a414..a66aa3fe 100644 --- a/gpt4all-chat/chatlistmodel.h +++ b/gpt4all-chat/chatlistmodel.h @@ -122,8 +122,6 @@ public: this, &ChatListModel::newChatCountChanged); connect(m_newChat, &Chat::nameChanged, this, &ChatListModel::nameChanged); - connect(m_newChat, &Chat::modelLoadingError, - this, &ChatListModel::handleModelLoadingError); setCurrentChat(m_newChat); } @@ -227,12 +225,6 @@ private Q_SLOTS: emit dataChanged(index, index, {NameRole}); } - void handleModelLoadingError() - { - Chat *chat = qobject_cast(sender()); - removeChat(chat); - } - void printChats() { for (auto c : m_chats) { diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 521edfd2..656b29df 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -11,7 +11,6 @@ #include #include #include -#include //#define DEBUG //#define DEBUG_MODEL_LOADING @@ -154,7 +153,7 @@ bool ChatLLM::loadModel(const QString &modelName) // to provide an overview of what we're doing here. // We're already loaded with this model - if (isModelLoaded() && m_modelName == modelName) + if (isModelLoaded() && this->modelName() == modelName) return true; bool isChatGPT = modelName.startsWith("chatgpt-"); @@ -170,7 +169,7 @@ bool ChatLLM::loadModel(const QString &modelName) #endif delete m_modelInfo.model; m_modelInfo.model = nullptr; - emit isModelLoadedChanged(); + emit isModelLoadedChanged(false); } else if (!m_isServer) { // This is a blocking call that tries to retrieve the model we need from the model store. // If it succeeds, then we just have to restore state. If the store has never had a model @@ -188,7 +187,7 @@ bool ChatLLM::loadModel(const QString &modelName) #endif LLModelStore::globalInstance()->releaseModel(m_modelInfo); m_modelInfo = LLModelInfo(); - emit isModelLoadedChanged(); + emit isModelLoadedChanged(false); return false; } @@ -198,7 +197,7 @@ bool ChatLLM::loadModel(const QString &modelName) qDebug() << "store had our model" << m_llmThread.objectName() << m_modelInfo.model; #endif restoreState(); - emit isModelLoadedChanged(); + emit isModelLoadedChanged(true); return true; } else { // Release the memory since we have to switch to a different model. @@ -273,7 +272,7 @@ bool ChatLLM::loadModel(const QString &modelName) qDebug() << "modelLoadedChanged" << m_llmThread.objectName(); fflush(stdout); #endif - emit isModelLoadedChanged(); + emit isModelLoadedChanged(isModelLoaded()); static bool isFirstLoad = true; if (isFirstLoad) { @@ -316,7 +315,7 @@ void ChatLLM::regenerateResponse() m_promptResponseTokens = 0; m_promptTokens = 0; m_response = std::string(); - emit responseChanged(); + emit responseChanged(QString::fromStdString(m_response)); } void ChatLLM::resetResponse() @@ -324,7 +323,7 @@ void ChatLLM::resetResponse() m_promptTokens = 0; m_promptResponseTokens = 0; m_response = std::string(); - emit responseChanged(); + emit responseChanged(QString::fromStdString(m_response)); } void ChatLLM::resetContext() @@ -397,7 +396,7 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response) // check for error if (token < 0) { m_response.append(response); - emit responseChanged(); + emit responseChanged(QString::fromStdString(m_response)); return false; } @@ -407,7 +406,7 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response) m_timer->inc(); Q_ASSERT(!response.empty()); m_response.append(response); - emit responseChanged(); + emit responseChanged(QString::fromStdString(m_response)); return !m_stopGenerating; } @@ -470,7 +469,7 @@ bool ChatLLM::prompt(const QList &collectionList, const QString &prompt std::string trimmed = trim_whitespace(m_response); if (trimmed != m_response) { m_response = trimmed; - emit responseChanged(); + emit responseChanged(QString::fromStdString(m_response)); } emit responseStopped(); return true; @@ -510,7 +509,7 @@ void ChatLLM::unloadModel() #endif LLModelStore::globalInstance()->releaseModel(m_modelInfo); m_modelInfo = LLModelInfo(); - emit isModelLoadedChanged(); + emit isModelLoadedChanged(false); } void ChatLLM::reloadModel() @@ -521,11 +520,11 @@ void ChatLLM::reloadModel() #if defined(DEBUG_MODEL_LOADING) qDebug() << "reloadModel" << m_llmThread.objectName() << m_modelInfo.model; #endif - if (m_modelName.isEmpty()) { + const QString m = modelName(); + if (m.isEmpty()) loadDefaultModel(); - } else { - loadModel(m_modelName); - } + else + loadModel(m); } void ChatLLM::generateName() @@ -554,7 +553,7 @@ void ChatLLM::generateName() std::string trimmed = trim_whitespace(m_nameResponse); if (trimmed != m_nameResponse) { m_nameResponse = trimmed; - emit generatedNameChanged(); + emit generatedNameChanged(QString::fromStdString(m_nameResponse)); } } @@ -580,7 +579,7 @@ bool ChatLLM::handleNameResponse(int32_t token, const std::string &response) Q_UNUSED(token); m_nameResponse.append(response); - emit generatedNameChanged(); + emit generatedNameChanged(QString::fromStdString(m_nameResponse)); QString gen = QString::fromStdString(m_nameResponse).simplified(); QStringList words = gen.split(' ', Qt::SkipEmptyParts); return words.size() <= 3; diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index a87c9067..86cbc6ed 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -116,16 +116,16 @@ public Q_SLOTS: void handleThreadStarted(); Q_SIGNALS: - void isModelLoadedChanged(); + void isModelLoadedChanged(bool); void modelLoadingError(const QString &error); - void responseChanged(); + void responseChanged(const QString &response); void promptProcessing(); void responseStopped(); void modelNameChanged(); void recalcChanged(); void sendStartup(); void sendModelLoaded(); - void generatedNameChanged(); + void generatedNameChanged(const QString &name); void stateChanged(); void threadStarted(); void shouldBeLoadedChanged(); @@ -144,22 +144,16 @@ protected: void restoreState(); protected: - // The following are all accessed by multiple threads and are thus guarded with thread protection - // mechanisms LLModel::PromptContext m_ctx; quint32 m_promptTokens; quint32 m_promptResponseTokens; private: - // The following are all accessed by multiple threads and are thus guarded with thread protection - // mechanisms std::string m_response; std::string m_nameResponse; LLModelInfo m_modelInfo; LLModelType m_modelType; QString m_modelName; - - // The following are only accessed by this thread QString m_defaultModel; TokenTimer *m_timer; QByteArray m_state;