From d948a4f2ee88922b2861b19eb7e7660921f7bf67 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Wed, 7 Feb 2024 09:37:59 -0500 Subject: [PATCH] Complete revamp of model loading to allow for more discreet control by the user of the models loading behavior. Signed-off-by: Adam Treat --- gpt4all-backend/llamamodel.cpp | 3 + gpt4all-backend/llmodel.h | 13 + gpt4all-chat/CMakeLists.txt | 3 + gpt4all-chat/chat.cpp | 57 ++-- gpt4all-chat/chat.h | 18 +- gpt4all-chat/chatlistmodel.h | 4 +- gpt4all-chat/chatllm.cpp | 93 ++++++- gpt4all-chat/chatllm.h | 11 +- gpt4all-chat/icons/eject.svg | 6 + gpt4all-chat/main.qml | 372 ++++++++++++++++--------- gpt4all-chat/qml/MyButton.qml | 5 +- gpt4all-chat/qml/MyMiniButton.qml | 47 ++++ gpt4all-chat/qml/SwitchModelDialog.qml | 44 +++ gpt4all-chat/qml/Theme.qml | 1 + 14 files changed, 504 insertions(+), 173 deletions(-) create mode 100644 gpt4all-chat/icons/eject.svg create mode 100644 gpt4all-chat/qml/MyMiniButton.qml create mode 100644 gpt4all-chat/qml/SwitchModelDialog.qml diff --git a/gpt4all-backend/llamamodel.cpp b/gpt4all-backend/llamamodel.cpp index 5b9960ff..0dd9de5d 100644 --- a/gpt4all-backend/llamamodel.cpp +++ b/gpt4all-backend/llamamodel.cpp @@ -180,6 +180,9 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl) d_ptr->model_params.use_mlock = params.use_mlock; #endif + d_ptr->model_params.progress_callback = &LLModel::staticProgressCallback; + d_ptr->model_params.progress_callback_user_data = this; + #ifdef GGML_USE_METAL if (llama_verbose()) { std::cerr << "llama.cpp: using Metal" << std::endl; diff --git a/gpt4all-backend/llmodel.h b/gpt4all-backend/llmodel.h index 7fc5e71d..c3cc937c 100644 --- a/gpt4all-backend/llmodel.h +++ b/gpt4all-backend/llmodel.h @@ -74,6 +74,8 @@ public: int32_t n_last_batch_tokens = 0; }; + using ProgressCallback = std::function; + explicit LLModel() {} virtual ~LLModel() {} @@ -125,6 +127,8 @@ public: virtual bool hasGPUDevice() { return false; } virtual bool usingGPUDevice() { return false; } + void setProgressCallback(ProgressCallback callback) { m_progressCallback = callback; } + protected: // These are pure virtual because subclasses need to implement as the default implementation of // 'prompt' above calls these functions @@ -153,6 +157,15 @@ protected: const Implementation *m_implementation = nullptr; + ProgressCallback m_progressCallback; + static bool staticProgressCallback(float progress, void* ctx) + { + LLModel* model = static_cast(ctx); + if (model && model->m_progressCallback) + return model->m_progressCallback(progress); + return true; + } + private: friend class LLMImplementation; }; diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index ee72f846..0f9d0ab0 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -109,6 +109,7 @@ qt_add_qml_module(chat qml/ModelSettings.qml qml/ApplicationSettings.qml qml/LocalDocsSettings.qml + qml/SwitchModelDialog.qml qml/MySettingsTab.qml qml/MySettingsStack.qml qml/MySettingsDestructiveButton.qml @@ -123,6 +124,7 @@ qt_add_qml_module(chat qml/MyTextField.qml qml/MyCheckBox.qml qml/MyBusyIndicator.qml + qml/MyMiniButton.qml qml/MyToolButton.qml RESOURCES icons/send_message.svg @@ -133,6 +135,7 @@ qt_add_qml_module(chat icons/db.svg icons/download.svg icons/settings.svg + icons/eject.svg icons/edit.svg icons/image.svg icons/trash.svg diff --git a/gpt4all-chat/chat.cpp b/gpt4all-chat/chat.cpp index 0e66c5c2..8730adbc 100644 --- a/gpt4all-chat/chat.cpp +++ b/gpt4all-chat/chat.cpp @@ -23,14 +23,10 @@ Chat::Chat(bool isServer, QObject *parent) , m_id(Network::globalInstance()->generateUniqueId()) , m_name(tr("Server Chat")) , m_chatModel(new ChatModel(this)) - , m_responseInProgress(false) , m_responseState(Chat::ResponseStopped) , m_creationDate(QDateTime::currentSecsSinceEpoch()) , m_llmodel(new Server(this)) , m_isServer(true) - , m_shouldDeleteLater(false) - , m_isModelLoaded(false) - , m_shouldLoadModelWhenInstalled(false) , m_collectionModel(new LocalDocsCollectionsModel(this)) { connectLLM(); @@ -45,7 +41,7 @@ Chat::~Chat() void Chat::connectLLM() { // Should be in different threads - connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::handleModelLoadedChanged, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::modelLoadingPercentageChanged, this, &Chat::handleModelLoadingPercentageChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection); @@ -57,6 +53,7 @@ void Chat::connectLLM() connect(m_llmodel, &ChatLLM::reportFallbackReason, this, &Chat::handleFallbackReasonChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::trySwitchContextOfLoadedModelCompleted, this, &Chat::trySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection); connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection); connect(this, &Chat::modelChangeRequested, m_llmodel, &ChatLLM::modelChangeRequested, Qt::QueuedConnection); @@ -69,8 +66,6 @@ void Chat::connectLLM() connect(this, &Chat::processSystemPromptRequested, m_llmodel, &ChatLLM::processSystemPrompt, Qt::QueuedConnection); connect(this, &Chat::collectionListChanged, m_collectionModel, &LocalDocsCollectionsModel::setCollections); - connect(ModelList::globalInstance()->installedModels(), &InstalledModels::countChanged, - this, &Chat::handleModelInstalled, Qt::QueuedConnection); } void Chat::reset() @@ -101,7 +96,12 @@ void Chat::processSystemPrompt() bool Chat::isModelLoaded() const { - return m_isModelLoaded; + return m_modelLoadingPercentage == 1.0f; +} + +float Chat::modelLoadingPercentage() const +{ + return m_modelLoadingPercentage; } void Chat::resetResponseState() @@ -158,16 +158,18 @@ void Chat::handleResponseChanged(const QString &response) emit responseChanged(); } -void Chat::handleModelLoadedChanged(bool loaded) +void Chat::handleModelLoadingPercentageChanged(float loadingPercentage) { if (m_shouldDeleteLater) deleteLater(); - if (loaded == m_isModelLoaded) + if (loadingPercentage == m_modelLoadingPercentage) return; - m_isModelLoaded = loaded; - emit isModelLoadedChanged(); + m_modelLoadingPercentage = loadingPercentage; + emit modelLoadingPercentageChanged(); + if (m_modelLoadingPercentage == 1.0f || m_modelLoadingPercentage == 0.0f) + emit isModelLoadedChanged(); } void Chat::promptProcessing() @@ -238,10 +240,10 @@ ModelInfo Chat::modelInfo() const void Chat::setModelInfo(const ModelInfo &modelInfo) { - if (m_modelInfo == modelInfo) + if (m_modelInfo == modelInfo && isModelLoaded()) return; - m_isModelLoaded = false; + m_modelLoadingPercentage = std::numeric_limits::min(); emit isModelLoadedChanged(); m_modelLoadingError = QString(); emit modelLoadingErrorChanged(); @@ -291,21 +293,26 @@ void Chat::unloadModel() void Chat::reloadModel() { - // If the installed model list is empty, then we mark a special flag and monitor for when a model - // is installed - if (!ModelList::globalInstance()->installedModels()->count()) { - m_shouldLoadModelWhenInstalled = true; - return; - } m_llmodel->setShouldBeLoaded(true); } -void Chat::handleModelInstalled() +void Chat::forceUnloadModel() { - if (!m_shouldLoadModelWhenInstalled) - return; - m_shouldLoadModelWhenInstalled = false; - reloadModel(); + stopGenerating(); + m_llmodel->setForceUnloadModel(true); + m_llmodel->setShouldBeLoaded(false); +} + +void Chat::forceReloadModel() +{ + m_llmodel->setForceUnloadModel(true); + m_llmodel->setShouldBeLoaded(true); +} + +void Chat::trySwitchContextOfLoadedModel() +{ + emit trySwitchContextOfLoadedModelAttempted(); + m_llmodel->setShouldTrySwitchContext(true); } void Chat::generatedNameChanged(const QString &name) diff --git a/gpt4all-chat/chat.h b/gpt4all-chat/chat.h index ae6910bf..cecbcbda 100644 --- a/gpt4all-chat/chat.h +++ b/gpt4all-chat/chat.h @@ -17,6 +17,7 @@ class Chat : public QObject Q_PROPERTY(QString name READ name WRITE setName NOTIFY nameChanged) Q_PROPERTY(ChatModel *chatModel READ chatModel NOTIFY chatModelChanged) Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged) + Q_PROPERTY(float modelLoadingPercentage READ modelLoadingPercentage NOTIFY modelLoadingPercentageChanged) Q_PROPERTY(QString response READ response NOTIFY responseChanged) Q_PROPERTY(ModelInfo modelInfo READ modelInfo WRITE setModelInfo NOTIFY modelInfoChanged) Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged) @@ -61,6 +62,7 @@ public: Q_INVOKABLE void reset(); Q_INVOKABLE void processSystemPrompt(); Q_INVOKABLE bool isModelLoaded() const; + Q_INVOKABLE float modelLoadingPercentage() const; Q_INVOKABLE void prompt(const QString &prompt); Q_INVOKABLE void regenerateResponse(); Q_INVOKABLE void stopGenerating(); @@ -75,8 +77,11 @@ public: void setModelInfo(const ModelInfo &modelInfo); bool isRecalc() const; - void unloadModel(); - void reloadModel(); + Q_INVOKABLE void unloadModel(); + Q_INVOKABLE void reloadModel(); + Q_INVOKABLE void forceUnloadModel(); + Q_INVOKABLE void forceReloadModel(); + Q_INVOKABLE void trySwitchContextOfLoadedModel(); void unloadAndDeleteLater(); qint64 creationDate() const { return m_creationDate; } @@ -106,6 +111,7 @@ Q_SIGNALS: void nameChanged(); void chatModelChanged(); void isModelLoadedChanged(); + void modelLoadingPercentageChanged(); void responseChanged(); void responseInProgressChanged(); void responseStateChanged(); @@ -127,10 +133,12 @@ Q_SIGNALS: void deviceChanged(); void fallbackReasonChanged(); void collectionModelChanged(); + void trySwitchContextOfLoadedModelAttempted(); + void trySwitchContextOfLoadedModelCompleted(bool); private Q_SLOTS: void handleResponseChanged(const QString &response); - void handleModelLoadedChanged(bool); + void handleModelLoadingPercentageChanged(float); void promptProcessing(); void responseStopped(); void generatedNameChanged(const QString &name); @@ -141,7 +149,6 @@ private Q_SLOTS: void handleFallbackReasonChanged(const QString &device); void handleDatabaseResultsChanged(const QList &results); void handleModelInfoChanged(const ModelInfo &modelInfo); - void handleModelInstalled(); private: QString m_id; @@ -163,8 +170,7 @@ private: QList m_databaseResults; bool m_isServer = false; bool m_shouldDeleteLater = false; - bool m_isModelLoaded = false; - bool m_shouldLoadModelWhenInstalled = false; + float m_modelLoadingPercentage = 0.0f; LocalDocsCollectionsModel *m_collectionModel; }; diff --git a/gpt4all-chat/chatlistmodel.h b/gpt4all-chat/chatlistmodel.h index 3f99c622..ed04cc7a 100644 --- a/gpt4all-chat/chatlistmodel.h +++ b/gpt4all-chat/chatlistmodel.h @@ -179,9 +179,9 @@ public: if (m_currentChat && m_currentChat != m_serverChat) m_currentChat->unloadModel(); m_currentChat = chat; - if (!m_currentChat->isModelLoaded() && m_currentChat != m_serverChat) - m_currentChat->reloadModel(); emit currentChatChanged(); + if (!m_currentChat->isModelLoaded() && m_currentChat != m_serverChat) + m_currentChat->trySwitchContextOfLoadedModel(); } Q_INVOKABLE Chat* get(int index) diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 844942e4..4b456e34 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -62,7 +62,9 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer) , m_promptResponseTokens(0) , m_promptTokens(0) , m_isRecalc(false) - , m_shouldBeLoaded(true) + , m_shouldBeLoaded(false) + , m_forceUnloadModel(false) + , m_shouldTrySwitchContext(false) , m_stopGenerating(false) , m_timer(nullptr) , m_isServer(isServer) @@ -76,6 +78,8 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer) connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded); connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged, Qt::QueuedConnection); // explicitly queued + connect(this, &ChatLLM::shouldTrySwitchContextChanged, this, &ChatLLM::handleShouldTrySwitchContextChanged, + Qt::QueuedConnection); // explicitly queued connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged); connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted); connect(MySettings::globalInstance(), &MySettings::forceMetalChanged, this, &ChatLLM::handleForceMetalChanged); @@ -143,6 +147,54 @@ bool ChatLLM::loadDefaultModel() return loadModel(defaultModel); } +bool ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo) +{ + // We're trying to see if the store already has the model fully loaded that we wish to use + // and if so we just acquire it from the store and switch the context and return true. If the + // store doesn't have it or we're already loaded or in any other case just return false. + + // If we're already loaded or a server or we're reloading to change the variant/device or the + // modelInfo is empty, then this should fail + if (isModelLoaded() || m_isServer || m_reloadingToChangeVariant || modelInfo.name().isEmpty()) { + m_shouldTrySwitchContext = false; + emit trySwitchContextOfLoadedModelCompleted(false); + return false; + } + + QString filePath = modelInfo.dirpath + modelInfo.filename(); + QFileInfo fileInfo(filePath); + + m_llModelInfo = LLModelStore::globalInstance()->acquireModel(); +#if defined(DEBUG_MODEL_LOADING) + qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model; +#endif + + // The store gave us no already loaded model, the wrong type of model, then give it back to the + // store and fail + if (!m_llModelInfo.model || m_llModelInfo.fileInfo != fileInfo) { + LLModelStore::globalInstance()->releaseModel(m_llModelInfo); + m_llModelInfo = LLModelInfo(); + m_shouldTrySwitchContext = false; + emit trySwitchContextOfLoadedModelCompleted(false); + return false; + } + +#if defined(DEBUG_MODEL_LOADING) + qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model; +#endif + + // We should be loaded and now we are + m_shouldBeLoaded = true; + m_shouldTrySwitchContext = false; + + // Restore, signal and process + restoreState(); + emit modelLoadingPercentageChanged(1.0f); + emit trySwitchContextOfLoadedModelCompleted(true); + processSystemPrompt(); + return true; +} + bool ChatLLM::loadModel(const ModelInfo &modelInfo) { // This is a complicated method because N different possible threads are interested in the outcome @@ -170,7 +222,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) #endif delete m_llModelInfo.model; m_llModelInfo.model = nullptr; - emit isModelLoadedChanged(false); + emit modelLoadingPercentageChanged(std::numeric_limits::min()); } else if (!m_isServer) { // This is a blocking call that tries to retrieve the model we need from the model store. // If it succeeds, then we just have to restore state. If the store has never had a model @@ -188,7 +240,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) #endif LLModelStore::globalInstance()->releaseModel(m_llModelInfo); m_llModelInfo = LLModelInfo(); - emit isModelLoadedChanged(false); + emit modelLoadingPercentageChanged(0.0f); return false; } @@ -198,7 +250,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model; #endif restoreState(); - emit isModelLoadedChanged(true); + emit modelLoadingPercentageChanged(1.0f); setModelInfo(modelInfo); Q_ASSERT(!m_modelInfo.filename().isEmpty()); if (m_modelInfo.filename().isEmpty()) @@ -261,6 +313,12 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx); if (m_llModelInfo.model) { + + m_llModelInfo.model->setProgressCallback([this](float progress) -> bool { + emit modelLoadingPercentageChanged(progress); + return m_shouldBeLoaded; + }); + // Update the settings that a model is being loaded and update the device list MySettings::globalInstance()->setAttemptModelLoad(filePath); @@ -354,7 +412,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) qDebug() << "modelLoadedChanged" << m_llmThread.objectName(); fflush(stdout); #endif - emit isModelLoadedChanged(isModelLoaded()); + emit modelLoadingPercentageChanged(isModelLoaded() ? 1.0f : 0.0f); static bool isFirstLoad = true; if (isFirstLoad) { @@ -456,6 +514,7 @@ void ChatLLM::setModelInfo(const ModelInfo &modelInfo) void ChatLLM::modelChangeRequested(const ModelInfo &modelInfo) { + m_shouldBeLoaded = true; loadModel(modelInfo); } @@ -598,6 +657,12 @@ void ChatLLM::setShouldBeLoaded(bool b) emit shouldBeLoadedChanged(); } +void ChatLLM::setShouldTrySwitchContext(bool b) +{ + m_shouldTrySwitchContext = b; // atomic + emit shouldTrySwitchContextChanged(); +} + void ChatLLM::handleShouldBeLoadedChanged() { if (m_shouldBeLoaded) @@ -606,10 +671,10 @@ void ChatLLM::handleShouldBeLoadedChanged() unloadModel(); } -void ChatLLM::forceUnloadModel() +void ChatLLM::handleShouldTrySwitchContextChanged() { - m_shouldBeLoaded = false; // atomic - unloadModel(); + if (m_shouldTrySwitchContext) + trySwitchContextOfLoadedModel(modelInfo()); } void ChatLLM::unloadModel() @@ -617,17 +682,27 @@ void ChatLLM::unloadModel() if (!isModelLoaded() || m_isServer) return; + emit modelLoadingPercentageChanged(0.0f); saveState(); #if defined(DEBUG_MODEL_LOADING) qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model; #endif + + if (m_forceUnloadModel) { + delete m_llModelInfo.model; + m_llModelInfo.model = nullptr; + m_forceUnloadModel = false; + } + LLModelStore::globalInstance()->releaseModel(m_llModelInfo); m_llModelInfo = LLModelInfo(); - emit isModelLoadedChanged(false); } void ChatLLM::reloadModel() { + if (isModelLoaded() && m_forceUnloadModel) + unloadModel(); // we unload first if we are forcing an unload + if (isModelLoaded() || m_isServer) return; diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index d6af4cb0..278e79cc 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -81,6 +81,8 @@ public: bool shouldBeLoaded() const { return m_shouldBeLoaded; } void setShouldBeLoaded(bool b); + void setShouldTrySwitchContext(bool b); + void setForceUnloadModel(bool b) { m_forceUnloadModel = b; } QString response() const; @@ -98,14 +100,15 @@ public: public Q_SLOTS: bool prompt(const QList &collectionList, const QString &prompt); bool loadDefaultModel(); + bool trySwitchContextOfLoadedModel(const ModelInfo &modelInfo); bool loadModel(const ModelInfo &modelInfo); void modelChangeRequested(const ModelInfo &modelInfo); - void forceUnloadModel(); void unloadModel(); void reloadModel(); void generateName(); void handleChatIdChanged(const QString &id); void handleShouldBeLoadedChanged(); + void handleShouldTrySwitchContextChanged(); void handleThreadStarted(); void handleForceMetalChanged(bool forceMetal); void handleDeviceChanged(); @@ -114,7 +117,7 @@ public Q_SLOTS: Q_SIGNALS: void recalcChanged(); - void isModelLoadedChanged(bool); + void modelLoadingPercentageChanged(float); void modelLoadingError(const QString &error); void responseChanged(const QString &response); void promptProcessing(); @@ -125,6 +128,8 @@ Q_SIGNALS: void stateChanged(); void threadStarted(); void shouldBeLoadedChanged(); + void shouldTrySwitchContextChanged(); + void trySwitchContextOfLoadedModelCompleted(bool); void requestRetrieveFromDB(const QList &collections, const QString &text, int retrievalSize, QList *results); void reportSpeed(const QString &speed); void reportDevice(const QString &device); @@ -167,7 +172,9 @@ private: QThread m_llmThread; std::atomic m_stopGenerating; std::atomic m_shouldBeLoaded; + std::atomic m_shouldTrySwitchContext; std::atomic m_isRecalc; + std::atomic m_forceUnloadModel; bool m_isServer; bool m_forceMetal; bool m_reloadingToChangeVariant; diff --git a/gpt4all-chat/icons/eject.svg b/gpt4all-chat/icons/eject.svg new file mode 100644 index 00000000..9649c487 --- /dev/null +++ b/gpt4all-chat/icons/eject.svg @@ -0,0 +1,6 @@ + + + diff --git a/gpt4all-chat/main.qml b/gpt4all-chat/main.qml index 72fbc3b8..66104e37 100644 --- a/gpt4all-chat/main.qml +++ b/gpt4all-chat/main.qml @@ -126,6 +126,10 @@ Window { } } + function currentModelName() { + return ModelList.modelInfo(currentChat.modelInfo.id).name; + } + PopupDialog { id: errorCompatHardware anchors.centerIn: parent @@ -282,6 +286,18 @@ Window { } } + SwitchModelDialog { + id: switchModelDialog + anchors.centerIn: parent + width: Math.min(1024, window.width - (window.width * .2)) + height: Math.min(600, window.height - (window.height * .2)) + Item { + Accessible.role: Accessible.Dialog + Accessible.name: qsTr("Switch model dialog") + Accessible.description: qsTr("Warn the user if they switch models, then context will be erased") + } + } + Rectangle { id: header anchors.left: parent.left @@ -292,7 +308,9 @@ Window { Item { anchors.centerIn: parent height: childrenRect.height - visible: currentChat.isModelLoaded || currentChat.modelLoadingError !== "" || currentChat.isServer + visible: true + || currentChat.modelLoadingError !== "" + || currentChat.isServer Label { id: modelLabel @@ -306,102 +324,168 @@ Window { horizontalAlignment: TextInput.AlignRight } - MyComboBox { - id: comboBox - implicitWidth: 375 - width: window.width >= 750 ? implicitWidth : implicitWidth - ((750 - window.width)) + RowLayout { + id: comboLayout anchors.top: modelLabel.top anchors.bottom: modelLabel.bottom anchors.horizontalCenter: parent.horizontalCenter anchors.horizontalCenterOffset: window.width >= 950 ? 0 : Math.max(-((950 - window.width) / 2), -99.5) - enabled: !currentChat.isServer - model: ModelList.installedModels - valueRole: "id" - textRole: "name" - property string currentModelName: "" - function updateCurrentModelName() { - var info = ModelList.modelInfo(currentChat.modelInfo.id); - comboBox.currentModelName = info.name; - } - Connections { - target: currentChat - function onModelInfoChanged() { - comboBox.updateCurrentModelName(); + spacing: 20 + + MyComboBox { + id: comboBox + Layout.fillWidth: true + Layout.fillHeight: true + implicitWidth: 575 + width: window.width >= 750 ? implicitWidth : implicitWidth - ((750 - window.width)) + enabled: !currentChat.isServer + model: ModelList.installedModels + valueRole: "id" + textRole: "name" + property bool isCurrentlyLoading: false + property real modelLoadingPercentage: 0.0 + property bool trySwitchContextInProgress: false + + function changeModel(index) { + comboBox.modelLoadingPercentage = 0.0; + comboBox.isCurrentlyLoading = true; + currentChat.stopGenerating() + currentChat.reset(); + currentChat.modelInfo = ModelList.modelInfo(comboBox.valueAt(index)) } - } - Connections { - target: window - function onCurrentChatChanged() { - comboBox.updateCurrentModelName(); + + Connections { + target: currentChat + function onModelLoadingPercentageChanged() { + comboBox.modelLoadingPercentage = currentChat.modelLoadingPercentage; + comboBox.isCurrentlyLoading = currentChat.modelLoadingPercentage !== 0.0 + && currentChat.modelLoadingPercentage !== 1.0; + } + function onTrySwitchContextOfLoadedModelAttempted() { + comboBox.trySwitchContextInProgress = true; + } + function onTrySwitchContextOfLoadedModelCompleted() { + comboBox.trySwitchContextInProgress = false; + } + } + Connections { + target: switchModelDialog + function onAccepted() { + comboBox.changeModel(switchModelDialog.index) + } + } + + background: ProgressBar { + id: modelProgress + value: comboBox.modelLoadingPercentage + background: Rectangle { + color: theme.mainComboBackground + radius: 10 + } + contentItem: Item { + Rectangle { + visible: comboBox.isCurrentlyLoading + anchors.bottom: parent.bottom + width: modelProgress.visualPosition * parent.width + height: 10 + radius: 2 + color: theme.progressForeground + } + } } - } - background: Rectangle { - color: theme.mainComboBackground - radius: 10 - } - contentItem: Text { - anchors.horizontalCenter: parent.horizontalCenter - leftPadding: 10 - rightPadding: 20 - text: currentChat.modelLoadingError !== "" - ? qsTr("Model loading error...") - : comboBox.currentModelName - font.pixelSize: theme.fontSizeLarger - color: theme.white - verticalAlignment: Text.AlignVCenter - horizontalAlignment: Text.AlignHCenter - elide: Text.ElideRight - } - delegate: ItemDelegate { - width: comboBox.width contentItem: Text { - text: name - color: theme.textColor - font: comboBox.font - elide: Text.ElideRight + anchors.horizontalCenter: parent.horizontalCenter + leftPadding: 10 + rightPadding: 20 + text: { + if (currentChat.modelLoadingError !== "") + return qsTr("Model loading error...") + if (comboBox.trySwitchContextInProgress) + return qsTr("Switching context...") + if (currentModelName() === "") + return qsTr("Choose a model...") + if (currentChat.modelLoadingPercentage === 0.0) + return qsTr("Reload \u00B7 ") + currentModelName() + if (comboBox.isCurrentlyLoading) + return qsTr("Loading \u00B7 ") + currentModelName() + return currentModelName() + } + font.pixelSize: theme.fontSizeLarger + color: theme.white verticalAlignment: Text.AlignVCenter + horizontalAlignment: Text.AlignHCenter + elide: Text.ElideRight + } + delegate: ItemDelegate { + id: comboItemDelegate + width: comboBox.width + contentItem: Text { + text: name + color: theme.textColor + font: comboBox.font + elide: Text.ElideRight + verticalAlignment: Text.AlignVCenter + } + background: Rectangle { + color: (index % 2 === 0 ? theme.darkContrast : theme.lightContrast) + border.width: highlighted + border.color: theme.accentColor + } + highlighted: comboBox.highlightedIndex === index } - background: Rectangle { - color: (index % 2 === 0 ? theme.darkContrast : theme.lightContrast) - border.width: highlighted - border.color: theme.accentColor + Accessible.role: Accessible.ComboBox + Accessible.name: currentModelName() + Accessible.description: qsTr("The top item is the current model") + onActivated: function (index) { + var newInfo = ModelList.modelInfo(comboBox.valueAt(index)); + if (currentModelName() !== "" + && newInfo !== currentChat.modelInfo + && chatModel.count !== 0) { + switchModelDialog.index = index; + switchModelDialog.open(); + } else { + comboBox.changeModel(index); + } } - highlighted: comboBox.highlightedIndex === index - } - Accessible.role: Accessible.ComboBox - Accessible.name: comboBox.currentModelName - Accessible.description: qsTr("The top item is the current model") - onActivated: function (index) { - currentChat.stopGenerating() - currentChat.reset(); - currentChat.modelInfo = ModelList.modelInfo(comboBox.valueAt(index)) - } - } - } - Item { - anchors.centerIn: parent - visible: ModelList.installedModels.count - && !currentChat.isModelLoaded - && currentChat.modelLoadingError === "" - && !currentChat.isServer - width: childrenRect.width - height: childrenRect.height - Row { - spacing: 5 - MyBusyIndicator { - anchors.verticalCenter: parent.verticalCenter - running: parent.visible - Accessible.role: Accessible.Animation - Accessible.name: qsTr("Busy indicator") - Accessible.description: qsTr("loading model...") - } + MyMiniButton { + id: ejectButton + visible: currentChat.isModelLoaded + z: 500 + anchors.right: parent.right + anchors.rightMargin: 50 + anchors.verticalCenter: parent.verticalCenter + source: "qrc:/gpt4all/icons/eject.svg" + backgroundColor: theme.gray300 + backgroundColorHovered: theme.iconBackgroundLight + onClicked: { + currentChat.forceUnloadModel(); + } + ToolTip.text: qsTr("Eject the currently loaded model") + ToolTip.visible: hovered + } - Label { - anchors.verticalCenter: parent.verticalCenter - text: qsTr("Loading model...") - font.pixelSize: theme.fontSizeLarge - color: theme.oppositeTextColor + MyMiniButton { + id: reloadButton + visible: currentChat.modelLoadingError === "" + && !comboBox.trySwitchContextInProgress + && (currentChat.isModelLoaded || currentModelName() !== "") + z: 500 + anchors.right: ejectButton.visible ? ejectButton.left : parent.right + anchors.rightMargin: ejectButton.visible ? 10 : 50 + anchors.verticalCenter: parent.verticalCenter + source: "qrc:/gpt4all/icons/regenerate.svg" + backgroundColor: theme.gray300 + backgroundColorHovered: theme.iconBackgroundLight + onClicked: { + if (currentChat.isModelLoaded) + currentChat.forceReloadModel(); + else + currentChat.reloadModel(); + } + ToolTip.text: qsTr("Reload the currently loaded model") + ToolTip.visible: hovered + } } } } @@ -790,9 +874,9 @@ Window { Rectangle { id: homePage - color: "transparent"//theme.green200 + color: "transparent" anchors.fill: parent - visible: (ModelList.installedModels.count === 0 || chatModel.count === 0) && !currentChat.isServer + visible: !currentChat.isModelLoaded && (ModelList.installedModels.count === 0 || currentModelName() === "") && !currentChat.isServer ColumnLayout { anchors.centerIn: parent @@ -1138,50 +1222,84 @@ Window { } } - MyButton { - id: myButton - visible: chatModel.count && !currentChat.isServer - textColor: theme.textColor - Image { - anchors.verticalCenter: parent.verticalCenter - anchors.left: parent.left - anchors.leftMargin: 15 - source: currentChat.responseInProgress ? "qrc:/gpt4all/icons/stop_generating.svg" : "qrc:/gpt4all/icons/regenerate.svg" - } - leftPadding: 50 - onClicked: { - var index = Math.max(0, chatModel.count - 1); - var listElement = chatModel.get(index); - - if (currentChat.responseInProgress) { - listElement.stopped = true - currentChat.stopGenerating() - } else { - currentChat.regenerateResponse() - if (chatModel.count) { - if (listElement.name === qsTr("Response: ")) { - chatModel.updateCurrentResponse(index, true); - chatModel.updateStopped(index, false); - chatModel.updateThumbsUpState(index, false); - chatModel.updateThumbsDownState(index, false); - chatModel.updateNewResponse(index, ""); - currentChat.prompt(listElement.prompt) + RowLayout { + anchors.bottom: textInputView.top + anchors.horizontalCenter: textInputView.horizontalCenter + anchors.bottomMargin: 20 + spacing: 10 + MyButton { + textColor: theme.textColor + visible: chatModel.count && !currentChat.isServer && currentChat.isModelLoaded + Image { + anchors.verticalCenter: parent.verticalCenter + anchors.left: parent.left + anchors.leftMargin: 15 + source: currentChat.responseInProgress ? "qrc:/gpt4all/icons/stop_generating.svg" : "qrc:/gpt4all/icons/regenerate.svg" + } + leftPadding: 50 + onClicked: { + var index = Math.max(0, chatModel.count - 1); + var listElement = chatModel.get(index); + + if (currentChat.responseInProgress) { + listElement.stopped = true + currentChat.stopGenerating() + } else { + currentChat.regenerateResponse() + if (chatModel.count) { + if (listElement.name === qsTr("Response: ")) { + chatModel.updateCurrentResponse(index, true); + chatModel.updateStopped(index, false); + chatModel.updateThumbsUpState(index, false); + chatModel.updateThumbsDownState(index, false); + chatModel.updateNewResponse(index, ""); + currentChat.prompt(listElement.prompt) + } } } } + + borderWidth: 1 + backgroundColor: theme.conversationButtonBackground + backgroundColorHovered: theme.conversationButtonBackgroundHovered + backgroundRadius: 5 + padding: 15 + topPadding: 4 + bottomPadding: 4 + text: currentChat.responseInProgress ? qsTr("Stop generating") : qsTr("Regenerate response") + fontPixelSize: theme.fontSizeSmaller + Accessible.description: qsTr("Controls generation of the response") } - background: Rectangle { - border.color: theme.conversationButtonBorder - border.width: 2 - radius: 10 - color: myButton.hovered ? theme.conversationButtonBackgroundHovered : theme.conversationButtonBackground + + MyButton { + textColor: theme.textColor + visible: chatModel.count + && !currentChat.isServer + && !currentChat.isModelLoaded + && currentChat.modelLoadingPercentage === 0.0 + && currentChat.modelInfo.name !== "" + Image { + anchors.verticalCenter: parent.verticalCenter + anchors.left: parent.left + anchors.leftMargin: 15 + source: "qrc:/gpt4all/icons/regenerate.svg" + } + leftPadding: 50 + onClicked: { + currentChat.reloadModel(); + } + + borderWidth: 1 + backgroundColor: theme.conversationButtonBackground + backgroundColorHovered: theme.conversationButtonBackgroundHovered + backgroundRadius: 5 + padding: 15 + topPadding: 4 + bottomPadding: 4 + text: qsTr("Reload \u00B7 ") + currentChat.modelInfo.name + fontPixelSize: theme.fontSizeSmaller + Accessible.description: qsTr("Reloads the model") } - anchors.bottom: textInputView.top - anchors.horizontalCenter: textInputView.horizontalCenter - anchors.bottomMargin: 20 - padding: 15 - text: currentChat.responseInProgress ? qsTr("Stop generating") : qsTr("Regenerate response") - Accessible.description: qsTr("Controls generation of the response") } Text { @@ -1224,7 +1342,7 @@ Window { rightPadding: 40 enabled: currentChat.isModelLoaded && !currentChat.isServer font.pixelSize: theme.fontSizeLarger - placeholderText: qsTr("Send a message...") + placeholderText: currentChat.isModelLoaded ? qsTr("Send a message...") : qsTr("Load a model to continue...") Accessible.role: Accessible.EditableText Accessible.name: placeholderText Accessible.description: qsTr("Send messages/prompts to the model") diff --git a/gpt4all-chat/qml/MyButton.qml b/gpt4all-chat/qml/MyButton.qml index d79c275b..6f14f9d3 100644 --- a/gpt4all-chat/qml/MyButton.qml +++ b/gpt4all-chat/qml/MyButton.qml @@ -13,9 +13,10 @@ Button { property color mutedTextColor: theme.oppositeMutedTextColor property color backgroundColor: theme.buttonBackground property color backgroundColorHovered: theme.buttonBackgroundHovered + property real backgroundRadius: 10 property real borderWidth: MySettings.chatTheme === "LegacyDark" ? 1 : 0 property color borderColor: theme.buttonBorder - property real fontPixelSize: theme.fontSizeLarge + property real fontPixelSize: theme.fontSizeLarge contentItem: Text { text: myButton.text horizontalAlignment: Text.AlignHCenter @@ -25,7 +26,7 @@ Button { Accessible.name: text } background: Rectangle { - radius: 10 + radius: myButton.backgroundRadius border.width: myButton.borderWidth border.color: myButton.borderColor color: myButton.hovered ? backgroundColorHovered : backgroundColor diff --git a/gpt4all-chat/qml/MyMiniButton.qml b/gpt4all-chat/qml/MyMiniButton.qml new file mode 100644 index 00000000..d5e5571a --- /dev/null +++ b/gpt4all-chat/qml/MyMiniButton.qml @@ -0,0 +1,47 @@ +import QtCore +import QtQuick +import QtQuick.Controls +import QtQuick.Controls.Basic +import Qt5Compat.GraphicalEffects + +Button { + id: myButton + padding: 0 + property color backgroundColor: theme.iconBackgroundDark + property color backgroundColorHovered: theme.iconBackgroundHovered + property alias source: image.source + property alias fillMode: image.fillMode + width: 30 + height: 30 + contentItem: Text { + text: myButton.text + horizontalAlignment: Text.AlignHCenter + color: myButton.enabled ? theme.textColor : theme.mutedTextColor + font.pixelSize: theme.fontSizeLarge + Accessible.role: Accessible.Button + Accessible.name: text + } + + background: Item { + anchors.fill: parent + Rectangle { + anchors.fill: parent + color: "transparent" + } + Image { + id: image + anchors.centerIn: parent + mipmap: true + width: 20 + height: 20 + } + ColorOverlay { + anchors.fill: image + source: image + color: myButton.hovered ? backgroundColorHovered : backgroundColor + } + } + Accessible.role: Accessible.Button + Accessible.name: text + ToolTip.delay: Qt.styleHints.mousePressAndHoldInterval +} diff --git a/gpt4all-chat/qml/SwitchModelDialog.qml b/gpt4all-chat/qml/SwitchModelDialog.qml new file mode 100644 index 00000000..54dfbe60 --- /dev/null +++ b/gpt4all-chat/qml/SwitchModelDialog.qml @@ -0,0 +1,44 @@ +import QtCore +import QtQuick +import QtQuick.Controls +import QtQuick.Controls.Basic +import QtQuick.Layouts +import llm +import mysettings + +MyDialog { + id: switchModelDialog + anchors.centerIn: parent + modal: true + padding: 20 + property int index: -1 + + Theme { + id: theme + } + + Column { + id: column + spacing: 20 + } + + footer: DialogButtonBox { + id: dialogBox + padding: 20 + alignment: Qt.AlignRight + spacing: 10 + MySettingsButton { + text: qsTr("Continue") + Accessible.description: qsTr("Continue with model loading") + DialogButtonBox.buttonRole: DialogButtonBox.AcceptRole + } + MySettingsButton { + text: qsTr("Cancel") + Accessible.description: qsTr("Cancel") + DialogButtonBox.buttonRole: DialogButtonBox.RejectRole + } + background: Rectangle { + color: "transparent" + } + } +} diff --git a/gpt4all-chat/qml/Theme.qml b/gpt4all-chat/qml/Theme.qml index 49f8343c..2b8c9733 100644 --- a/gpt4all-chat/qml/Theme.qml +++ b/gpt4all-chat/qml/Theme.qml @@ -555,6 +555,7 @@ QtObject { property real fontSizeFixedSmall: 16 property real fontSize: Qt.application.font.pixelSize + property real fontSizeSmaller: fontSizeLarge - 4 property real fontSizeSmall: fontSizeLarge - 2 property real fontSizeLarge: MySettings.fontSize === "Small" ? fontSize : MySettings.fontSize === "Medium" ?