From 0ea31487e3cacd28915028146e9e2ef7132b07ea Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Tue, 11 Apr 2023 08:29:55 -0400 Subject: [PATCH] Programmatically get the model name from the LLM. The LLM now searches for applicable models in the directory of the executable given a pattern match and then loads the first one it finds. Also, add a busy indicator for model loading. --- llm.cpp | 34 +++++++++++++++++++++++++++++----- llm.h | 8 ++++++++ main.qml | 52 +++++++++++++++++++++++++++++++++++----------------- 3 files changed, 72 insertions(+), 22 deletions(-) diff --git a/llm.cpp b/llm.cpp index 728ea96d..e16b4913 100644 --- a/llm.cpp +++ b/llm.cpp @@ -31,16 +31,29 @@ bool GPTJObject::loadModel() if (isModelLoaded()) return true; - QString modelName("ggml-model-q4_0.bin"); - QString fileName = QCoreApplication::applicationDirPath() + QDir::separator() + modelName; - QFile file(fileName); - if (file.exists()) { + QDir dir(QCoreApplication::applicationDirPath()); + dir.setNameFilters(QStringList() << "ggml-*.bin"); + QStringList fileNames = dir.entryList(); + if (fileNames.isEmpty()) { + qDebug() << "ERROR: Could not find any applicable models in directory" + << QCoreApplication::applicationDirPath(); + } + + QString modelName = fileNames.first(); + QString filePath = QCoreApplication::applicationDirPath() + QDir::separator() + modelName; + QFileInfo info(filePath); + if (info.exists()) { - auto fin = std::ifstream(fileName.toStdString(), std::ios::binary); + auto fin = std::ifstream(filePath.toStdString(), std::ios::binary); m_gptj->loadModel(modelName.toStdString(), fin); emit isModelLoadedChanged(); } + if (m_gptj) { + m_modelName = info.baseName().remove(0, 5); // remove the ggml- prefix + emit modelNameChanged(); + } + return m_gptj; } @@ -64,6 +77,11 @@ QString GPTJObject::response() const return QString::fromStdString(m_response); } +QString GPTJObject::modelName() const +{ + return m_modelName; +} + bool GPTJObject::handleResponse(const std::string &response) { #if 0 @@ -97,6 +115,7 @@ LLM::LLM() connect(m_gptj, &GPTJObject::responseChanged, this, &LLM::responseChanged, Qt::QueuedConnection); connect(m_gptj, &GPTJObject::responseStarted, this, &LLM::responseStarted, Qt::QueuedConnection); connect(m_gptj, &GPTJObject::responseStopped, this, &LLM::responseStopped, Qt::QueuedConnection); + connect(m_gptj, &GPTJObject::modelNameChanged, this, &LLM::modelNameChanged, Qt::QueuedConnection); connect(this, &LLM::promptRequested, m_gptj, &GPTJObject::prompt, Qt::QueuedConnection); connect(this, &LLM::resetResponseRequested, m_gptj, &GPTJObject::resetResponse, Qt::BlockingQueuedConnection); @@ -145,6 +164,11 @@ void LLM::responseStopped() emit responseInProgressChanged(); } +QString LLM::modelName() const +{ + return m_gptj->modelName(); +} + bool LLM::checkForUpdates() const { #if defined(Q_OS_LINUX) diff --git a/llm.h b/llm.h index 35eb11a0..808fa886 100644 --- a/llm.h +++ b/llm.h @@ -10,6 +10,7 @@ class GPTJObject : public QObject Q_OBJECT Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged) Q_PROPERTY(QString response READ response NOTIFY responseChanged) + Q_PROPERTY(QString modelName READ modelName NOTIFY modelNameChanged) public: @@ -22,6 +23,7 @@ public: void stopGenerating() { m_stopGenerating = true; } QString response() const; + QString modelName() const; public Q_SLOTS: bool prompt(const QString &prompt); @@ -31,6 +33,7 @@ Q_SIGNALS: void responseChanged(); void responseStarted(); void responseStopped(); + void modelNameChanged(); private: bool handleResponse(const std::string &response); @@ -38,6 +41,7 @@ private: private: GPTJ *m_gptj; std::string m_response; + QString m_modelName; QThread m_llmThread; std::atomic m_stopGenerating; }; @@ -47,6 +51,7 @@ class LLM : public QObject Q_OBJECT Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged) Q_PROPERTY(QString response READ response NOTIFY responseChanged) + Q_PROPERTY(QString modelName READ modelName NOTIFY modelNameChanged) Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged) public: @@ -61,6 +66,8 @@ public: QString response() const; bool responseInProgress() const { return m_responseInProgress; } + QString modelName() const; + Q_INVOKABLE bool checkForUpdates() const; Q_SIGNALS: @@ -70,6 +77,7 @@ Q_SIGNALS: void promptRequested(const QString &prompt); void resetResponseRequested(); void resetContextRequested(); + void modelNameChanged(); private Q_SLOTS: void responseStarted(); diff --git a/main.qml b/main.qml index a357858d..e5251297 100644 --- a/main.qml +++ b/main.qml @@ -11,30 +11,48 @@ Window { title: qsTr("GPT4All Chat") color: "#d1d5db" - TextField { + Rectangle { id: header anchors.left: parent.left anchors.right: parent.right anchors.top: parent.top height: 100 - color: "#d1d5db" - padding: 20 - font.pixelSize: 24 - text: "GPT4ALL Model: gpt4all-j" - background: Rectangle { - color: "#202123" + color: "#202123" + + Item { + anchors.centerIn: parent + width: childrenRect.width + height: childrenRect.height + visible: LLM.isModelLoaded + + TextField { + id: modelNameField + color: "#d1d5db" + padding: 20 + font.pixelSize: 24 + text: "GPT4ALL Model: " + LLM.modelName + background: Rectangle { + color: "#202123" + } + focus: false + horizontalAlignment: TextInput.AlignHCenter + } + + Image { + anchors.left: modelNameField.right + anchors.verticalCenter: modelNameField.baseline + width: 50 + height: 65 + source: "qrc:/gpt4all-chat/icons/logo.svg" + z: 300 + } } - focus: false - horizontalAlignment: TextInput.AlignHCenter - } - Image { - anchors.verticalCenter: header.baseline - x: parent.width / 2 + 163 - width: 50 - height: 65 - source: "qrc:/gpt4all-chat/icons/logo.svg" - z: 300 + BusyIndicator { + anchors.centerIn: parent + visible: !LLM.isModelLoaded + running: !LLM.isModelLoaded + } } Button {