From 35e7503571d6405a71aaf82982c959bd6cb936e3 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Sun, 23 Apr 2023 19:43:20 -0400 Subject: [PATCH] Make the download use a temp file to save ram and make it threaded. --- download.cpp | 173 ++++++++++++++++++++++++++-------- download.h | 30 +++++- qml/ModelDownloaderDialog.qml | 32 ++++++- 3 files changed, 192 insertions(+), 43 deletions(-) diff --git a/download.cpp b/download.cpp index 459d2ae6..8e339bde 100644 --- a/download.cpp +++ b/download.cpp @@ -19,7 +19,12 @@ Download *Download::globalInstance() Download::Download() : QObject(nullptr) + , m_hashAndSave(new HashAndSaveFile) { + connect(this, &Download::requestHashAndSave, m_hashAndSave, + &HashAndSaveFile::hashAndSave, Qt::QueuedConnection); + connect(m_hashAndSave, &HashAndSaveFile::hashAndSaveFinished, this, + &Download::handleHashAndSaveFinished, Qt::QueuedConnection); updateModelList(); } @@ -69,17 +74,27 @@ void Download::updateModelList() void Download::downloadModel(const QString &modelFile) { + QTemporaryFile *tempFile = new QTemporaryFile; + bool success = tempFile->open(); + qWarning() << "Opening temp file for writing:" << tempFile->fileName(); + if (!success) { + qWarning() << "ERROR: Could not open temp file:" + << tempFile->fileName() << modelFile; + return; + } + QNetworkRequest request("http://gpt4all.io/models/" + modelFile); QNetworkReply *modelReply = m_networkManager.get(request); connect(modelReply, &QNetworkReply::downloadProgress, this, &Download::handleDownloadProgress); connect(modelReply, &QNetworkReply::finished, this, &Download::handleModelDownloadFinished); - m_activeDownloads.append(modelReply); + connect(modelReply, &QNetworkReply::readyRead, this, &Download::handleReadyRead); + m_activeDownloads.insert(modelReply, tempFile); } void Download::cancelDownload(const QString &modelFile) { for (int i = 0; i < m_activeDownloads.size(); ++i) { - QNetworkReply *modelReply = m_activeDownloads.at(i); + QNetworkReply *modelReply = m_activeDownloads.keys().at(i); QUrl url = modelReply->request().url(); if (url.toString().endsWith(modelFile)) { // Disconnect the signals @@ -88,7 +103,10 @@ void Download::cancelDownload(const QString &modelFile) modelReply->abort(); // Abort the download modelReply->deleteLater(); // Schedule the reply for deletion - m_activeDownloads.removeAll(modelReply); + + QTemporaryFile *tempFile = m_activeDownloads.value(modelReply); + tempFile->deleteLater(); + m_activeDownloads.remove(modelReply); // Emit downloadFinished signal for cleanup emit downloadFinished(modelFile); @@ -192,6 +210,74 @@ bool operator==(const ModelInfo& lhs, const ModelInfo& rhs) { return lhs.filename == rhs.filename && lhs.md5sum == rhs.md5sum; } +HashAndSaveFile::HashAndSaveFile() + : QObject(nullptr) +{ + moveToThread(&m_hashAndSaveThread); + m_hashAndSaveThread.setObjectName("hashandsave thread"); + m_hashAndSaveThread.start(); +} + +void HashAndSaveFile::hashAndSave(const QString &expectedHash, const QString &saveFilePath, + QTemporaryFile *tempFile, QNetworkReply *modelReply) +{ + Q_ASSERT(!tempFile->isOpen()); + QString modelFilename = modelReply->url().fileName(); + + // Reopen the tempFile for hashing + if (!tempFile->open()) { + qWarning() << "ERROR: Could not open temp file for hashing:" + << tempFile->fileName() << modelFilename; + emit hashAndSaveFinished(false, tempFile, modelReply); + return; + } + + QCryptographicHash hash(QCryptographicHash::Md5); + hash.addData(tempFile->readAll()); + while(!tempFile->atEnd()) + hash.addData(tempFile->read(16384)); + if (hash.result().toHex() != expectedHash) { + tempFile->close(); + qWarning() << "ERROR: Download error MD5SUM did not match:" + << hash.result().toHex() + << "!=" << expectedHash << "for" << modelFilename; + emit hashAndSaveFinished(false, tempFile, modelReply); + return; + } + + // The file save needs the tempFile closed + tempFile->close(); + + // Reopen the tempFile for copying + if (!tempFile->open()) { + qWarning() << "ERROR: Could not open temp file at finish:" + << tempFile->fileName() << modelFilename; + emit hashAndSaveFinished(false, tempFile, modelReply); + return; + } + + // Save the model file to disk + QFile file(saveFilePath); + if (file.open(QIODevice::WriteOnly)) { + QByteArray buffer; + while (!tempFile->atEnd()) { + buffer = tempFile->read(16384); + file.write(buffer); + } + file.close(); + tempFile->close(); + emit hashAndSaveFinished(true, tempFile, modelReply); + } else { + QFile::FileError error = file.error(); + qWarning() << "ERROR: Could not save model to location:" + << saveFilePath + << "failed with code" << error; + tempFile->close(); + emit hashAndSaveFinished(false, tempFile, modelReply); + return; + } +} + void Download::handleModelDownloadFinished() { QNetworkReply *modelReply = qobject_cast(sender()); @@ -199,54 +285,59 @@ void Download::handleModelDownloadFinished() return; QString modelFilename = modelReply->url().fileName(); - m_activeDownloads.removeAll(modelReply); + QTemporaryFile *tempFile = m_activeDownloads.value(modelReply); + m_activeDownloads.remove(modelReply); if (modelReply->error()) { qWarning() << "ERROR: downloading:" << modelReply->errorString(); modelReply->deleteLater(); + tempFile->deleteLater(); emit downloadFinished(modelFilename); return; } - QByteArray modelData = modelReply->readAll(); - if (!m_modelMap.contains(modelFilename)) { - qWarning() << "ERROR: Cannot find in download map:" << modelFilename; - modelReply->deleteLater(); - emit downloadFinished(modelFilename); - return; - } + // The hash and save needs the tempFile closed + tempFile->close(); + // Notify that we are calculating hash ModelInfo info = m_modelMap.value(modelFilename); - QCryptographicHash hash(QCryptographicHash::Md5); - hash.addData(modelData); - if (hash.result().toHex() != info.md5sum) { - qWarning() << "ERROR: Download error MD5SUM did not match:" - << hash.result().toHex() - << "!=" << info.md5sum << "for" << modelFilename; - modelReply->deleteLater(); - emit downloadFinished(modelFilename); - return; - } - - // Save the model file to disk - QFile file(downloadLocalModelsPath() + modelFilename); - if (file.open(QIODevice::WriteOnly)) { - file.write(modelData); - file.close(); - } else { - QFile::FileError error = file.error(); - qWarning() << "ERROR: Could not save model to location:" - << downloadLocalModelsPath() + modelFilename - << "failed with code" << error; - modelReply->deleteLater(); - emit downloadFinished(modelFilename); - return; - } - - modelReply->deleteLater(); - emit downloadFinished(modelFilename); - - info.installed = true; + info.calcHash = true; m_modelMap.insert(modelFilename, info); emit modelListChanged(); + + const QString saveFilePath = downloadLocalModelsPath() + modelFilename; + emit requestHashAndSave(info.md5sum, saveFilePath, tempFile, modelReply); +} + +void Download::handleHashAndSaveFinished(bool success, + QTemporaryFile *tempFile, QNetworkReply *modelReply) +{ + // The hash and save should send back with tempfile closed + Q_ASSERT(!tempFile->isOpen()); + QString modelFilename = modelReply->url().fileName(); + + ModelInfo info = m_modelMap.value(modelFilename); + info.calcHash = false; + info.installed = success; + m_modelMap.insert(modelFilename, info); + emit modelListChanged(); + + modelReply->deleteLater(); + tempFile->deleteLater(); + emit downloadFinished(modelFilename); +} + +void Download::handleReadyRead() +{ + QNetworkReply *modelReply = qobject_cast(sender()); + if (!modelReply) + return; + + QString modelFilename = modelReply->url().fileName(); + QTemporaryFile *tempFile = m_activeDownloads.value(modelReply); + QByteArray buffer; + while (!modelReply->atEnd()) { + buffer = modelReply->read(16384); + tempFile->write(buffer); + } } diff --git a/download.h b/download.h index 0528abc1..53d01b37 100644 --- a/download.h +++ b/download.h @@ -6,12 +6,15 @@ #include #include #include +#include +#include struct ModelInfo { Q_GADGET Q_PROPERTY(QString filename MEMBER filename) Q_PROPERTY(QString filesize MEMBER filesize) Q_PROPERTY(QByteArray md5sum MEMBER md5sum) + Q_PROPERTY(bool calcHash MEMBER calcHash) Q_PROPERTY(bool installed MEMBER installed) Q_PROPERTY(bool isDefault MEMBER isDefault) @@ -19,11 +22,30 @@ public: QString filename; QString filesize; QByteArray md5sum; + bool calcHash = false; bool installed = false; bool isDefault = false; }; Q_DECLARE_METATYPE(ModelInfo) +class HashAndSaveFile : public QObject +{ + Q_OBJECT +public: + HashAndSaveFile(); + +public Q_SLOTS: + void hashAndSave(const QString &hash, const QString &saveFilePath, + QTemporaryFile *tempFile, QNetworkReply *modelReply); + +Q_SIGNALS: + void hashAndSaveFinished(bool success, + QTemporaryFile *tempFile, QNetworkReply *modelReply); + +private: + QThread m_hashAndSaveThread; +}; + class Download : public QObject { Q_OBJECT @@ -42,18 +64,24 @@ private Q_SLOTS: void handleJsonDownloadFinished(); void handleDownloadProgress(qint64 bytesReceived, qint64 bytesTotal); void handleModelDownloadFinished(); + void handleHashAndSaveFinished(bool success, + QTemporaryFile *tempFile, QNetworkReply *modelReply); + void handleReadyRead(); Q_SIGNALS: void downloadProgress(qint64 bytesReceived, qint64 bytesTotal, const QString &modelFile); void downloadFinished(const QString &modelFile); void modelListChanged(); + void requestHashAndSave(const QString &hash, const QString &saveFilePath, + QTemporaryFile *tempFile, QNetworkReply *modelReply); private: void parseJsonFile(const QByteArray &jsonData); + HashAndSaveFile *m_hashAndSave; QMap m_modelMap; QNetworkAccessManager m_networkManager; - QVector m_activeDownloads; + QMap m_activeDownloads; private: explicit Download(); diff --git a/qml/ModelDownloaderDialog.qml b/qml/ModelDownloaderDialog.qml index f1a64d1c..95a2e77f 100644 --- a/qml/ModelDownloaderDialog.qml +++ b/qml/ModelDownloaderDialog.qml @@ -126,6 +126,36 @@ Dialog { Accessible.description: qsTr("Shows the progress made in the download") } + Item { + visible: modelData.calcHash + anchors.verticalCenter: parent.verticalCenter + anchors.right: parent.right + anchors.rightMargin: 10 + + Label { + id: calcHashLabel + anchors.right: busyCalcHash.left + anchors.rightMargin: 10 + anchors.verticalCenter: parent.verticalCenter + objectName: "calcHashLabel" + color: theme.textColor + text: qsTr("Calculating MD5...") + Accessible.role: Accessible.Paragraph + Accessible.name: text + Accessible.description: qsTr("Whether the file hash is being calculated") + } + + BusyIndicator { + id: busyCalcHash + anchors.right: parent.right + anchors.verticalCenter: calcHashLabel.verticalCenter + running: modelData.calcHash + Accessible.role: Accessible.Animation + Accessible.name: qsTr("Busy indicator") + Accessible.description: qsTr("Displayed when the file hash is being calculated") + } + } + Label { id: installedLabel anchors.verticalCenter: parent.verticalCenter @@ -146,7 +176,7 @@ Dialog { anchors.verticalCenter: parent.verticalCenter anchors.right: parent.right anchors.rightMargin: 10 - visible: !modelData.installed + visible: !modelData.installed && !modelData.calcHash padding: 10 onClicked: { if (!downloading) {