diff --git a/gpt4all-chat/modellist.cpp b/gpt4all-chat/modellist.cpp index b705a4bf..5a7deed6 100644 --- a/gpt4all-chat/modellist.cpp +++ b/gpt4all-chat/modellist.cpp @@ -2,6 +2,8 @@ #include "mysettings.h" #include "network.h" +#include +#include #include //#define USE_LOCAL_MODELSJSON @@ -241,6 +243,7 @@ ModelList::ModelList() connect(MySettings::globalInstance(), &MySettings::repeatPenaltyTokensChanged, this, &ModelList::updateDataForSettings);; connect(MySettings::globalInstance(), &MySettings::promptTemplateChanged, this, &ModelList::updateDataForSettings); connect(MySettings::globalInstance(), &MySettings::systemPromptChanged, this, &ModelList::updateDataForSettings); + connect(&m_networkManager, &QNetworkAccessManager::sslErrors, this, &ModelList::handleSslErrors); updateModelsFromJson(); updateModelsFromSettings(); @@ -390,6 +393,21 @@ void ModelList::addModel(const QString &id) emit userDefaultModelListChanged(); } +void ModelList::changeId(const QString &oldId, const QString &newId) +{ + const bool hasModel = contains(oldId); + Q_ASSERT(hasModel); + if (!hasModel) { + qWarning() << "ERROR: model list does not contain" << oldId; + return; + } + + QMutexLocker locker(&m_mutex); + ModelInfo *info = m_modelMap.take(oldId); + info->setId(newId); + m_modelMap.insert(newId, info); +} + int ModelList::rowCount(const QModelIndex &parent) const { Q_UNUSED(parent) @@ -857,13 +875,60 @@ void ModelList::updateModelsFromJson() if (jsonReply->error() == QNetworkReply::NoError && jsonReply->isFinished()) { QByteArray jsonData = jsonReply->readAll(); jsonReply->deleteLater(); - parseModelsJsonFile(jsonData); + parseModelsJsonFile(jsonData, true); } else { - qWarning() << "Could not download models.json"; + qWarning() << "WARNING: Could not download models.json synchronously"; + updateModelsFromJsonAsync(); + + QSettings settings; + QFileInfo info(settings.fileName()); + QString dirPath = info.canonicalPath(); + const QString modelsConfig = dirPath + "/models.json"; + QFile file(modelsConfig); + if (!file.open(QIODeviceBase::ReadOnly)) { + qWarning() << "ERROR: Couldn't read models config file: " << modelsConfig; + } else { + QByteArray jsonData = file.readAll(); + file.close(); + parseModelsJsonFile(jsonData, false); + } } delete jsonReply; } +void ModelList::updateModelsFromJsonAsync() +{ +#if defined(USE_LOCAL_MODELSJSON) + QUrl jsonUrl("file://" + QDir::homePath() + "/dev/large_language_models/gpt4all/gpt4all-chat/metadata/models.json"); +#else + QUrl jsonUrl("http://gpt4all.io/models/models.json"); +#endif + QNetworkRequest request(jsonUrl); + QSslConfiguration conf = request.sslConfiguration(); + conf.setPeerVerifyMode(QSslSocket::VerifyNone); + request.setSslConfiguration(conf); + QNetworkReply *jsonReply = m_networkManager.get(request); + connect(jsonReply, &QNetworkReply::finished, this, &ModelList::handleModelsJsonDownloadFinished); +} + +void ModelList::handleModelsJsonDownloadFinished() +{ + QNetworkReply *jsonReply = qobject_cast(sender()); + if (!jsonReply) + return; + + QByteArray jsonData = jsonReply->readAll(); + jsonReply->deleteLater(); + parseModelsJsonFile(jsonData, true); +} + +void ModelList::handleSslErrors(QNetworkReply *reply, const QList &errors) +{ + QUrl url = reply->request().url(); + for (const auto &e : errors) + qWarning() << "ERROR: Received ssl error:" << e.errorString() << "for" << url; +} + void ModelList::updateDataForSettings() { emit dataChanged(index(0, 0), index(m_models.size() - 1, 0)); @@ -887,7 +952,7 @@ static bool compareVersions(const QString &a, const QString &b) { return aParts.size() > bParts.size(); } -void ModelList::parseModelsJsonFile(const QByteArray &jsonData) +void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save) { QJsonParseError err; QJsonDocument document = QJsonDocument::fromJson(jsonData, &err); @@ -896,6 +961,20 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData) return; } + if (save) { + QSettings settings; + QFileInfo info(settings.fileName()); + QString dirPath = info.canonicalPath(); + const QString modelsConfig = dirPath + "/models.json"; + QFile file(modelsConfig); + if (!file.open(QIODeviceBase::WriteOnly)) { + qWarning() << "ERROR: Couldn't write models config file: " << modelsConfig; + } else { + file.write(jsonData.constData()); + file.close(); + } + } + QJsonArray jsonArray = document.array(); const QString currentVersion = QCoreApplication::applicationVersion(); @@ -936,6 +1015,9 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData) const QString id = modelName; Q_ASSERT(!id.isEmpty()); + if (contains(modelFilename)) + changeId(modelFilename, id); + if (!contains(id)) addModel(id); @@ -983,6 +1065,8 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData) const QString modelName = "ChatGPT-3.5 Turbo"; const QString id = modelName; const QString modelFilename = "chatgpt-gpt-3.5-turbo.txt"; + if (contains(modelFilename)) + changeId(modelFilename, id); if (!contains(id)) addModel(id); updateData(id, ModelList::NameRole, modelName); @@ -1003,6 +1087,8 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData) const QString modelName = "ChatGPT-4"; const QString id = modelName; const QString modelFilename = "chatgpt-gpt-4.txt"; + if (contains(modelFilename)) + changeId(modelFilename, id); if (!contains(id)) addModel(id); updateData(id, ModelList::NameRole, modelName); diff --git a/gpt4all-chat/modellist.h b/gpt4all-chat/modellist.h index 89c68229..c749254c 100644 --- a/gpt4all-chat/modellist.h +++ b/gpt4all-chat/modellist.h @@ -275,6 +275,7 @@ public: ModelInfo defaultModelInfo() const; void addModel(const QString &id); + void changeId(const QString &oldId, const QString &newId); const QList exportModelList() const; const QList userDefaultModelList() const; @@ -304,16 +305,19 @@ Q_SIGNALS: private Q_SLOTS: void updateModelsFromJson(); + void updateModelsFromJsonAsync(); void updateModelsFromSettings(); void updateModelsFromDirectory(); void updateDataForSettings(); + void handleModelsJsonDownloadFinished(); + void handleSslErrors(QNetworkReply *reply, const QList &errors); private: QString modelDirPath(const QString &modelName, bool isChatGPT); int indexForModel(ModelInfo *model); QVariant dataInternal(const ModelInfo *info, int role) const; static bool lessThan(const ModelInfo* a, const ModelInfo* b); - void parseModelsJsonFile(const QByteArray &jsonData); + void parseModelsJsonFile(const QByteArray &jsonData, bool save); QString uniqueModelName(const ModelInfo &model) const; private: