From 267601d670ffefae1dece00719d4cb0065c3804d Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Tue, 27 Jun 2023 11:54:34 -0400 Subject: [PATCH] Enable the force metal setting. --- gpt4all-chat/CMakeLists.txt | 1 + gpt4all-chat/chatlistmodel.cpp | 2 +- gpt4all-chat/chatllm.cpp | 32 ++++++++++++++++++++++++++--- gpt4all-chat/chatllm.h | 3 +++ gpt4all-chat/llm.h | 1 + gpt4all-chat/main.cpp | 2 ++ gpt4all-chat/mysettings.cpp | 31 ++++++++++++++++++++++++++++ gpt4all-chat/mysettings.h | 30 +++++++++++++++++++++++++++ gpt4all-chat/qml/SettingsDialog.qml | 14 ++++++++++--- 9 files changed, 109 insertions(+), 7 deletions(-) create mode 100644 gpt4all-chat/mysettings.cpp create mode 100644 gpt4all-chat/mysettings.h diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index d48f6b9a..3432d911 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -74,6 +74,7 @@ qt_add_executable(chat localdocs.h localdocs.cpp localdocsmodel.h localdocsmodel.cpp llm.h llm.cpp modellist.h modellist.cpp + mysettings.h mysettings.cpp network.h network.cpp server.h server.cpp logger.h logger.cpp diff --git a/gpt4all-chat/chatlistmodel.cpp b/gpt4all-chat/chatlistmodel.cpp index afaa8f8f..8bac9013 100644 --- a/gpt4all-chat/chatlistmodel.cpp +++ b/gpt4all-chat/chatlistmodel.cpp @@ -306,4 +306,4 @@ void ChatListModel::handleServerEnabledChanged() Chat *nextChat = get(0); Q_ASSERT(nextChat && nextChat != m_serverChat); setCurrentChat(nextChat); -} \ No newline at end of file +} diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 544949e3..12fe08f3 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -3,6 +3,7 @@ #include "chatgpt.h" #include "modellist.h" #include "network.h" +#include "mysettings.h" #include "../gpt4all-backend/llmodel.h" #include @@ -73,6 +74,8 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer) , m_stopGenerating(false) , m_timer(nullptr) , m_isServer(isServer) + , m_forceMetal(MySettings::globalInstance()->forceMetal()) + , m_reloadingToChangeVariant(false) { moveToThread(&m_llmThread); connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup); @@ -81,6 +84,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer) Qt::QueuedConnection); // explicitly queued connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged); connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted); + connect(MySettings::globalInstance(), &MySettings::forceMetalChanged, this, &ChatLLM::handleForceMetalChanged); // The following are blocking operations and will block the llm thread connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB, @@ -110,6 +114,19 @@ void ChatLLM::handleThreadStarted() emit threadStarted(); } +void ChatLLM::handleForceMetalChanged(bool forceMetal) +{ +#if defined(Q_OS_MAC) && defined(__arm__) + m_forceMetal = forceMetal; + if (isModelLoaded() && m_shouldBeLoaded) { + m_reloadingToChangeVariant = true; + unloadModel(); + reloadModel(); + m_reloadingToChangeVariant = false; + } +#endif +} + bool ChatLLM::loadDefaultModel() { ModelInfo defaultModel = ModelList::globalInstance()->defaultModelInfo(); @@ -154,7 +171,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) // returned to it, then the modelInfo.model pointer should be null which will happen on startup m_llModelInfo = LLModelStore::globalInstance()->acquireModel(); #if defined(DEBUG_MODEL_LOADING) - qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo3.model; + qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model; #endif // At this point it is possible that while we were blocked waiting to acquire the model from the // store, that our state was changed to not be loaded. If this is the case, release the model @@ -170,7 +187,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) } // Check if the store just gave us exactly the model we were looking for - if (m_llModelInfo.model && m_llModelInfo.fileInfo == fileInfo) { + if (m_llModelInfo.model && m_llModelInfo.fileInfo == fileInfo && !m_reloadingToChangeVariant) { #if defined(DEBUG_MODEL_LOADING) qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model; #endif @@ -210,7 +227,16 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) model->setAPIKey(apiKey); m_llModelInfo.model = model; } else { - m_llModelInfo.model = LLModel::construct(filePath.toStdString()); + +#if defined(Q_OS_MAC) && defined(__arm__) + if (m_forceMetal) + m_llModelInfo.model = LLModel::construct(filePath.toStdString(), "metal"); + else + m_llModelInfo.model = LLModel::construct(filePath.toStdString(), "auto"); +#else + m_llModelInfo.model = LLModel::construct(filePath.toStdString(), "auto"); +#endif + if (m_llModelInfo.model) { bool success = m_llModelInfo.model->loadModel(filePath.toStdString()); if (!success) { diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index 6f9c8ea6..9a1c17b5 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -110,6 +110,7 @@ public Q_SLOTS: void handleChatIdChanged(const QString &id); void handleShouldBeLoadedChanged(); void handleThreadStarted(); + void handleForceMetalChanged(bool forceMetal); Q_SIGNALS: void recalcChanged(); @@ -157,6 +158,8 @@ private: std::atomic m_shouldBeLoaded; std::atomic m_isRecalc; bool m_isServer; + bool m_forceMetal; + bool m_reloadingToChangeVariant; }; #endif // CHATLLM_H diff --git a/gpt4all-chat/llm.h b/gpt4all-chat/llm.h index 716f6246..c226b559 100644 --- a/gpt4all-chat/llm.h +++ b/gpt4all-chat/llm.h @@ -13,6 +13,7 @@ class LLM : public QObject public: static LLM *globalInstance(); + // FIXME: Move all settings to the new settings singleton int32_t threadCount() const; void setThreadCount(int32_t n_threads); bool serverEnabled() const; diff --git a/gpt4all-chat/main.cpp b/gpt4all-chat/main.cpp index 0d81e87b..5d16def3 100644 --- a/gpt4all-chat/main.cpp +++ b/gpt4all-chat/main.cpp @@ -11,6 +11,7 @@ #include "localdocs.h" #include "download.h" #include "network.h" +#include "mysettings.h" #include "config.h" #include "logger.h" @@ -26,6 +27,7 @@ int main(int argc, char *argv[]) QGuiApplication app(argc, argv); QQmlApplicationEngine engine; + qmlRegisterSingletonInstance("mysettings", 1, 0, "MySettings", MySettings::globalInstance()); qmlRegisterSingletonInstance("modellist", 1, 0, "ModelList", ModelList::globalInstance()); qmlRegisterSingletonInstance("chatlistmodel", 1, 0, "ChatListModel", ChatListModel::globalInstance()); qmlRegisterSingletonInstance("llm", 1, 0, "LLM", LLM::globalInstance()); diff --git a/gpt4all-chat/mysettings.cpp b/gpt4all-chat/mysettings.cpp new file mode 100644 index 00000000..b0bcebc4 --- /dev/null +++ b/gpt4all-chat/mysettings.cpp @@ -0,0 +1,31 @@ +#include "mysettings.h" + +#include + +class MyPrivateSettings: public MySettings { }; +Q_GLOBAL_STATIC(MyPrivateSettings, settingsInstance) +MySettings *MySettings::globalInstance() +{ + return settingsInstance(); +} + +MySettings::MySettings() + : QObject{nullptr} +{ + QSettings settings; + settings.sync(); + m_forceMetal = settings.value("forceMetal", false).toBool(); +} + +bool MySettings::forceMetal() const +{ + return m_forceMetal; +} + +void MySettings::setForceMetal(bool enabled) +{ + if (m_forceMetal == enabled) + return; + m_forceMetal = enabled; + emit forceMetalChanged(enabled); +} diff --git a/gpt4all-chat/mysettings.h b/gpt4all-chat/mysettings.h new file mode 100644 index 00000000..b2a3c951 --- /dev/null +++ b/gpt4all-chat/mysettings.h @@ -0,0 +1,30 @@ +#ifndef MYSETTINGS_H +#define MYSETTINGS_H + +#include +#include + +class MySettings : public QObject +{ + Q_OBJECT + Q_PROPERTY(bool forceMetal READ forceMetal WRITE setForceMetal NOTIFY forceMetalChanged) + +public: + static MySettings *globalInstance(); + + bool forceMetal() const; + void setForceMetal(bool enabled); + +Q_SIGNALS: + void forceMetalChanged(bool); + +private: + bool m_forceMetal; + +private: + explicit MySettings(); + ~MySettings() {} + friend class MyPrivateSettings; +}; + +#endif // MYSETTINGS_H diff --git a/gpt4all-chat/qml/SettingsDialog.qml b/gpt4all-chat/qml/SettingsDialog.qml index 26bc8405..4a9d50e3 100644 --- a/gpt4all-chat/qml/SettingsDialog.qml +++ b/gpt4all-chat/qml/SettingsDialog.qml @@ -10,6 +10,7 @@ import download import modellist import network import llm +import mysettings Dialog { id: settingsDialog @@ -51,6 +52,7 @@ Dialog { ### Assistant:\n" property string defaultModelPath: ModelList.defaultLocalModelsPath() property string defaultUserDefaultModel: "Application default" + property bool defaultForceMetal: false property alias temperature: settings.temperature property alias topP: settings.topP @@ -66,6 +68,7 @@ Dialog { property alias serverChat: settings.serverChat property alias modelPath: settings.modelPath property alias userDefaultModel: settings.userDefaultModel + property alias forceMetal: settings.forceMetal Settings { id: settings @@ -83,6 +86,7 @@ Dialog { property string promptTemplate: settingsDialog.defaultPromptTemplate property string modelPath: settingsDialog.defaultModelPath property string userDefaultModel: settingsDialog.defaultUserDefaultModel + property bool forceMetal: settingsDialog.defaultForceMetal } function restoreGenerationDefaults() { @@ -109,6 +113,7 @@ Dialog { LLM.serverEnabled = settings.serverChat ChatListModel.shouldSaveChats = settings.saveChats ChatListModel.shouldSaveChatGPTChats = settings.saveChatGPTChats + MySettings.forceMetal = settings.forceMetal settings.sync() } @@ -118,6 +123,7 @@ Dialog { ChatListModel.shouldSaveChats = settings.saveChats ChatListModel.shouldSaveChatGPTChats = settings.saveChatGPTChats ModelList.localModelsPath = settings.modelPath + MySettings.forceMetal = settings.forceMetal } Connections { @@ -811,9 +817,11 @@ Dialog { Layout.columnSpan: 2 MyCheckBox { id: gpuOverrideBox - checked: false + checked: settings.forceMetal onClicked: { - // fixme + settingsDialog.forceMetal = gpuOverrideBox.checked + MySettings.forceMetal = gpuOverrideBox.checked + settings.sync() } } Label { @@ -822,7 +830,7 @@ Dialog { Layout.alignment: Qt.AlignTop color: theme.textErrorColor wrapMode: Text.WordWrap - text: qsTr("WARNING: This setting forces usage of the GPU if it is detected. Can cause a crash if the model requires more RAM than the OS + GPU supports.") + text: qsTr("WARNING: On macOS with arm architecture (M1+), this setting forces usage of the GPU if it is detected. Can cause a crash if the model requires more RAM than the OS + GPU supports. Setting has no effect on non-macs or intel macs.") } } MyButton {