diff --git a/CMakeLists.txt b/CMakeLists.txt index 641936c7..793a9926 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,6 +33,7 @@ add_subdirectory(ggml) qt_add_executable(chat main.cpp + download.h download.cpp gptj.h gptj.cpp llm.h llm.cpp llmodel.h @@ -41,7 +42,7 @@ qt_add_executable(chat qt_add_qml_module(chat URI gpt4all-chat VERSION 1.0 - QML_FILES main.qml + QML_FILES main.qml qml/ModelDownloaderDialog.qml RESOURCES icons/send_message.svg icons/stop_generating.svg diff --git a/download.cpp b/download.cpp new file mode 100644 index 00000000..c95ca233 --- /dev/null +++ b/download.cpp @@ -0,0 +1,209 @@ +#include "download.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +class MyDownload: public Download { }; +Q_GLOBAL_STATIC(MyDownload, downloadInstance) +Download *Download::globalInstance() +{ + return downloadInstance(); +} + +Download::Download() + : QObject(nullptr) +{ + updateModelList(); +} + +QList Download::modelList() const +{ + // We make sure the default model is listed first + QList values = m_modelMap.values(); + ModelInfo defaultInfo; + for (ModelInfo v : values) { + if (v.isDefault) { + defaultInfo = v; + break; + } + } + values.removeAll(defaultInfo); + values.prepend(defaultInfo); + return values; +} + +void Download::updateModelList() +{ + QUrl jsonUrl("http://gpt4all.io/models/models.json"); + QNetworkRequest request(jsonUrl); + QNetworkReply *jsonReply = m_networkManager.get(request); + connect(jsonReply, &QNetworkReply::finished, this, &Download::handleJsonDownloadFinished); +} + +void Download::downloadModel(const QString &modelFile) +{ + QNetworkRequest request("http://gpt4all.io/models/" + modelFile); + QNetworkReply *modelReply = m_networkManager.get(request); + connect(modelReply, &QNetworkReply::downloadProgress, this, &Download::handleDownloadProgress); + connect(modelReply, &QNetworkReply::finished, this, &Download::handleModelDownloadFinished); + m_activeDownloads.append(modelReply); +} + +void Download::cancelDownload(const QString &modelFile) +{ + for (int i = 0; i < m_activeDownloads.size(); ++i) { + QNetworkReply *modelReply = m_activeDownloads.at(i); + QUrl url = modelReply->request().url(); + if (url.toString().endsWith(modelFile)) { + // Disconnect the signals + disconnect(modelReply, &QNetworkReply::downloadProgress, this, &Download::handleDownloadProgress); + disconnect(modelReply, &QNetworkReply::finished, this, &Download::handleModelDownloadFinished); + + modelReply->abort(); // Abort the download + modelReply->deleteLater(); // Schedule the reply for deletion + m_activeDownloads.removeAll(modelReply); + + // Emit downloadFinished signal for cleanup + emit downloadFinished(modelFile); + break; + } + } +} + +void Download::handleJsonDownloadFinished() +{ +#if 0 + QByteArray jsonData = QString("" + "[" + " {" + " \"md5sum\": \"61d48a82cb188cceb14ebb8082bfec37\"," + " \"filename\": \"ggml-gpt4all-j-v1.1-breezy.bin\"" + " }," + " {" + " \"md5sum\": \"879344aaa9d62fdccbda0be7a09e7976\"," + " \"filename\": \"ggml-gpt4all-j-v1.2-jazzy.bin\"," + " \"isDefault\": \"true\"" + " }," + " {" + " \"md5sum\": \"5b5a3f9b858d33b29b52b89692415595\"," + " \"filename\": \"ggml-gpt4all-j.bin\"" + " }" + "]" + ).toUtf8(); + printf("%s\n", jsonData.toStdString().c_str()); + fflush(stdout); +#else + QNetworkReply *jsonReply = qobject_cast(sender()); + if (!jsonReply) + return; + + QByteArray jsonData = jsonReply->readAll(); + jsonReply->deleteLater(); +#endif + parseJsonFile(jsonData); +} + +void Download::parseJsonFile(const QByteArray &jsonData) +{ + QJsonParseError err; + QJsonDocument document = QJsonDocument::fromJson(jsonData, &err); + if (err.error != QJsonParseError::NoError) { + qDebug() << "ERROR: Couldn't parse: " << jsonData << err.errorString(); + return; + } + + QJsonArray jsonArray = document.array(); + + m_modelMap.clear(); + for (const QJsonValue &value : jsonArray) { + QJsonObject obj = value.toObject(); + + QString modelFilename = obj["filename"].toString(); + QByteArray modelMd5sum = obj["md5sum"].toString().toLatin1().constData(); + bool isDefault = obj.contains("isDefault") && obj["isDefault"] == QString("true"); + + QString filePath = QCoreApplication::applicationDirPath() + QDir::separator() + modelFilename; + QFileInfo info(filePath); + ModelInfo modelInfo; + modelInfo.filename = modelFilename; + modelInfo.md5sum = modelMd5sum; + modelInfo.installed = info.exists(); + modelInfo.isDefault = isDefault; + m_modelMap.insert(modelInfo.filename, modelInfo); + } + + emit modelListChanged(); +} + +void Download::handleDownloadProgress(qint64 bytesReceived, qint64 bytesTotal) +{ + QNetworkReply *modelReply = qobject_cast(sender()); + if (!modelReply) + return; + + QString modelFilename = modelReply->url().fileName(); +// qDebug() << "handleDownloadProgress" << bytesReceived << bytesTotal << modelFilename; + emit downloadProgress(bytesReceived, bytesTotal, modelFilename); +} + +bool operator==(const ModelInfo& lhs, const ModelInfo& rhs) { + return lhs.filename == rhs.filename && lhs.md5sum == rhs.md5sum; +} + +void Download::handleModelDownloadFinished() +{ + QNetworkReply *modelReply = qobject_cast(sender()); + if (!modelReply) + return; + + QString modelFilename = modelReply->url().fileName(); +// qDebug() << "handleModelDownloadFinished" << modelFilename; + m_activeDownloads.removeAll(modelReply); + + if (modelReply->error()) { + qWarning() << "ERROR: downloading:" << modelReply->errorString(); + modelReply->deleteLater(); + emit downloadFinished(modelFilename); + return; + } + + QByteArray modelData = modelReply->readAll(); + if (!m_modelMap.contains(modelFilename)) { + qWarning() << "ERROR: Cannot find in download map:" << modelFilename; + modelReply->deleteLater(); + emit downloadFinished(modelFilename); + return; + } + + ModelInfo info = m_modelMap.value(modelFilename); + QCryptographicHash hash(QCryptographicHash::Md5); + hash.addData(modelData); + if (hash.result().toHex() != info.md5sum) { + qWarning() << "ERROR: Download error MD5SUM did not match:" + << hash.result().toHex() + << "!=" << info.md5sum << "for" << modelFilename; + modelReply->deleteLater(); + emit downloadFinished(modelFilename); + return; + } + + // Save the model file to disk + QFile file(QCoreApplication::applicationDirPath() + QDir::separator() + modelFilename); + if (file.open(QIODevice::WriteOnly)) { + file.write(modelData); + file.close(); + } + + modelReply->deleteLater(); + emit downloadFinished(modelFilename); + + info.installed = true; + m_modelMap.insert(modelFilename, info); + emit modelListChanged(); +} diff --git a/download.h b/download.h new file mode 100644 index 00000000..6d69e4ce --- /dev/null +++ b/download.h @@ -0,0 +1,62 @@ +#ifndef DOWNLOAD_H +#define DOWNLOAD_H + +#include +#include +#include +#include +#include +#include + +struct ModelInfo { + Q_GADGET + Q_PROPERTY(QString filename MEMBER filename) + Q_PROPERTY(QByteArray md5sum MEMBER md5sum) + Q_PROPERTY(bool installed MEMBER installed) + Q_PROPERTY(bool isDefault MEMBER isDefault) + +public: + QString filename; + QByteArray md5sum; + bool installed = false; + bool isDefault = false; +}; +Q_DECLARE_METATYPE(ModelInfo) + +class Download : public QObject +{ + Q_OBJECT + Q_PROPERTY(QList modelList READ modelList NOTIFY modelListChanged) + +public: + static Download *globalInstance(); + + QList modelList() const; + Q_INVOKABLE void updateModelList(); + Q_INVOKABLE void downloadModel(const QString &modelFile); + Q_INVOKABLE void cancelDownload(const QString &modelFile); + +public Q_SLOTS: + void handleJsonDownloadFinished(); + void handleDownloadProgress(qint64 bytesReceived, qint64 bytesTotal); + void handleModelDownloadFinished(); + +Q_SIGNALS: + void downloadProgress(qint64 bytesReceived, qint64 bytesTotal, const QString &modelFile); + void downloadFinished(const QString &modelFile); + void modelListChanged(); + +private: + void parseJsonFile(const QByteArray &jsonData); + + QMap m_modelMap; + QNetworkAccessManager m_networkManager; + QVector m_activeDownloads; + +private: + explicit Download(); + ~Download() {} + friend class MyDownload; +}; + +#endif // DOWNLOAD_H diff --git a/llm.cpp b/llm.cpp index 93c26d02..a73b3fab 100644 --- a/llm.cpp +++ b/llm.cpp @@ -1,4 +1,5 @@ #include "llm.h" +#include "download.h" #include #include @@ -30,6 +31,13 @@ LLMObject::LLMObject() bool LLMObject::loadModel() { + if (modelList().isEmpty()) { + // try again when we get a list of models + connect(Download::globalInstance(), &Download::modelListChanged, this, + &LLMObject::loadModel, Qt::SingleShotConnection); + return false; + } + return loadModelPrivate(modelList().first()); } @@ -210,6 +218,7 @@ LLM::LLM() , m_llmodel(new LLMObject) , m_responseInProgress(false) { + connect(Download::globalInstance(), &Download::modelListChanged, this, &LLM::modelListChanged, Qt::QueuedConnection); connect(m_llmodel, &LLMObject::isModelLoadedChanged, this, &LLM::isModelLoadedChanged, Qt::QueuedConnection); connect(m_llmodel, &LLMObject::responseChanged, this, &LLM::responseChanged, Qt::QueuedConnection); connect(m_llmodel, &LLMObject::responseStarted, this, &LLM::responseStarted, Qt::QueuedConnection); diff --git a/main.cpp b/main.cpp index b0659aba..098f67ca 100644 --- a/main.cpp +++ b/main.cpp @@ -5,15 +5,21 @@ #include #include "llm.h" +#include "download.h" #include "config.h" int main(int argc, char *argv[]) { + QCoreApplication::setOrganizationName("nomic.ai"); + QCoreApplication::setOrganizationDomain("gpt4all.io"); + QCoreApplication::setApplicationName("GPT4All"); QCoreApplication::setApplicationVersion(APP_VERSION); QGuiApplication app(argc, argv); QQmlApplicationEngine engine; qmlRegisterSingletonInstance("llm", 1, 0, "LLM", LLM::globalInstance()); + qmlRegisterSingletonInstance("download", 1, 0, "Download", Download::globalInstance()); + const QUrl url(u"qrc:/gpt4all-chat/main.qml"_qs); QObject::connect(&engine, &QQmlApplicationEngine::objectCreated, @@ -23,7 +29,7 @@ int main(int argc, char *argv[]) }, Qt::QueuedConnection); engine.load(url); -#if 1 +#if 0 QDirIterator it("qrc:", QDirIterator::Subdirectories); while (it.hasNext()) { qDebug() << it.next(); diff --git a/main.qml b/main.qml index f899816b..002f8060 100644 --- a/main.qml +++ b/main.qml @@ -27,7 +27,6 @@ Window { Item { anchors.centerIn: parent - width: childrenRect.width height: childrenRect.height visible: LLM.isModelLoaded @@ -93,6 +92,16 @@ Window { title: qsTr("Settings") height: 600 width: 600 + opacity: 0.9 + background: Rectangle { + anchors.fill: parent + anchors.margins: -20 + color: "#202123" + border.width: 1 + border.color: "white" + radius: 10 + } + property real defaultTemperature: 0.28 property real defaultTopP: 0.95 property int defaultTopK: 40 @@ -134,10 +143,7 @@ Window { columns: 2 rowSpacing: 10 columnSpacing: 10 - anchors.top: parent.top - anchors.left: parent.left - anchors.right: parent.right - anchors.bottom: parent.bottom + anchors.fill: parent Label { id: tempLabel @@ -558,6 +564,7 @@ Window { } background: Rectangle { anchors.fill: parent + anchors.margins: -20 color: "#202123" border.width: 1 border.color: "white" @@ -565,6 +572,16 @@ Window { } } + ModelDownloaderDialog { + id: downloadNewModels + anchors.centerIn: parent + Item { + Accessible.role: Accessible.Dialog + Accessible.name: qsTr("Download new models dialog") + Accessible.description: qsTr("Dialog for downloading new models") + } + } + Drawer { id: drawer y: header.height @@ -638,7 +655,8 @@ Window { Button { anchors.left: parent.left anchors.right: parent.right - anchors.bottom: parent.bottom + anchors.bottom: downloadButton.top + anchors.bottomMargin: 20 padding: 15 contentItem: Text { text: qsTr("Check for updates...") @@ -663,6 +681,36 @@ Window { checkForUpdatesError.open() } } + + Button { + id: downloadButton + anchors.left: parent.left + anchors.right: parent.right + anchors.bottom: parent.bottom + padding: 15 + contentItem: Text { + text: qsTr("Download new models...") + horizontalAlignment: Text.AlignHCenter + color: "#d1d5db" + + Accessible.role: Accessible.Button + Accessible.name: text + Accessible.description: qsTr("Use this to launch a dialog to download new models") + } + + background: Rectangle { + opacity: .5 + border.color: "#7d7d8e" + border.width: 1 + radius: 10 + color: "#343541" + } + + onClicked: { + downloadNewModels.open() + } + } + } } diff --git a/qml/ModelDownloaderDialog.qml b/qml/ModelDownloaderDialog.qml new file mode 100644 index 00000000..67318411 --- /dev/null +++ b/qml/ModelDownloaderDialog.qml @@ -0,0 +1,219 @@ +import QtQuick 6.5 +import QtQuick.Controls 6.5 +import QtQuick.Layouts 1.12 +import download +import llm + +Dialog { + id: modelDownloaderDialog + width: 900 + height: 400 + title: "Model Downloader" + modal: true + opacity: 0.9 + closePolicy: LLM.modelList.length === 0 ? Popup.NoAutoClose : (Popup.CloseOnEscape | Popup.CloseOnPressOutside) + background: Rectangle { + anchors.fill: parent + anchors.margins: -20 + color: "#202123" + border.width: 1 + border.color: "white" + radius: 10 + } + + Component.onCompleted: { + if (LLM.modelList.length === 0) + open(); + } + + ColumnLayout { + anchors.fill: parent + anchors.margins: 20 + spacing: 10 + + Label { + id: listLabel + text: "Available Models:" + Layout.alignment: Qt.AlignLeft + Layout.fillWidth: true + color: "#d1d5db" + } + + ListView { + id: modelList + Layout.fillWidth: true + Layout.fillHeight: true + model: Download.modelList + clip: true + boundsBehavior: Flickable.StopAtBounds + + delegate: Item { + id: delegateItem + width: modelList.width + height: 50 + objectName: "delegateItem" + property bool downloading: false + + Rectangle { + anchors.fill: parent + color: index % 2 === 0 ? "#2c2f33" : "#1e2125" + } + + Text { + id: modelName + objectName: "modelName" + text: modelData.filename + anchors.verticalCenter: parent.verticalCenter + anchors.left: parent.left + anchors.leftMargin: 10 + font.pixelSize: 24 + color: "#d1d5db" + Accessible.role: Accessible.Paragraph + Accessible.name: qsTr("Model file") + Accessible.description: qsTr("Model file to be downloaded") + } + + Text { + text: qsTr("(default)") + visible: modelData.isDefault + anchors.verticalCenter: parent.verticalCenter + anchors.left: modelName.right + anchors.leftMargin: 10 + font.pixelSize: 24 + color: "#d1d5db" + Accessible.role: Accessible.Paragraph + Accessible.name: qsTr("Default file") + Accessible.description: qsTr("Whether the file is the default model") + } + + Label { + id: speedLabel + anchors.verticalCenter: parent.verticalCenter + anchors.right: itemProgressBar.left + anchors.rightMargin: 10 + objectName: "speedLabel" + color: "#d1d5db" + text: "" + visible: downloading + Accessible.role: Accessible.Paragraph + Accessible.name: qsTr("Download speed") + Accessible.description: qsTr("Download speed in bytes/kilobytes/megabytes per second") + } + + ProgressBar { + id: itemProgressBar + objectName: "itemProgressBar" + anchors.verticalCenter: parent.verticalCenter + anchors.right: downloadButton.left + anchors.rightMargin: 10 + width: 100 + visible: downloading + Accessible.role: Accessible.ProgressBar + Accessible.name: qsTr("Download progressBar") + Accessible.description: qsTr("Shows the progress made in the download") + } + + Label { + id: installedLabel + anchors.verticalCenter: parent.verticalCenter + anchors.right: parent.right + anchors.rightMargin: 15 + objectName: "installedLabel" + color: "#d1d5db" + text: qsTr("Already installed") + visible: modelData.installed + Accessible.role: Accessible.Paragraph + Accessible.name: text + Accessible.description: qsTr("Whether the file is already installed on your system") + } + + Button { + id: downloadButton + text: downloading ? "Cancel" : "Download" + anchors.verticalCenter: parent.verticalCenter + anchors.right: parent.right + anchors.rightMargin: 10 + visible: !modelData.installed + onClicked: { + if (!downloading) { + downloading = true; + Download.downloadModel(modelData.filename); + } else { + downloading = false; + Download.cancelDownload(modelData.filename); + } + } + Accessible.role: Accessible.Button + Accessible.name: text + Accessible.description: qsTr("Cancel/Download button to stop/start the download") + + } + } + + Component.onCompleted: { + Download.downloadProgress.connect(updateProgress); + Download.downloadFinished.connect(resetProgress); + } + + property var lastUpdate: ({}) + + function updateProgress(bytesReceived, bytesTotal, modelName) { + let currentTime = new Date().getTime(); + + for (let i = 0; i < modelList.contentItem.children.length; i++) { + let delegateItem = modelList.contentItem.children[i]; + if (delegateItem.objectName === "delegateItem") { + let modelNameText = delegateItem.children.find(child => child.objectName === "modelName").text; + if (modelNameText === modelName) { + let progressBar = delegateItem.children.find(child => child.objectName === "itemProgressBar"); + progressBar.value = bytesReceived / bytesTotal; + + // Calculate the download speed + if (lastUpdate[modelName] && lastUpdate[modelName].timestamp) { + let timeDifference = currentTime - lastUpdate[modelName].timestamp; + let bytesDifference = bytesReceived - lastUpdate[modelName].bytesReceived; + let speed = (bytesDifference / timeDifference) * 1000; // bytes per second + + // Update the speed label + let speedLabel = delegateItem.children.find(child => child.objectName === "speedLabel"); + if (speed < 1024) { + speedLabel.text = speed.toFixed(2) + " B/s"; + } else if (speed < 1024 * 1024) { + speedLabel.text = (speed / 1024).toFixed(2) + " KB/s"; + } else { + speedLabel.text = (speed / (1024 * 1024)).toFixed(2) + " MB/s"; + } + } + + // Update the lastUpdate object for the current model + lastUpdate[modelName] = {"timestamp": currentTime, "bytesReceived": bytesReceived}; + break; + } + } + } + } + + function resetProgress(modelName) { + for (let i = 0; i < modelList.contentItem.children.length; i++) { + let delegateItem = modelList.contentItem.children[i]; + if (delegateItem.objectName === "delegateItem") { + let modelNameText = delegateItem.children.find(child => child.objectName === "modelName").text; + if (modelNameText === modelName) { + let progressBar = delegateItem.children.find(child => child.objectName === "itemProgressBar"); + progressBar.value = 0; + delegateItem.downloading = false; + + // Remove speed label text + let speedLabel = delegateItem.children.find(child => child.objectName === "speedLabel"); + speedLabel.text = ""; + + // Remove the lastUpdate object for the canceled model + delete lastUpdate[modelName]; + break; + } + } + } + } + } + } +}