diff --git a/gpt4all-chat/chat.cpp b/gpt4all-chat/chat.cpp index 0a4154be..be80c3bc 100644 --- a/gpt4all-chat/chat.cpp +++ b/gpt4all-chat/chat.cpp @@ -258,6 +258,7 @@ bool Chat::deserialize(QDataStream &stream, int version) // unfortunately, we cannot deserialize these if (version < 2 && m_savedModelName.contains("gpt4all-j")) return false; + m_llmodel->setModelName(m_savedModelName); if (!m_llmodel->deserialize(stream, version)) return false; if (!m_chatModel->deserialize(stream, version)) diff --git a/gpt4all-chat/chatgpt.cpp b/gpt4all-chat/chatgpt.cpp index 0e2f3ab2..1ea16105 100644 --- a/gpt4all-chat/chatgpt.cpp +++ b/gpt4all-chat/chatgpt.cpp @@ -46,6 +46,7 @@ bool ChatGPT::isModelLoaded() const return true; } +// All three of the state virtual functions are handled custom inside of chatllm save/restore size_t ChatGPT::stateSize() const { return 0; @@ -53,11 +54,13 @@ size_t ChatGPT::stateSize() const size_t ChatGPT::saveState(uint8_t *dest) const { + Q_UNUSED(dest); return 0; } size_t ChatGPT::restoreState(const uint8_t *src) { + Q_UNUSED(src); return 0; } @@ -141,8 +144,8 @@ void ChatGPT::handleFinished() bool ok; int code = response.toInt(&ok); if (!ok || code != 200) { - qWarning() << QString("\nERROR: ChatGPT responded with error code \"%1-%2%3\"\n") - .arg(code).arg(reply->errorString()).arg(reply->readAll()).toStdString(); + qWarning() << QString("ERROR: ChatGPT responded with error code \"%1-%2\"") + .arg(code).arg(reply->errorString()).toStdString(); } reply->deleteLater(); } @@ -190,8 +193,11 @@ void ChatGPT::handleReadyRead() const QString content = delta.value("content").toString(); Q_ASSERT(m_ctx); Q_ASSERT(m_responseCallback); - m_responseCallback(0, content.toStdString()); m_currentResponse += content; + if (!m_responseCallback(0, content.toStdString())) { + reply->abort(); + return; + } } } @@ -201,6 +207,6 @@ void ChatGPT::handleErrorOccurred(QNetworkReply::NetworkError code) if (!reply) return; - qWarning() << QString("\nERROR: ChatGPT responded with error code \"%1-%2%3\"\n") - .arg(code).arg(reply->errorString()).arg(reply->readAll()).toStdString(); + qWarning() << QString("ERROR: ChatGPT responded with error code \"%1-%2\"") + .arg(code).arg(reply->errorString()).toStdString(); } diff --git a/gpt4all-chat/chatgpt.h b/gpt4all-chat/chatgpt.h index 0a1e0d52..ef0f6ef5 100644 --- a/gpt4all-chat/chatgpt.h +++ b/gpt4all-chat/chatgpt.h @@ -30,6 +30,9 @@ public: void setModelName(const QString &modelName) { m_modelName = modelName; } void setAPIKey(const QString &apiKey) { m_apiKey = apiKey; } + QList context() const { return m_context; } + void setContext(const QList &context) { m_context = context; } + protected: void recalculateContext(PromptContext &promptCtx, std::function recalculate) override {} diff --git a/gpt4all-chat/chatlistmodel.cpp b/gpt4all-chat/chatlistmodel.cpp index a931c8ec..f2cb0ad2 100644 --- a/gpt4all-chat/chatlistmodel.cpp +++ b/gpt4all-chat/chatlistmodel.cpp @@ -38,6 +38,19 @@ void ChatListModel::setShouldSaveChats(bool b) emit shouldSaveChatsChanged(); } +bool ChatListModel::shouldSaveChatGPTChats() const +{ + return m_shouldSaveChatGPTChats; +} + +void ChatListModel::setShouldSaveChatGPTChats(bool b) +{ + if (m_shouldSaveChatGPTChats == b) + return; + m_shouldSaveChatGPTChats = b; + emit shouldSaveChatGPTChatsChanged(); +} + void ChatListModel::removeChatFile(Chat *chat) const { Q_ASSERT(chat != m_serverChat); @@ -52,15 +65,17 @@ void ChatListModel::removeChatFile(Chat *chat) const void ChatListModel::saveChats() const { - if (!m_shouldSaveChats) - return; - QElapsedTimer timer; timer.start(); const QString savePath = Download::globalInstance()->downloadLocalModelsPath(); for (Chat *chat : m_chats) { if (chat == m_serverChat) continue; + const bool isChatGPT = chat->modelName().startsWith("chatgpt-"); + if (!isChatGPT && !m_shouldSaveChats) + continue; + if (isChatGPT && !m_shouldSaveChatGPTChats) + continue; QString fileName = "gpt4all-" + chat->id() + ".chat"; QFile file(savePath + "/" + fileName); bool success = file.open(QIODevice::WriteOnly); diff --git a/gpt4all-chat/chatlistmodel.h b/gpt4all-chat/chatlistmodel.h index 881b2cd9..106c2fcb 100644 --- a/gpt4all-chat/chatlistmodel.h +++ b/gpt4all-chat/chatlistmodel.h @@ -20,6 +20,7 @@ class ChatListModel : public QAbstractListModel Q_PROPERTY(int count READ count NOTIFY countChanged) Q_PROPERTY(Chat *currentChat READ currentChat WRITE setCurrentChat NOTIFY currentChatChanged) Q_PROPERTY(bool shouldSaveChats READ shouldSaveChats WRITE setShouldSaveChats NOTIFY shouldSaveChatsChanged) + Q_PROPERTY(bool shouldSaveChatGPTChats READ shouldSaveChatGPTChats WRITE setShouldSaveChatGPTChats NOTIFY shouldSaveChatGPTChatsChanged) public: explicit ChatListModel(QObject *parent = nullptr); @@ -62,6 +63,9 @@ public: bool shouldSaveChats() const; void setShouldSaveChats(bool b); + bool shouldSaveChatGPTChats() const; + void setShouldSaveChatGPTChats(bool b); + Q_INVOKABLE void addChat() { // Don't add a new chat if we already have one @@ -199,6 +203,7 @@ Q_SIGNALS: void countChanged(); void currentChatChanged(); void shouldSaveChatsChanged(); + void shouldSaveChatGPTChatsChanged(); private Q_SLOTS: void newChatCountChanged() @@ -240,6 +245,7 @@ private Q_SLOTS: private: bool m_shouldSaveChats; + bool m_shouldSaveChatGPTChats; Chat* m_newChat; Chat* m_dummyChat; Chat* m_serverChat; diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index de9ee0a6..8d3acb21 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -611,6 +611,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version) stream >> compressed; m_state = qUncompress(compressed); } else { + stream >> m_state; } #if defined(DEBUG) @@ -624,6 +625,15 @@ void ChatLLM::saveState() if (!isModelLoaded()) return; + if (m_isChatGPT) { + m_state.clear(); + QDataStream stream(&m_state, QIODeviceBase::WriteOnly); + stream.setVersion(QDataStream::Qt_6_5); + ChatGPT *chatGPT = static_cast(m_modelInfo.model); + stream << chatGPT->context(); + return; + } + const size_t stateSize = m_modelInfo.model->stateSize(); m_state.resize(stateSize); #if defined(DEBUG) @@ -637,6 +647,18 @@ void ChatLLM::restoreState() if (!isModelLoaded() || m_state.isEmpty()) return; + if (m_isChatGPT) { + QDataStream stream(&m_state, QIODeviceBase::ReadOnly); + stream.setVersion(QDataStream::Qt_6_5); + ChatGPT *chatGPT = static_cast(m_modelInfo.model); + QList context; + stream >> context; + chatGPT->setContext(context); + m_state.clear(); + m_state.resize(0); + return; + } + #if defined(DEBUG) qDebug() << "restoreState" << m_chat->id() << "size:" << m_state.size(); #endif diff --git a/gpt4all-chat/qml/SettingsDialog.qml b/gpt4all-chat/qml/SettingsDialog.qml index 87873fa8..8e235f6b 100644 --- a/gpt4all-chat/qml/SettingsDialog.qml +++ b/gpt4all-chat/qml/SettingsDialog.qml @@ -40,6 +40,7 @@ Dialog { property int defaultRepeatPenaltyTokens: 64 property int defaultThreadCount: 0 property bool defaultSaveChats: false + property bool defaultSaveChatGPTChats: true property bool defaultServerChat: false property string defaultPromptTemplate: "### Human: %1 @@ -57,6 +58,7 @@ Dialog { property alias repeatPenaltyTokens: settings.repeatPenaltyTokens property alias threadCount: settings.threadCount property alias saveChats: settings.saveChats + property alias saveChatGPTChats: settings.saveChatGPTChats property alias serverChat: settings.serverChat property alias modelPath: settings.modelPath property alias userDefaultModel: settings.userDefaultModel @@ -70,6 +72,7 @@ Dialog { property int promptBatchSize: settingsDialog.defaultPromptBatchSize property int threadCount: settingsDialog.defaultThreadCount property bool saveChats: settingsDialog.defaultSaveChats + property bool saveChatGPTChats: settingsDialog.defaultSaveChatGPTChats property bool serverChat: settingsDialog.defaultServerChat property real repeatPenalty: settingsDialog.defaultRepeatPenalty property int repeatPenaltyTokens: settingsDialog.defaultRepeatPenaltyTokens @@ -94,12 +97,14 @@ Dialog { settings.modelPath = settingsDialog.defaultModelPath settings.threadCount = defaultThreadCount settings.saveChats = defaultSaveChats + settings.saveChatGPTChats = defaultSaveChatGPTChats settings.serverChat = defaultServerChat settings.userDefaultModel = defaultUserDefaultModel Download.downloadLocalModelsPath = settings.modelPath LLM.threadCount = settings.threadCount LLM.serverEnabled = settings.serverChat LLM.chatListModel.shouldSaveChats = settings.saveChats + LLM.chatListModel.shouldSaveChatGPTChats = settings.saveChatGPTChats settings.sync() } @@ -107,6 +112,7 @@ Dialog { LLM.threadCount = settings.threadCount LLM.serverEnabled = settings.serverChat LLM.chatListModel.shouldSaveChats = settings.saveChats + LLM.chatListModel.shouldSaveChatGPTChats = settings.saveChatGPTChats Download.downloadLocalModelsPath = settings.modelPath } @@ -802,16 +808,65 @@ Dialog { leftPadding: saveChatsBox.indicator.width + saveChatsBox.spacing } } + Label { + id: saveChatGPTChatsLabel + text: qsTr("Save ChatGPT chats to disk:") + color: theme.textColor + Layout.row: 5 + Layout.column: 0 + } + CheckBox { + id: saveChatGPTChatsBox + Layout.row: 5 + Layout.column: 1 + checked: settingsDialog.saveChatGPTChats + onClicked: { + settingsDialog.saveChatGPTChats = saveChatGPTChatsBox.checked + LLM.chatListModel.shouldSaveChatGPTChats = saveChatGPTChatsBox.checked + settings.sync() + } + + background: Rectangle { + color: "transparent" + } + + indicator: Rectangle { + implicitWidth: 26 + implicitHeight: 26 + x: saveChatGPTChatsBox.leftPadding + y: parent.height / 2 - height / 2 + border.color: theme.dialogBorder + color: "transparent" + + Rectangle { + width: 14 + height: 14 + x: 6 + y: 6 + color: theme.textColor + visible: saveChatGPTChatsBox.checked + } + } + + contentItem: Text { + text: saveChatGPTChatsBox.text + font: saveChatGPTChatsBox.font + opacity: enabled ? 1.0 : 0.3 + color: theme.textColor + verticalAlignment: Text.AlignVCenter + leftPadding: saveChatGPTChatsBox.indicator.width + saveChatGPTChatsBox.spacing + } + } Label { id: serverChatLabel text: qsTr("Enable web server:") color: theme.textColor - Layout.row: 5 + Layout.row: 6 Layout.column: 0 } CheckBox { id: serverChatBox - Layout.row: 5 + Layout.row: 6 Layout.column: 1 checked: settings.serverChat onClicked: { @@ -855,7 +910,7 @@ Dialog { } } Button { - Layout.row: 6 + Layout.row: 7 Layout.column: 1 Layout.fillWidth: true padding: 10