From 2b1cae5a7ee9cba4a72d54b2c20947643749a619 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Tue, 18 Apr 2023 11:42:16 -0400 Subject: [PATCH] Allow unloading/loading/changing of models. --- llm.cpp | 91 ++++++++++++++++++++++++++++++++++++++++++++++---------- llm.h | 19 ++++++++++-- main.qml | 41 +++++++++++++++++++++---- 3 files changed, 126 insertions(+), 25 deletions(-) diff --git a/llm.cpp b/llm.cpp index db13807a..93c26d02 100644 --- a/llm.cpp +++ b/llm.cpp @@ -18,7 +18,7 @@ static LLModel::PromptContext s_ctx; LLMObject::LLMObject() : QObject{nullptr} - , m_llmodel(new GPTJ) + , m_llmodel(nullptr) , m_responseTokens(0) , m_responseLogits(0) { @@ -30,19 +30,24 @@ LLMObject::LLMObject() bool LLMObject::loadModel() { - if (isModelLoaded()) + return loadModelPrivate(modelList().first()); +} + +bool LLMObject::loadModelPrivate(const QString &modelName) +{ + if (isModelLoaded() && m_modelName == modelName) return true; - QDir dir(QCoreApplication::applicationDirPath()); - dir.setNameFilters(QStringList() << "ggml-*.bin"); - QStringList fileNames = dir.entryList(); - if (fileNames.isEmpty()) { - qDebug() << "ERROR: Could not find any applicable models in directory" - << QCoreApplication::applicationDirPath(); + if (isModelLoaded()) { + delete m_llmodel; + m_llmodel = nullptr; + emit isModelLoadedChanged(); } - QString modelName = fileNames.first(); - QString filePath = QCoreApplication::applicationDirPath() + QDir::separator() + modelName; + m_llmodel = new GPTJ; + + QString filePath = QCoreApplication::applicationDirPath() + QDir::separator() + + "ggml-" + modelName + ".bin"; QFileInfo info(filePath); if (info.exists()) { @@ -51,17 +56,15 @@ bool LLMObject::loadModel() emit isModelLoadedChanged(); } - if (m_llmodel) { - m_modelName = info.completeBaseName().remove(0, 5); // remove the ggml- prefix - emit modelNameChanged(); - } + if (m_llmodel) + setModelName(info.completeBaseName().remove(0, 5)); // remove the ggml- prefix return m_llmodel; } bool LLMObject::isModelLoaded() const { - return m_llmodel->isModelLoaded(); + return m_llmodel && m_llmodel->isModelLoaded(); } void LLMObject::regenerateResponse() @@ -119,6 +122,46 @@ QString LLMObject::modelName() const return m_modelName; } +void LLMObject::setModelName(const QString &modelName) +{ + m_modelName = modelName; + emit modelNameChanged(); + emit modelListChanged(); +} + +void LLMObject::modelNameChangeRequested(const QString &modelName) +{ + if (!loadModelPrivate(modelName)) + qWarning() << "ERROR: Could not load model" << modelName; +} + +QList LLMObject::modelList() const +{ + QDir dir(QCoreApplication::applicationDirPath()); + dir.setNameFilters(QStringList() << "ggml-*.bin"); + QStringList fileNames = dir.entryList(); + if (fileNames.isEmpty()) { + qWarning() << "ERROR: Could not find any applicable models in directory" + << QCoreApplication::applicationDirPath(); + return QList(); + } + + QList list; + for (QString f : fileNames) { + QString filePath = QCoreApplication::applicationDirPath() + QDir::separator() + f; + QFileInfo info(filePath); + QString name = info.completeBaseName().remove(0, 5); + if (info.exists()) { + if (name == m_modelName) + list.prepend(name); + else + list.append(name); + } + } + + return list; +} + bool LLMObject::handleResponse(const std::string &response) { #if 0 @@ -172,8 +215,12 @@ LLM::LLM() connect(m_llmodel, &LLMObject::responseStarted, this, &LLM::responseStarted, Qt::QueuedConnection); connect(m_llmodel, &LLMObject::responseStopped, this, &LLM::responseStopped, Qt::QueuedConnection); connect(m_llmodel, &LLMObject::modelNameChanged, this, &LLM::modelNameChanged, Qt::QueuedConnection); - + connect(m_llmodel, &LLMObject::modelListChanged, this, &LLM::modelListChanged, Qt::QueuedConnection); connect(this, &LLM::promptRequested, m_llmodel, &LLMObject::prompt, Qt::QueuedConnection); + connect(this, &LLM::modelNameChangeRequested, m_llmodel, &LLMObject::modelNameChangeRequested, Qt::QueuedConnection); + + // The following are blocking operations and will block the gui thread, therefore must be fast + // to respond to connect(this, &LLM::regenerateResponseRequested, m_llmodel, &LLMObject::regenerateResponse, Qt::BlockingQueuedConnection); connect(this, &LLM::resetResponseRequested, m_llmodel, &LLMObject::resetResponse, Qt::BlockingQueuedConnection); connect(this, &LLM::resetContextRequested, m_llmodel, &LLMObject::resetContext, Qt::BlockingQueuedConnection); @@ -232,6 +279,18 @@ QString LLM::modelName() const return m_llmodel->modelName(); } +void LLM::setModelName(const QString &modelName) +{ + // doesn't block but will unload old model and load new one which the gui can see through changes + // to the isModelLoaded property + emit modelNameChangeRequested(modelName); +} + +QList LLM::modelList() const +{ + return m_llmodel->modelList(); +} + bool LLM::checkForUpdates() const { #if defined(Q_OS_LINUX) diff --git a/llm.h b/llm.h index 33aa95c6..0e189e42 100644 --- a/llm.h +++ b/llm.h @@ -8,15 +8,15 @@ class LLMObject : public QObject { Q_OBJECT + Q_PROPERTY(QList modelList READ modelList NOTIFY modelListChanged) Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged) Q_PROPERTY(QString response READ response NOTIFY responseChanged) - Q_PROPERTY(QString modelName READ modelName NOTIFY modelNameChanged) + Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged) public: LLMObject(); - bool loadModel(); bool isModelLoaded() const; void regenerateResponse(); void resetResponse(); @@ -26,9 +26,14 @@ public: QString response() const; QString modelName() const; + QList modelList() const; + void setModelName(const QString &modelName); + public Q_SLOTS: bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch); + bool loadModel(); + void modelNameChangeRequested(const QString &modelName); Q_SIGNALS: void isModelLoadedChanged(); @@ -36,8 +41,10 @@ Q_SIGNALS: void responseStarted(); void responseStopped(); void modelNameChanged(); + void modelListChanged(); private: + bool loadModelPrivate(const QString &modelName); bool handleResponse(const std::string &response); private: @@ -53,9 +60,10 @@ private: class LLM : public QObject { Q_OBJECT + Q_PROPERTY(QList modelList READ modelList NOTIFY modelListChanged) Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged) Q_PROPERTY(QString response READ response NOTIFY responseChanged) - Q_PROPERTY(QString modelName READ modelName NOTIFY modelNameChanged) + Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged) Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged) public: @@ -72,7 +80,10 @@ public: QString response() const; bool responseInProgress() const { return m_responseInProgress; } + QList modelList() const; + QString modelName() const; + void setModelName(const QString &modelName); Q_INVOKABLE bool checkForUpdates() const; @@ -85,7 +96,9 @@ Q_SIGNALS: void regenerateResponseRequested(); void resetResponseRequested(); void resetContextRequested(); + void modelNameChangeRequested(const QString &modelName); void modelNameChanged(); + void modelListChanged(); private Q_SLOTS: void responseStarted(); diff --git a/main.qml b/main.qml index b341b480..f899816b 100644 --- a/main.qml +++ b/main.qml @@ -32,18 +32,47 @@ Window { visible: LLM.isModelLoaded Label { - id: modelNameField + id: modelLabel color: "#d1d5db" padding: 20 font.pixelSize: 24 - text: "GPT4ALL Model: " + LLM.modelName + text: "" background: Rectangle { color: "#202123" } - horizontalAlignment: TextInput.AlignHCenter - Accessible.role: Accessible.Heading - Accessible.name: text - Accessible.description: qsTr("Displays the model name that is currently loaded") + horizontalAlignment: TextInput.AlignRight + } + + ComboBox { + id: comboBox + width: 400 + anchors.top: modelLabel.top + anchors.bottom: modelLabel.bottom + anchors.horizontalCenter: parent.horizontalCenter + font.pixelSize: 24 + spacing: 0 + model: LLM.modelList + Accessible.role: Accessible.ComboBox + Accessible.name: qsTr("ComboBox for displaying/picking the current model") + Accessible.description: qsTr("Use this for picking the current model to use; the first item is the current model") + contentItem: Text { + anchors.horizontalCenter: parent.horizontalCenter + leftPadding: 10 + rightPadding: 10 + text: comboBox.displayText + font: comboBox.font + color: "#d1d5db" + verticalAlignment: Text.AlignVCenter + horizontalAlignment: Text.AlignHCenter + elide: Text.ElideRight + } + background: Rectangle { + color: "#242528" + } + + onActivated: { + LLM.modelName = comboBox.currentText + } } }