From c086a45173a74340ac724b5d26c0aea5c4ac9f53 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Sun, 23 Apr 2023 11:28:17 -0400 Subject: [PATCH] Provide a non-priviledged place for model downloads when exe is installed to root. --- download.cpp | 35 +++- download.h | 1 + llm.cpp | 88 +++++++-- qml/ModelDownloaderDialog.qml | 357 ++++++++++++++++++---------------- 4 files changed, 286 insertions(+), 195 deletions(-) diff --git a/download.cpp b/download.cpp index c398f0ac..459d2ae6 100644 --- a/download.cpp +++ b/download.cpp @@ -8,6 +8,7 @@ #include #include #include +#include class MyDownload: public Download { }; Q_GLOBAL_STATIC(MyDownload, downloadInstance) @@ -38,6 +39,26 @@ QList Download::modelList() const return values; } +QString Download::downloadLocalModelsPath() const +{ + QString exePath = QCoreApplication::applicationDirPath() + QDir::separator(); + QFileInfo infoExe(exePath); + if (infoExe.isWritable()) + return exePath; + + QString localPath = QStandardPaths::writableLocation(QStandardPaths::AppLocalDataLocation); + QDir localDir(localPath); + if (!localDir.exists()) + localDir.mkpath(localPath); + QString localDownloadPath = localPath + + QDir::separator(); + QFileInfo infoLocal(localDownloadPath); + if (infoLocal.isWritable()) + return localDownloadPath; + qWarning() << "ERROR: Local download path appears not writeable:" << localDownloadPath; + return localDownloadPath; +} + void Download::updateModelList() { QUrl jsonUrl("http://gpt4all.io/models/models.json"); @@ -143,7 +164,7 @@ void Download::parseJsonFile(const QByteArray &jsonData) modelFilesize = QString("%1 GB").arg(qreal(sz) / (1024 * 1024 * 1024), 0, 'g', 3); } - QString filePath = QCoreApplication::applicationDirPath() + QDir::separator() + modelFilename; + QString filePath = downloadLocalModelsPath() + modelFilename; QFileInfo info(filePath); ModelInfo modelInfo; modelInfo.filename = modelFilename; @@ -164,7 +185,6 @@ void Download::handleDownloadProgress(qint64 bytesReceived, qint64 bytesTotal) return; QString modelFilename = modelReply->url().fileName(); -// qDebug() << "handleDownloadProgress" << bytesReceived << bytesTotal << modelFilename; emit downloadProgress(bytesReceived, bytesTotal, modelFilename); } @@ -179,7 +199,6 @@ void Download::handleModelDownloadFinished() return; QString modelFilename = modelReply->url().fileName(); -// qDebug() << "handleModelDownloadFinished" << modelFilename; m_activeDownloads.removeAll(modelReply); if (modelReply->error()) { @@ -210,10 +229,18 @@ void Download::handleModelDownloadFinished() } // Save the model file to disk - QFile file(QCoreApplication::applicationDirPath() + QDir::separator() + modelFilename); + QFile file(downloadLocalModelsPath() + modelFilename); if (file.open(QIODevice::WriteOnly)) { file.write(modelData); file.close(); + } else { + QFile::FileError error = file.error(); + qWarning() << "ERROR: Could not save model to location:" + << downloadLocalModelsPath() + modelFilename + << "failed with code" << error; + modelReply->deleteLater(); + emit downloadFinished(modelFilename); + return; } modelReply->deleteLater(); diff --git a/download.h b/download.h index 92f20cc6..0528abc1 100644 --- a/download.h +++ b/download.h @@ -36,6 +36,7 @@ public: Q_INVOKABLE void updateModelList(); Q_INVOKABLE void downloadModel(const QString &modelFile); Q_INVOKABLE void cancelDownload(const QString &modelFile); + Q_INVOKABLE QString downloadLocalModelsPath() const; private Q_SLOTS: void handleJsonDownloadFinished(); diff --git a/llm.cpp b/llm.cpp index 69e486d0..d951a085 100644 --- a/llm.cpp +++ b/llm.cpp @@ -17,6 +17,23 @@ LLM *LLM::globalInstance() static LLModel::PromptContext s_ctx; +static QString modelFilePath(const QString &modelName) +{ + QString appPath = QCoreApplication::applicationDirPath() + + QDir::separator() + "ggml-" + modelName + ".bin"; + QFileInfo infoAppPath(appPath); + if (infoAppPath.exists()) + return appPath; + + QString downloadPath = Download::globalInstance()->downloadLocalModelsPath() + + QDir::separator() + "ggml-" + modelName + ".bin"; + + QFileInfo infoLocalPath(downloadPath); + if (infoLocalPath.exists()) + return downloadPath; + return QString(); +} + LLMObject::LLMObject() : QObject{nullptr} , m_llmodel(nullptr) @@ -31,14 +48,15 @@ LLMObject::LLMObject() bool LLMObject::loadModel() { - if (modelList().isEmpty()) { + const QList models = modelList(); + if (models.isEmpty()) { // try again when we get a list of models connect(Download::globalInstance(), &Download::modelListChanged, this, &LLMObject::loadModel, Qt::SingleShotConnection); return false; } - return loadModelPrivate(modelList().first()); + return loadModelPrivate(models.first()); } bool LLMObject::loadModelPrivate(const QString &modelName) @@ -54,8 +72,7 @@ bool LLMObject::loadModelPrivate(const QString &modelName) } bool isGPTJ = false; - QString filePath = QCoreApplication::applicationDirPath() + QDir::separator() + - "ggml-" + modelName + ".bin"; + QString filePath = modelFilePath(modelName); QFileInfo info(filePath); if (info.exists()) { @@ -169,28 +186,57 @@ void LLMObject::modelNameChangeRequested(const QString &modelName) QList LLMObject::modelList() const { - QDir dir(QCoreApplication::applicationDirPath()); - dir.setNameFilters(QStringList() << "ggml-*.bin"); - QStringList fileNames = dir.entryList(); - if (fileNames.isEmpty()) { - qWarning() << "ERROR: Could not find any applicable models in directory" - << QCoreApplication::applicationDirPath(); - return QList(); + // Build a model list from exepath and from the localpath + QList list; + + QString exePath = QCoreApplication::applicationDirPath() + QDir::separator(); + QString localPath = Download::globalInstance()->downloadLocalModelsPath(); + + { + QDir dir(exePath); + dir.setNameFilters(QStringList() << "ggml-*.bin"); + QStringList fileNames = dir.entryList(); + for (QString f : fileNames) { + QString filePath = exePath + f; + QFileInfo info(filePath); + QString name = info.completeBaseName().remove(0, 5); + if (info.exists()) { + if (name == m_modelName) + list.prepend(name); + else + list.append(name); + } + } } - QList list; - for (QString f : fileNames) { - QString filePath = QCoreApplication::applicationDirPath() + QDir::separator() + f; - QFileInfo info(filePath); - QString name = info.completeBaseName().remove(0, 5); - if (info.exists()) { - if (name == m_modelName) - list.prepend(name); - else - list.append(name); + if (localPath != exePath) { + QDir dir(localPath); + dir.setNameFilters(QStringList() << "ggml-*.bin"); + QStringList fileNames = dir.entryList(); + for (QString f : fileNames) { + QString filePath = localPath + f; + QFileInfo info(filePath); + QString name = info.completeBaseName().remove(0, 5); + if (info.exists() && !list.contains(name)) { // don't allow duplicates + if (name == m_modelName) + list.prepend(name); + else + list.append(name); + } } } + if (list.isEmpty()) { + if (exePath != localPath) { + qWarning() << "ERROR: Could not find any applicable models in" + << exePath << "nor" << localPath; + } else { + qWarning() << "ERROR: Could not find any applicable models in" + << exePath; + } + return QList(); + } + return list; } diff --git a/qml/ModelDownloaderDialog.qml b/qml/ModelDownloaderDialog.qml index 1f8028a0..f1a64d1c 100644 --- a/qml/ModelDownloaderDialog.qml +++ b/qml/ModelDownloaderDialog.qml @@ -7,7 +7,7 @@ import llm Dialog { id: modelDownloaderDialog width: 1024 - height: 400 + height: 435 modal: true opacity: 0.9 closePolicy: LLM.modelList.length === 0 ? Popup.NoAutoClose : (Popup.CloseOnEscape | Popup.CloseOnPressOutside) @@ -28,7 +28,7 @@ Dialog { ColumnLayout { anchors.fill: parent anchors.margins: 20 - spacing: 10 + spacing: 30 Label { id: listLabel @@ -38,199 +38,216 @@ Dialog { color: theme.textColor } - ListView { - id: modelList + ScrollView { + id: scrollView + ScrollBar.vertical.policy: ScrollBar.AlwaysOn Layout.fillWidth: true Layout.fillHeight: true - model: Download.modelList clip: true - boundsBehavior: Flickable.StopAtBounds - - delegate: Item { - id: delegateItem - width: modelList.width - height: 70 - objectName: "delegateItem" - property bool downloading: false - Rectangle { - anchors.fill: parent - color: index % 2 === 0 ? theme.backgroundLight : theme.backgroundLighter - } - - Text { - id: modelName - objectName: "modelName" - property string filename: modelData.filename - text: filename.slice(5, filename.length - 4) - anchors.verticalCenter: parent.verticalCenter - anchors.left: parent.left - anchors.leftMargin: 10 - color: theme.textColor - Accessible.role: Accessible.Paragraph - Accessible.name: qsTr("Model file") - Accessible.description: qsTr("Model file to be downloaded") - } - - Text { - id: isDefault - text: qsTr("(default)") - visible: modelData.isDefault - anchors.verticalCenter: parent.verticalCenter - anchors.left: modelName.right - anchors.leftMargin: 10 - color: theme.textColor - Accessible.role: Accessible.Paragraph - Accessible.name: qsTr("Default file") - Accessible.description: qsTr("Whether the file is the default model") - } - Text { - text: modelData.filesize - anchors.verticalCenter: parent.verticalCenter - anchors.left: isDefault.visible ? isDefault.right : modelName.right - anchors.leftMargin: 10 - color: theme.textColor - Accessible.role: Accessible.Paragraph - Accessible.name: qsTr("File size") - Accessible.description: qsTr("The size of the file") - } - - Label { - id: speedLabel - anchors.verticalCenter: parent.verticalCenter - anchors.right: itemProgressBar.left - anchors.rightMargin: 10 - objectName: "speedLabel" - color: theme.textColor - text: "" - visible: downloading - Accessible.role: Accessible.Paragraph - Accessible.name: qsTr("Download speed") - Accessible.description: qsTr("Download speed in bytes/kilobytes/megabytes per second") - } - - ProgressBar { - id: itemProgressBar - objectName: "itemProgressBar" - anchors.verticalCenter: parent.verticalCenter - anchors.right: downloadButton.left - anchors.rightMargin: 10 - width: 100 - visible: downloading - Accessible.role: Accessible.ProgressBar - Accessible.name: qsTr("Download progressBar") - Accessible.description: qsTr("Shows the progress made in the download") - } + ListView { + id: modelList + model: Download.modelList + boundsBehavior: Flickable.StopAtBounds + + delegate: Item { + id: delegateItem + width: modelList.width + height: 70 + objectName: "delegateItem" + property bool downloading: false + Rectangle { + anchors.fill: parent + color: index % 2 === 0 ? theme.backgroundLight : theme.backgroundLighter + } - Label { - id: installedLabel - anchors.verticalCenter: parent.verticalCenter - anchors.right: parent.right - anchors.rightMargin: 15 - objectName: "installedLabel" - color: theme.textColor - text: qsTr("Already installed") - visible: modelData.installed - Accessible.role: Accessible.Paragraph - Accessible.name: text - Accessible.description: qsTr("Whether the file is already installed on your system") - } + Text { + id: modelName + objectName: "modelName" + property string filename: modelData.filename + text: filename.slice(5, filename.length - 4) + anchors.verticalCenter: parent.verticalCenter + anchors.left: parent.left + anchors.leftMargin: 10 + color: theme.textColor + Accessible.role: Accessible.Paragraph + Accessible.name: qsTr("Model file") + Accessible.description: qsTr("Model file to be downloaded") + } - Button { - id: downloadButton - text: downloading ? "Cancel" : "Download" - anchors.verticalCenter: parent.verticalCenter - anchors.right: parent.right - anchors.rightMargin: 10 - visible: !modelData.installed - padding: 10 - onClicked: { - if (!downloading) { - downloading = true; - Download.downloadModel(modelData.filename); - } else { - downloading = false; - Download.cancelDownload(modelData.filename); - } + Text { + id: isDefault + text: qsTr("(default)") + visible: modelData.isDefault + anchors.verticalCenter: parent.verticalCenter + anchors.left: modelName.right + anchors.leftMargin: 10 + color: theme.textColor + Accessible.role: Accessible.Paragraph + Accessible.name: qsTr("Default file") + Accessible.description: qsTr("Whether the file is the default model") } - background: Rectangle { - opacity: .5 - border.color: theme.backgroundLightest - border.width: 1 - radius: 10 - color: theme.backgroundLight + + Text { + text: modelData.filesize + anchors.verticalCenter: parent.verticalCenter + anchors.left: isDefault.visible ? isDefault.right : modelName.right + anchors.leftMargin: 10 + color: theme.textColor + Accessible.role: Accessible.Paragraph + Accessible.name: qsTr("File size") + Accessible.description: qsTr("The size of the file") } - Accessible.role: Accessible.Button - Accessible.name: text - Accessible.description: qsTr("Cancel/Download button to stop/start the download") - } - } + Label { + id: speedLabel + anchors.verticalCenter: parent.verticalCenter + anchors.right: itemProgressBar.left + anchors.rightMargin: 10 + objectName: "speedLabel" + color: theme.textColor + text: "" + visible: downloading + Accessible.role: Accessible.Paragraph + Accessible.name: qsTr("Download speed") + Accessible.description: qsTr("Download speed in bytes/kilobytes/megabytes per second") + } - Component.onCompleted: { - Download.downloadProgress.connect(updateProgress); - Download.downloadFinished.connect(resetProgress); - } + ProgressBar { + id: itemProgressBar + objectName: "itemProgressBar" + anchors.verticalCenter: parent.verticalCenter + anchors.right: downloadButton.left + anchors.rightMargin: 10 + width: 100 + visible: downloading + Accessible.role: Accessible.ProgressBar + Accessible.name: qsTr("Download progressBar") + Accessible.description: qsTr("Shows the progress made in the download") + } - property var lastUpdate: ({}) + Label { + id: installedLabel + anchors.verticalCenter: parent.verticalCenter + anchors.right: parent.right + anchors.rightMargin: 15 + objectName: "installedLabel" + color: theme.textColor + text: qsTr("Already installed") + visible: modelData.installed + Accessible.role: Accessible.Paragraph + Accessible.name: text + Accessible.description: qsTr("Whether the file is already installed on your system") + } - function updateProgress(bytesReceived, bytesTotal, modelName) { - let currentTime = new Date().getTime(); + Button { + id: downloadButton + text: downloading ? "Cancel" : "Download" + anchors.verticalCenter: parent.verticalCenter + anchors.right: parent.right + anchors.rightMargin: 10 + visible: !modelData.installed + padding: 10 + onClicked: { + if (!downloading) { + downloading = true; + Download.downloadModel(modelData.filename); + } else { + downloading = false; + Download.cancelDownload(modelData.filename); + } + } + background: Rectangle { + opacity: .5 + border.color: theme.backgroundLightest + border.width: 1 + radius: 10 + color: theme.backgroundLight + } + Accessible.role: Accessible.Button + Accessible.name: text + Accessible.description: qsTr("Cancel/Download button to stop/start the download") - for (let i = 0; i < modelList.contentItem.children.length; i++) { - let delegateItem = modelList.contentItem.children[i]; - if (delegateItem.objectName === "delegateItem") { - let modelNameText = delegateItem.children.find(child => child.objectName === "modelName").filename; - if (modelNameText === modelName) { - let progressBar = delegateItem.children.find(child => child.objectName === "itemProgressBar"); - progressBar.value = bytesReceived / bytesTotal; + } + } - // Calculate the download speed - if (lastUpdate[modelName] && lastUpdate[modelName].timestamp) { - let timeDifference = currentTime - lastUpdate[modelName].timestamp; - let bytesDifference = bytesReceived - lastUpdate[modelName].bytesReceived; - let speed = (bytesDifference / timeDifference) * 1000; // bytes per second + Component.onCompleted: { + Download.downloadProgress.connect(updateProgress); + Download.downloadFinished.connect(resetProgress); + } - // Update the speed label - let speedLabel = delegateItem.children.find(child => child.objectName === "speedLabel"); - if (speed < 1024) { - speedLabel.text = speed.toFixed(2) + " B/s"; - } else if (speed < 1024 * 1024) { - speedLabel.text = (speed / 1024).toFixed(2) + " KB/s"; - } else { - speedLabel.text = (speed / (1024 * 1024)).toFixed(2) + " MB/s"; + property var lastUpdate: ({}) + + function updateProgress(bytesReceived, bytesTotal, modelName) { + let currentTime = new Date().getTime(); + + for (let i = 0; i < modelList.contentItem.children.length; i++) { + let delegateItem = modelList.contentItem.children[i]; + if (delegateItem.objectName === "delegateItem") { + let modelNameText = delegateItem.children.find(child => child.objectName === "modelName").filename; + if (modelNameText === modelName) { + let progressBar = delegateItem.children.find(child => child.objectName === "itemProgressBar"); + progressBar.value = bytesReceived / bytesTotal; + + // Calculate the download speed + if (lastUpdate[modelName] && lastUpdate[modelName].timestamp) { + let timeDifference = currentTime - lastUpdate[modelName].timestamp; + let bytesDifference = bytesReceived - lastUpdate[modelName].bytesReceived; + let speed = (bytesDifference / timeDifference) * 1000; // bytes per second + + // Update the speed label + let speedLabel = delegateItem.children.find(child => child.objectName === "speedLabel"); + if (speed < 1024) { + speedLabel.text = speed.toFixed(2) + " B/s"; + } else if (speed < 1024 * 1024) { + speedLabel.text = (speed / 1024).toFixed(2) + " KB/s"; + } else { + speedLabel.text = (speed / (1024 * 1024)).toFixed(2) + " MB/s"; + } } - } - // Update the lastUpdate object for the current model - lastUpdate[modelName] = {"timestamp": currentTime, "bytesReceived": bytesReceived}; - break; + // Update the lastUpdate object for the current model + lastUpdate[modelName] = {"timestamp": currentTime, "bytesReceived": bytesReceived}; + break; + } } } } - } - function resetProgress(modelName) { - for (let i = 0; i < modelList.contentItem.children.length; i++) { - let delegateItem = modelList.contentItem.children[i]; - if (delegateItem.objectName === "delegateItem") { - let modelNameText = delegateItem.children.find(child => child.objectName === "modelName").filename; - if (modelNameText === modelName) { - let progressBar = delegateItem.children.find(child => child.objectName === "itemProgressBar"); - progressBar.value = 0; - delegateItem.downloading = false; - - // Remove speed label text - let speedLabel = delegateItem.children.find(child => child.objectName === "speedLabel"); - speedLabel.text = ""; - - // Remove the lastUpdate object for the canceled model - delete lastUpdate[modelName]; - break; + function resetProgress(modelName) { + for (let i = 0; i < modelList.contentItem.children.length; i++) { + let delegateItem = modelList.contentItem.children[i]; + if (delegateItem.objectName === "delegateItem") { + let modelNameText = delegateItem.children.find(child => child.objectName === "modelName").filename; + if (modelNameText === modelName) { + let progressBar = delegateItem.children.find(child => child.objectName === "itemProgressBar"); + progressBar.value = 0; + delegateItem.downloading = false; + + // Remove speed label text + let speedLabel = delegateItem.children.find(child => child.objectName === "speedLabel"); + speedLabel.text = ""; + + // Remove the lastUpdate object for the canceled model + delete lastUpdate[modelName]; + break; + } } } } } } + + Label { + Layout.alignment: Qt.AlignLeft + Layout.fillWidth: true + text: qsTr("NOTE: models will be downloaded to\n") + Download.downloadLocalModelsPath() + wrapMode: Text.WrapAnywhere + horizontalAlignment: Text.AlignHCenter + color: theme.textColor + Accessible.role: Accessible.Paragraph + Accessible.name: qsTr("Model download path") + Accessible.description: qsTr("The path where downloaded models will be saved.") + } } }