Enable the force metal setting.

This commit is contained in:
Adam Treat 2023-06-27 11:54:34 -04:00 committed by AT
parent 2565f6a94a
commit 267601d670
9 changed files with 109 additions and 7 deletions

View File

@ -74,6 +74,7 @@ qt_add_executable(chat
localdocs.h localdocs.cpp localdocsmodel.h localdocsmodel.cpp localdocs.h localdocs.cpp localdocsmodel.h localdocsmodel.cpp
llm.h llm.cpp llm.h llm.cpp
modellist.h modellist.cpp modellist.h modellist.cpp
mysettings.h mysettings.cpp
network.h network.cpp network.h network.cpp
server.h server.cpp server.h server.cpp
logger.h logger.cpp logger.h logger.cpp

View File

@ -3,6 +3,7 @@
#include "chatgpt.h" #include "chatgpt.h"
#include "modellist.h" #include "modellist.h"
#include "network.h" #include "network.h"
#include "mysettings.h"
#include "../gpt4all-backend/llmodel.h" #include "../gpt4all-backend/llmodel.h"
#include <QCoreApplication> #include <QCoreApplication>
@ -73,6 +74,8 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
, m_stopGenerating(false) , m_stopGenerating(false)
, m_timer(nullptr) , m_timer(nullptr)
, m_isServer(isServer) , m_isServer(isServer)
, m_forceMetal(MySettings::globalInstance()->forceMetal())
, m_reloadingToChangeVariant(false)
{ {
moveToThread(&m_llmThread); moveToThread(&m_llmThread);
connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup); connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup);
@ -81,6 +84,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
Qt::QueuedConnection); // explicitly queued Qt::QueuedConnection); // explicitly queued
connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged); connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted); connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted);
connect(MySettings::globalInstance(), &MySettings::forceMetalChanged, this, &ChatLLM::handleForceMetalChanged);
// The following are blocking operations and will block the llm thread // The following are blocking operations and will block the llm thread
connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB, connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB,
@ -110,6 +114,19 @@ void ChatLLM::handleThreadStarted()
emit threadStarted(); emit threadStarted();
} }
void ChatLLM::handleForceMetalChanged(bool forceMetal)
{
#if defined(Q_OS_MAC) && defined(__arm__)
m_forceMetal = forceMetal;
if (isModelLoaded() && m_shouldBeLoaded) {
m_reloadingToChangeVariant = true;
unloadModel();
reloadModel();
m_reloadingToChangeVariant = false;
}
#endif
}
bool ChatLLM::loadDefaultModel() bool ChatLLM::loadDefaultModel()
{ {
ModelInfo defaultModel = ModelList::globalInstance()->defaultModelInfo(); ModelInfo defaultModel = ModelList::globalInstance()->defaultModelInfo();
@ -154,7 +171,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
// returned to it, then the modelInfo.model pointer should be null which will happen on startup // returned to it, then the modelInfo.model pointer should be null which will happen on startup
m_llModelInfo = LLModelStore::globalInstance()->acquireModel(); m_llModelInfo = LLModelStore::globalInstance()->acquireModel();
#if defined(DEBUG_MODEL_LOADING) #if defined(DEBUG_MODEL_LOADING)
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo3.model; qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model;
#endif #endif
// At this point it is possible that while we were blocked waiting to acquire the model from the // At this point it is possible that while we were blocked waiting to acquire the model from the
// store, that our state was changed to not be loaded. If this is the case, release the model // store, that our state was changed to not be loaded. If this is the case, release the model
@ -170,7 +187,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
} }
// Check if the store just gave us exactly the model we were looking for // Check if the store just gave us exactly the model we were looking for
if (m_llModelInfo.model && m_llModelInfo.fileInfo == fileInfo) { if (m_llModelInfo.model && m_llModelInfo.fileInfo == fileInfo && !m_reloadingToChangeVariant) {
#if defined(DEBUG_MODEL_LOADING) #if defined(DEBUG_MODEL_LOADING)
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model; qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model;
#endif #endif
@ -210,7 +227,16 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
model->setAPIKey(apiKey); model->setAPIKey(apiKey);
m_llModelInfo.model = model; m_llModelInfo.model = model;
} else { } else {
m_llModelInfo.model = LLModel::construct(filePath.toStdString());
#if defined(Q_OS_MAC) && defined(__arm__)
if (m_forceMetal)
m_llModelInfo.model = LLModel::construct(filePath.toStdString(), "metal");
else
m_llModelInfo.model = LLModel::construct(filePath.toStdString(), "auto");
#else
m_llModelInfo.model = LLModel::construct(filePath.toStdString(), "auto");
#endif
if (m_llModelInfo.model) { if (m_llModelInfo.model) {
bool success = m_llModelInfo.model->loadModel(filePath.toStdString()); bool success = m_llModelInfo.model->loadModel(filePath.toStdString());
if (!success) { if (!success) {

View File

@ -110,6 +110,7 @@ public Q_SLOTS:
void handleChatIdChanged(const QString &id); void handleChatIdChanged(const QString &id);
void handleShouldBeLoadedChanged(); void handleShouldBeLoadedChanged();
void handleThreadStarted(); void handleThreadStarted();
void handleForceMetalChanged(bool forceMetal);
Q_SIGNALS: Q_SIGNALS:
void recalcChanged(); void recalcChanged();
@ -157,6 +158,8 @@ private:
std::atomic<bool> m_shouldBeLoaded; std::atomic<bool> m_shouldBeLoaded;
std::atomic<bool> m_isRecalc; std::atomic<bool> m_isRecalc;
bool m_isServer; bool m_isServer;
bool m_forceMetal;
bool m_reloadingToChangeVariant;
}; };
#endif // CHATLLM_H #endif // CHATLLM_H

View File

@ -13,6 +13,7 @@ class LLM : public QObject
public: public:
static LLM *globalInstance(); static LLM *globalInstance();
// FIXME: Move all settings to the new settings singleton
int32_t threadCount() const; int32_t threadCount() const;
void setThreadCount(int32_t n_threads); void setThreadCount(int32_t n_threads);
bool serverEnabled() const; bool serverEnabled() const;

View File

@ -11,6 +11,7 @@
#include "localdocs.h" #include "localdocs.h"
#include "download.h" #include "download.h"
#include "network.h" #include "network.h"
#include "mysettings.h"
#include "config.h" #include "config.h"
#include "logger.h" #include "logger.h"
@ -26,6 +27,7 @@ int main(int argc, char *argv[])
QGuiApplication app(argc, argv); QGuiApplication app(argc, argv);
QQmlApplicationEngine engine; QQmlApplicationEngine engine;
qmlRegisterSingletonInstance("mysettings", 1, 0, "MySettings", MySettings::globalInstance());
qmlRegisterSingletonInstance("modellist", 1, 0, "ModelList", ModelList::globalInstance()); qmlRegisterSingletonInstance("modellist", 1, 0, "ModelList", ModelList::globalInstance());
qmlRegisterSingletonInstance("chatlistmodel", 1, 0, "ChatListModel", ChatListModel::globalInstance()); qmlRegisterSingletonInstance("chatlistmodel", 1, 0, "ChatListModel", ChatListModel::globalInstance());
qmlRegisterSingletonInstance("llm", 1, 0, "LLM", LLM::globalInstance()); qmlRegisterSingletonInstance("llm", 1, 0, "LLM", LLM::globalInstance());

View File

@ -0,0 +1,31 @@
#include "mysettings.h"
#include <QSettings>
class MyPrivateSettings: public MySettings { };
Q_GLOBAL_STATIC(MyPrivateSettings, settingsInstance)
MySettings *MySettings::globalInstance()
{
return settingsInstance();
}
MySettings::MySettings()
: QObject{nullptr}
{
QSettings settings;
settings.sync();
m_forceMetal = settings.value("forceMetal", false).toBool();
}
bool MySettings::forceMetal() const
{
return m_forceMetal;
}
void MySettings::setForceMetal(bool enabled)
{
if (m_forceMetal == enabled)
return;
m_forceMetal = enabled;
emit forceMetalChanged(enabled);
}

30
gpt4all-chat/mysettings.h Normal file
View File

@ -0,0 +1,30 @@
#ifndef MYSETTINGS_H
#define MYSETTINGS_H
#include <QObject>
#include <QMutex>
class MySettings : public QObject
{
Q_OBJECT
Q_PROPERTY(bool forceMetal READ forceMetal WRITE setForceMetal NOTIFY forceMetalChanged)
public:
static MySettings *globalInstance();
bool forceMetal() const;
void setForceMetal(bool enabled);
Q_SIGNALS:
void forceMetalChanged(bool);
private:
bool m_forceMetal;
private:
explicit MySettings();
~MySettings() {}
friend class MyPrivateSettings;
};
#endif // MYSETTINGS_H

View File

@ -10,6 +10,7 @@ import download
import modellist import modellist
import network import network
import llm import llm
import mysettings
Dialog { Dialog {
id: settingsDialog id: settingsDialog
@ -51,6 +52,7 @@ Dialog {
### Assistant:\n" ### Assistant:\n"
property string defaultModelPath: ModelList.defaultLocalModelsPath() property string defaultModelPath: ModelList.defaultLocalModelsPath()
property string defaultUserDefaultModel: "Application default" property string defaultUserDefaultModel: "Application default"
property bool defaultForceMetal: false
property alias temperature: settings.temperature property alias temperature: settings.temperature
property alias topP: settings.topP property alias topP: settings.topP
@ -66,6 +68,7 @@ Dialog {
property alias serverChat: settings.serverChat property alias serverChat: settings.serverChat
property alias modelPath: settings.modelPath property alias modelPath: settings.modelPath
property alias userDefaultModel: settings.userDefaultModel property alias userDefaultModel: settings.userDefaultModel
property alias forceMetal: settings.forceMetal
Settings { Settings {
id: settings id: settings
@ -83,6 +86,7 @@ Dialog {
property string promptTemplate: settingsDialog.defaultPromptTemplate property string promptTemplate: settingsDialog.defaultPromptTemplate
property string modelPath: settingsDialog.defaultModelPath property string modelPath: settingsDialog.defaultModelPath
property string userDefaultModel: settingsDialog.defaultUserDefaultModel property string userDefaultModel: settingsDialog.defaultUserDefaultModel
property bool forceMetal: settingsDialog.defaultForceMetal
} }
function restoreGenerationDefaults() { function restoreGenerationDefaults() {
@ -109,6 +113,7 @@ Dialog {
LLM.serverEnabled = settings.serverChat LLM.serverEnabled = settings.serverChat
ChatListModel.shouldSaveChats = settings.saveChats ChatListModel.shouldSaveChats = settings.saveChats
ChatListModel.shouldSaveChatGPTChats = settings.saveChatGPTChats ChatListModel.shouldSaveChatGPTChats = settings.saveChatGPTChats
MySettings.forceMetal = settings.forceMetal
settings.sync() settings.sync()
} }
@ -118,6 +123,7 @@ Dialog {
ChatListModel.shouldSaveChats = settings.saveChats ChatListModel.shouldSaveChats = settings.saveChats
ChatListModel.shouldSaveChatGPTChats = settings.saveChatGPTChats ChatListModel.shouldSaveChatGPTChats = settings.saveChatGPTChats
ModelList.localModelsPath = settings.modelPath ModelList.localModelsPath = settings.modelPath
MySettings.forceMetal = settings.forceMetal
} }
Connections { Connections {
@ -811,9 +817,11 @@ Dialog {
Layout.columnSpan: 2 Layout.columnSpan: 2
MyCheckBox { MyCheckBox {
id: gpuOverrideBox id: gpuOverrideBox
checked: false checked: settings.forceMetal
onClicked: { onClicked: {
// fixme settingsDialog.forceMetal = gpuOverrideBox.checked
MySettings.forceMetal = gpuOverrideBox.checked
settings.sync()
} }
} }
Label { Label {
@ -822,7 +830,7 @@ Dialog {
Layout.alignment: Qt.AlignTop Layout.alignment: Qt.AlignTop
color: theme.textErrorColor color: theme.textErrorColor
wrapMode: Text.WordWrap wrapMode: Text.WordWrap
text: qsTr("WARNING: This setting forces usage of the GPU if it is detected. Can cause a crash if the model requires more RAM than the OS + GPU supports.") text: qsTr("WARNING: On macOS with arm architecture (M1+), this setting forces usage of the GPU if it is detected. Can cause a crash if the model requires more RAM than the OS + GPU supports. Setting has no effect on non-macs or intel macs.")
} }
} }
MyButton { MyButton {