diff --git a/gpt4all-chat/chat.cpp b/gpt4all-chat/chat.cpp index 6e99c9e2..6cf637d2 100644 --- a/gpt4all-chat/chat.cpp +++ b/gpt4all-chat/chat.cpp @@ -12,6 +12,7 @@ Chat::Chat(QObject *parent) , m_creationDate(QDateTime::currentSecsSinceEpoch()) , m_llmodel(new ChatLLM(this)) , m_isServer(false) + , m_shouldDeleteLater(false) { connectLLM(); } @@ -25,6 +26,7 @@ Chat::Chat(bool isServer, QObject *parent) , m_creationDate(QDateTime::currentSecsSinceEpoch()) , m_llmodel(new Server(this)) , m_isServer(true) + , m_shouldDeleteLater(false) { connectLLM(); } @@ -43,6 +45,7 @@ void Chat::connectLLM() // Should be in different threads connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::handleModelLoadedChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseStarted, this, &Chat::responseStarted, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection); @@ -55,8 +58,6 @@ void Chat::connectLLM() connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection); connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &ChatLLM::loadDefaultModel, Qt::QueuedConnection); connect(this, &Chat::loadModelRequested, m_llmodel, &ChatLLM::loadModel, Qt::QueuedConnection); - connect(this, &Chat::unloadModelRequested, m_llmodel, &ChatLLM::unloadModel, Qt::QueuedConnection); - connect(this, &Chat::reloadModelRequested, m_llmodel, &ChatLLM::reloadModel, Qt::QueuedConnection); connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection); // The following are blocking operations and will block the gui thread, therefore must be fast @@ -122,6 +123,12 @@ void Chat::handleResponseChanged() emit responseChanged(); } +void Chat::handleModelLoadedChanged() +{ + if (m_shouldDeleteLater) + deleteLater(); +} + void Chat::responseStarted() { m_responseInProgress = true; @@ -180,15 +187,26 @@ void Chat::loadModel(const QString &modelName) emit loadModelRequested(modelName); } +void Chat::unloadAndDeleteLater() +{ + if (!isModelLoaded()) { + deleteLater(); + return; + } + + m_shouldDeleteLater = true; + unloadModel(); +} + void Chat::unloadModel() { stopGenerating(); - emit unloadModelRequested(); + m_llmodel->setShouldBeLoaded(false); } void Chat::reloadModel() { - emit reloadModelRequested(m_savedModelName); + m_llmodel->setShouldBeLoaded(true); } void Chat::generatedNameChanged() @@ -236,12 +254,10 @@ bool Chat::deserialize(QDataStream &stream, int version) stream >> m_userName; emit nameChanged(); stream >> m_savedModelName; - // Prior to version 2 gptj models had a bug that fixed the kv_cache to F32 instead of F16 so // unfortunately, we cannot deserialize these if (version < 2 && m_savedModelName.contains("gpt4all-j")) return false; - if (!m_llmodel->deserialize(stream, version)) return false; if (!m_chatModel->deserialize(stream, version)) diff --git a/gpt4all-chat/chat.h b/gpt4all-chat/chat.h index b3275caf..974378bc 100644 --- a/gpt4all-chat/chat.h +++ b/gpt4all-chat/chat.h @@ -58,6 +58,7 @@ public: void loadModel(const QString &modelName); void unloadModel(); void reloadModel(); + void unloadAndDeleteLater(); qint64 creationDate() const { return m_creationDate; } bool serialize(QDataStream &stream, int version) const; @@ -87,8 +88,6 @@ Q_SIGNALS: void recalcChanged(); void loadDefaultModelRequested(); void loadModelRequested(const QString &modelName); - void unloadModelRequested(); - void reloadModelRequested(const QString &modelName); void generateNameRequested(); void modelListChanged(); void modelLoadingError(const QString &error); @@ -96,6 +95,7 @@ Q_SIGNALS: private Q_SLOTS: void handleResponseChanged(); + void handleModelLoadedChanged(); void responseStarted(); void responseStopped(); void generatedNameChanged(); @@ -112,6 +112,7 @@ private: qint64 m_creationDate; ChatLLM *m_llmodel; bool m_isServer; + bool m_shouldDeleteLater; }; #endif // CHAT_H diff --git a/gpt4all-chat/chatlistmodel.cpp b/gpt4all-chat/chatlistmodel.cpp index a0fe17f6..a931c8ec 100644 --- a/gpt4all-chat/chatlistmodel.cpp +++ b/gpt4all-chat/chatlistmodel.cpp @@ -40,6 +40,7 @@ void ChatListModel::setShouldSaveChats(bool b) void ChatListModel::removeChatFile(Chat *chat) const { + Q_ASSERT(chat != m_serverChat); const QString savePath = Download::globalInstance()->downloadLocalModelsPath(); QFile file(savePath + "/gpt4all-" + chat->id() + ".chat"); if (!file.exists()) @@ -58,6 +59,8 @@ void ChatListModel::saveChats() const timer.start(); const QString savePath = Download::globalInstance()->downloadLocalModelsPath(); for (Chat *chat : m_chats) { + if (chat == m_serverChat) + continue; QString fileName = "gpt4all-" + chat->id() + ".chat"; QFile file(savePath + "/" + fileName); bool success = file.open(QIODevice::WriteOnly); diff --git a/gpt4all-chat/chatlistmodel.h b/gpt4all-chat/chatlistmodel.h index e8858b60..8c60b7a9 100644 --- a/gpt4all-chat/chatlistmodel.h +++ b/gpt4all-chat/chatlistmodel.h @@ -125,6 +125,7 @@ public: Q_INVOKABLE void removeChat(Chat* chat) { + Q_ASSERT(chat != m_serverChat); if (!m_chats.contains(chat)) { qWarning() << "WARNING: Removing chat failed with id" << chat->id(); return; @@ -138,11 +139,11 @@ public: } const int index = m_chats.indexOf(chat); - if (m_chats.count() < 2) { + if (m_chats.count() < 3 /*m_serverChat included*/) { addChat(); } else { int nextIndex; - if (index == m_chats.count() - 1) + if (index == m_chats.count() - 2 /*m_serverChat is last*/) nextIndex = index - 1; else nextIndex = index + 1; @@ -155,7 +156,7 @@ public: beginRemoveRows(QModelIndex(), newIndex, newIndex); m_chats.removeAll(chat); endRemoveRows(); - delete chat; + chat->unloadAndDeleteLater(); } Chat *currentChat() const @@ -170,7 +171,7 @@ public: return; } - if (m_currentChat && m_currentChat->isModelLoaded()) + if (m_currentChat) m_currentChat->unloadModel(); m_currentChat = chat; diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 34c604fb..0e8b08cb 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -15,6 +15,7 @@ #include //#define DEBUG +//#define DEBUG_MODEL_LOADING #define MPT_INTERNAL_STATE_VERSION 0 #define GPTJ_INTERNAL_STATE_VERSION 0 @@ -37,9 +38,51 @@ static QString modelFilePath(const QString &modelName) return QString(); } +class LLModelStore { +public: + static LLModelStore *globalInstance(); + + LLModelInfo acquireModel(); // will block until llmodel is ready + void releaseModel(const LLModelInfo &info); // must be called when you are done + +private: + LLModelStore() + { + // seed with empty model + m_availableModels.append(LLModelInfo()); + } + ~LLModelStore() {} + QVector m_availableModels; + QMutex m_mutex; + QWaitCondition m_condition; + friend class MyLLModelStore; +}; + +class MyLLModelStore : public LLModelStore { }; +Q_GLOBAL_STATIC(MyLLModelStore, storeInstance) +LLModelStore *LLModelStore::globalInstance() +{ + return storeInstance(); +} + +LLModelInfo LLModelStore::acquireModel() +{ + QMutexLocker locker(&m_mutex); + while (m_availableModels.isEmpty()) + m_condition.wait(locker.mutex()); + return m_availableModels.takeFirst(); +} + +void LLModelStore::releaseModel(const LLModelInfo &info) +{ + QMutexLocker locker(&m_mutex); + m_availableModels.append(info); + Q_ASSERT(m_availableModels.count() < 2); + m_condition.wakeAll(); +} + ChatLLM::ChatLLM(Chat *parent) : QObject{nullptr} - , m_llmodel(nullptr) , m_promptResponseTokens(0) , m_promptTokens(0) , m_responseLogits(0) @@ -49,6 +92,7 @@ ChatLLM::ChatLLM(Chat *parent) moveToThread(&m_llmThread); connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup); connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded); + connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged, Qt::QueuedConnection); connect(m_chat, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged); connect(&m_llmThread, &QThread::started, this, &ChatLLM::threadStarted); m_llmThread.setObjectName(m_chat->id()); @@ -59,7 +103,13 @@ ChatLLM::~ChatLLM() { m_llmThread.quit(); m_llmThread.wait(); - delete m_llmodel; + + // The only time we should have a model loaded here is on shutdown + // as we explicitly unload the model in all other circumstances + if (isModelLoaded()) { + delete m_modelInfo.model; + m_modelInfo.model = nullptr; + } } bool ChatLLM::loadDefaultModel() @@ -76,50 +126,103 @@ bool ChatLLM::loadDefaultModel() bool ChatLLM::loadModel(const QString &modelName) { + // This is a complicated method because N different possible threads are interested in the outcome + // of this method. Why? Because we have a main/gui thread trying to monitor the state of N different + // possible chat threads all vying for a single resource - the currently loaded model - as the user + // switches back and forth between chats. It is important for our main/gui thread to never block + // but simultaneously always have up2date information with regards to which chat has the model loaded + // and what the type and name of that model is. I've tried to comment extensively in this method + // to provide an overview of what we're doing here. + + // We're already loaded with this model if (isModelLoaded() && m_modelName == modelName) return true; - if (isModelLoaded()) { + QString filePath = modelFilePath(modelName); + QFileInfo fileInfo(filePath); + + // We have a live model, but it isn't the one we want + bool alreadyAcquired = isModelLoaded(); + if (alreadyAcquired) { resetContextProtected(); - delete m_llmodel; - m_llmodel = nullptr; +#if defined(DEBUG_MODEL_LOADING) + qDebug() << "already acquired model deleted" << m_chat->id() << m_modelInfo.model; +#endif + delete m_modelInfo.model; + m_modelInfo.model = nullptr; emit isModelLoadedChanged(); + } else { + // This is a blocking call that tries to retrieve the model we need from the model store. + // If it succeeds, then we just have to restore state. If the store has never had a model + // returned to it, then the modelInfo.model pointer should be null which will happen on startup + m_modelInfo = LLModelStore::globalInstance()->acquireModel(); +#if defined(DEBUG_MODEL_LOADING) + qDebug() << "acquired model from store" << m_chat->id() << m_modelInfo.model; +#endif + // At this point it is possible that while we were blocked waiting to acquire the model from the + // store, that our state was changed to not be loaded. If this is the case, release the model + // back into the store and quit loading + if (!m_shouldBeLoaded) { + qDebug() << "no longer need model" << m_chat->id() << m_modelInfo.model; + LLModelStore::globalInstance()->releaseModel(m_modelInfo); + m_modelInfo = LLModelInfo(); + emit isModelLoadedChanged(); + return false; + } + + // Check if the store just gave us exactly the model we were looking for + if (m_modelInfo.model && m_modelInfo.fileInfo == fileInfo) { +#if defined(DEBUG_MODEL_LOADING) + qDebug() << "store had our model" << m_chat->id() << m_modelInfo.model; +#endif + restoreState(); + emit isModelLoadedChanged(); + return true; + } else { + // Release the memory since we have to switch to a different model. +#if defined(DEBUG_MODEL_LOADING) + qDebug() << "deleting model" << m_chat->id() << m_modelInfo.model; +#endif + delete m_modelInfo.model; + m_modelInfo.model = nullptr; + } } - bool isGPTJ = false; - bool isMPT = false; - QString filePath = modelFilePath(modelName); - QFileInfo info(filePath); - if (info.exists()) { + // Guarantee we've released the previous models memory + Q_ASSERT(!m_modelInfo.model); + // Store the file info in the modelInfo in case we have an error loading + m_modelInfo.fileInfo = fileInfo; + + if (fileInfo.exists()) { auto fin = std::ifstream(filePath.toStdString(), std::ios::binary); uint32_t magic; fin.read((char *) &magic, sizeof(magic)); fin.seekg(0); fin.close(); - isGPTJ = magic == 0x67676d6c; - isMPT = magic == 0x67676d6d; + const bool isGPTJ = magic == 0x67676d6c; + const bool isMPT = magic == 0x67676d6d; if (isGPTJ) { - m_modelType = ModelType::GPTJ_; - m_llmodel = new GPTJ; - m_llmodel->loadModel(filePath.toStdString()); + m_modelType = LLModelType::GPTJ_; + m_modelInfo.model = new GPTJ; + m_modelInfo.model->loadModel(filePath.toStdString()); } else if (isMPT) { - m_modelType = ModelType::MPT_; - m_llmodel = new MPT; - m_llmodel->loadModel(filePath.toStdString()); + m_modelType = LLModelType::MPT_; + m_modelInfo.model = new MPT; + m_modelInfo.model->loadModel(filePath.toStdString()); } else { - m_modelType = ModelType::LLAMA_; - m_llmodel = new LLamaModel; - m_llmodel->loadModel(filePath.toStdString()); + m_modelType = LLModelType::LLAMA_; + m_modelInfo.model = new LLamaModel; + m_modelInfo.model->loadModel(filePath.toStdString()); } - - restoreState(); - -#if defined(DEBUG) - qDebug() << "chatllm modelLoadedChanged" << m_chat->id(); - fflush(stdout); +#if defined(DEBUG_MODEL_LOADING) + qDebug() << "new model" << m_chat->id() << m_modelInfo.model; +#endif + restoreState(); +#if defined(DEBUG) + qDebug() << "modelLoadedChanged" << m_chat->id(); + fflush(stdout); #endif - emit isModelLoadedChanged(); static bool isFirstLoad = true; @@ -129,19 +232,20 @@ bool ChatLLM::loadModel(const QString &modelName) } else emit sendModelLoaded(); } else { + LLModelStore::globalInstance()->releaseModel(m_modelInfo); // release back into the store const QString error = QString("Could not find model %1").arg(modelName); emit modelLoadingError(error); } - if (m_llmodel) - setModelName(info.completeBaseName().remove(0, 5)); // remove the ggml- prefix + if (m_modelInfo.model) + setModelName(fileInfo.completeBaseName().remove(0, 5)); // remove the ggml- prefix - return m_llmodel; + return m_modelInfo.model; } bool ChatLLM::isModelLoaded() const { - return m_llmodel && m_llmodel->isModelLoaded(); + return m_modelInfo.model && m_modelInfo.model->isModelLoaded(); } void ChatLLM::regenerateResponse() @@ -226,7 +330,7 @@ bool ChatLLM::handlePrompt(int32_t token) // m_promptResponseTokens and m_responseLogits are related to last prompt/response not // the entire context window which we can reset on regenerate prompt #if defined(DEBUG) - qDebug() << "chatllm prompt process" << m_chat->id() << token; + qDebug() << "prompt process" << m_chat->id() << token; #endif ++m_promptTokens; ++m_promptResponseTokens; @@ -287,12 +391,12 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3 m_ctx.n_batch = n_batch; m_ctx.repeat_penalty = repeat_penalty; m_ctx.repeat_last_n = repeat_penalty_tokens; - m_llmodel->setThreadCount(n_threads); + m_modelInfo.model->setThreadCount(n_threads); #if defined(DEBUG) printf("%s", qPrintable(instructPrompt)); fflush(stdout); #endif - m_llmodel->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx); + m_modelInfo.model->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx); #if defined(DEBUG) printf("\n"); fflush(stdout); @@ -307,26 +411,55 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3 return true; } +void ChatLLM::setShouldBeLoaded(bool b) +{ +#if defined(DEBUG_MODEL_LOADING) + qDebug() << "setShouldBeLoaded" << m_chat->id() << b << m_modelInfo.model; +#endif + m_shouldBeLoaded = b; // atomic + emit shouldBeLoadedChanged(); +} + +void ChatLLM::handleShouldBeLoadedChanged() +{ + if (m_shouldBeLoaded) + reloadModel(); + else + unloadModel(); +} + +void ChatLLM::forceUnloadModel() +{ + m_shouldBeLoaded = false; // atomic + unloadModel(); +} + void ChatLLM::unloadModel() { -#if defined(DEBUG) - qDebug() << "chatllm unloadModel" << m_chat->id(); -#endif + if (!isModelLoaded()) + return; + saveState(); - delete m_llmodel; - m_llmodel = nullptr; +#if defined(DEBUG_MODEL_LOADING) + qDebug() << "unloadModel" << m_chat->id() << m_modelInfo.model; +#endif + LLModelStore::globalInstance()->releaseModel(m_modelInfo); + m_modelInfo = LLModelInfo(); emit isModelLoadedChanged(); } -void ChatLLM::reloadModel(const QString &modelName) +void ChatLLM::reloadModel() { -#if defined(DEBUG) - qDebug() << "chatllm reloadModel" << m_chat->id(); + if (isModelLoaded()) + return; + +#if defined(DEBUG_MODEL_LOADING) + qDebug() << "reloadModel" << m_chat->id() << m_modelInfo.model; #endif - if (modelName.isEmpty()) { + if (m_modelName.isEmpty()) { loadDefaultModel(); } else { - loadModel(modelName); + loadModel(m_modelName); } } @@ -348,7 +481,7 @@ void ChatLLM::generateName() printf("%s", qPrintable(instructPrompt)); fflush(stdout); #endif - m_llmodel->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, ctx); + m_modelInfo.model->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, ctx); #if defined(DEBUG) printf("\n"); fflush(stdout); @@ -415,7 +548,7 @@ bool ChatLLM::serialize(QDataStream &stream, int version) QByteArray compressed = qCompress(m_state); stream << compressed; #if defined(DEBUG) - qDebug() << "chatllm serialize" << m_chat->id() << m_state.size(); + qDebug() << "serialize" << m_chat->id() << m_state.size(); #endif return stream.status() == QDataStream::Ok; } @@ -452,7 +585,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version) stream >> m_state; } #if defined(DEBUG) - qDebug() << "chatllm deserialize" << m_chat->id(); + qDebug() << "deserialize" << m_chat->id(); #endif return stream.status() == QDataStream::Ok; } @@ -462,12 +595,12 @@ void ChatLLM::saveState() if (!isModelLoaded()) return; - const size_t stateSize = m_llmodel->stateSize(); + const size_t stateSize = m_modelInfo.model->stateSize(); m_state.resize(stateSize); #if defined(DEBUG) - qDebug() << "chatllm saveState" << m_chat->id() << "size:" << m_state.size(); + qDebug() << "saveState" << m_chat->id() << "size:" << m_state.size(); #endif - m_llmodel->saveState(static_cast(reinterpret_cast(m_state.data()))); + m_modelInfo.model->saveState(static_cast(reinterpret_cast(m_state.data()))); } void ChatLLM::restoreState() @@ -476,9 +609,9 @@ void ChatLLM::restoreState() return; #if defined(DEBUG) - qDebug() << "chatllm restoreState" << m_chat->id() << "size:" << m_state.size(); + qDebug() << "restoreState" << m_chat->id() << "size:" << m_state.size(); #endif - m_llmodel->restoreState(static_cast(reinterpret_cast(m_state.data()))); + m_modelInfo.model->restoreState(static_cast(reinterpret_cast(m_state.data()))); m_state.clear(); m_state.resize(0); } diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index 59d480f3..4f6d39f7 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -3,9 +3,23 @@ #include #include +#include #include "../gpt4all-backend/llmodel.h" +enum LLModelType { + MPT_, + GPTJ_, + LLAMA_ +}; + +struct LLModelInfo { + LLModel *model = nullptr; + QFileInfo fileInfo; + // NOTE: This does not store the model type or name on purpose as this is left for ChatLLM which + // must be able to serialize the information even if it is in the unloaded state +}; + class Chat; class ChatLLM : public QObject { @@ -17,12 +31,6 @@ class ChatLLM : public QObject Q_PROPERTY(QString generatedName READ generatedName NOTIFY generatedNameChanged) public: - enum ModelType { - MPT_, - GPTJ_, - LLAMA_ - }; - ChatLLM(Chat *parent); virtual ~ChatLLM(); @@ -33,6 +41,9 @@ public: void stopGenerating() { m_stopGenerating = true; } + bool shouldBeLoaded() const { return m_shouldBeLoaded; } + void setShouldBeLoaded(bool b); + QString response() const; QString modelName() const; @@ -52,10 +63,12 @@ public Q_SLOTS: bool loadDefaultModel(); bool loadModel(const QString &modelName); void modelNameChangeRequested(const QString &modelName); + void forceUnloadModel(); void unloadModel(); - void reloadModel(const QString &modelName); + void reloadModel(); void generateName(); void handleChatIdChanged(); + void handleShouldBeLoadedChanged(); Q_SIGNALS: void isModelLoadedChanged(); @@ -71,6 +84,7 @@ Q_SIGNALS: void generatedNameChanged(); void stateChanged(); void threadStarted(); + void shouldBeLoadedChanged(); protected: LLModel::PromptContext m_ctx; @@ -89,16 +103,17 @@ private: void restoreState(); private: - LLModel *m_llmodel; + LLModelInfo m_modelInfo; + LLModelType m_modelType; std::string m_response; std::string m_nameResponse; quint32 m_responseLogits; QString m_modelName; - ModelType m_modelType; Chat *m_chat; QByteArray m_state; QThread m_llmThread; std::atomic m_stopGenerating; + std::atomic m_shouldBeLoaded; bool m_isRecalc; };