From 01e582f15b0a617c46b131a00c499cdbf201482e Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Thu, 4 May 2023 15:31:41 -0400 Subject: [PATCH] First attempt at providing a persistent chat list experience. Limitations: 1) Context is not restored for gpt-j models 2) When you switch between different model types in an existing chat the context and all the conversation is lost 3) The settings are not chat or conversation specific 4) The sizes of the chat persisted files are very large due to how much data the llama.cpp backend tries to persist. Need to investigate how we can shrink this. --- CMakeLists.txt | 2 +- chat.cpp | 180 +++++++++++++++++++++++++++------- chat.h | 42 +++++--- chatlistmodel.cpp | 72 ++++++++++++++ chatlistmodel.h | 31 ++++-- chatllm.cpp | 119 ++++++++++++++++------ chatllm.h | 29 +++--- chatmodel.h | 41 ++++++++ llm.cpp | 98 ++++-------------- llm.h | 16 ++- llmodel/llamamodel.cpp | 16 +++ llmodel/llamamodel.h | 3 + llmodel/llmodel.h | 3 + llmodel/llmodel_c.cpp | 18 ++++ llmodel/llmodel_c.h | 26 +++++ main.qml | 32 +++--- network.cpp | 1 - qml/ChatDrawer.qml | 7 ++ qml/ModelDownloaderDialog.qml | 2 +- 19 files changed, 530 insertions(+), 208 deletions(-) create mode 100644 chatlistmodel.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index bddcdd1f..6fe03ed6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,7 +60,7 @@ qt_add_executable(chat main.cpp chat.h chat.cpp chatllm.h chatllm.cpp - chatmodel.h chatlistmodel.h + chatmodel.h chatlistmodel.h chatlistmodel.cpp download.h download.cpp network.h network.cpp llm.h llm.cpp diff --git a/chat.cpp b/chat.cpp index 87f1dbd6..2350949f 100644 --- a/chat.cpp +++ b/chat.cpp @@ -1,32 +1,37 @@ #include "chat.h" +#include "llm.h" #include "network.h" +#include "download.h" Chat::Chat(QObject *parent) : QObject(parent) - , m_llmodel(new ChatLLM) , m_id(Network::globalInstance()->generateUniqueId()) , m_name(tr("New Chat")) , m_chatModel(new ChatModel(this)) , m_responseInProgress(false) - , m_desiredThreadCount(std::min(4, (int32_t) std::thread::hardware_concurrency())) + , m_creationDate(QDateTime::currentSecsSinceEpoch()) + , m_llmodel(new ChatLLM(this)) { + // Should be in same thread + connect(Download::globalInstance(), &Download::modelListChanged, this, &Chat::modelListChanged, Qt::DirectConnection); + connect(this, &Chat::modelNameChanged, this, &Chat::modelListChanged, Qt::DirectConnection); + + // Should be in different threads connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::responseChanged, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseStarted, this, &Chat::responseStarted, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::modelNameChanged, this, &Chat::modelNameChanged, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::threadCountChanged, this, &Chat::threadCountChanged, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::threadCountChanged, this, &Chat::syncThreadCount, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::recalcChanged, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::modelNameChanged, this, &Chat::handleModelNameChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection); connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection); connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection); - connect(this, &Chat::unloadRequested, m_llmodel, &ChatLLM::unload, Qt::QueuedConnection); - connect(this, &Chat::reloadRequested, m_llmodel, &ChatLLM::reload, Qt::QueuedConnection); + connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &ChatLLM::loadDefaultModel, Qt::QueuedConnection); + connect(this, &Chat::loadModelRequested, m_llmodel, &ChatLLM::loadModel, Qt::QueuedConnection); + connect(this, &Chat::unloadModelRequested, m_llmodel, &ChatLLM::unloadModel, Qt::QueuedConnection); + connect(this, &Chat::reloadModelRequested, m_llmodel, &ChatLLM::reloadModel, Qt::QueuedConnection); connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection); - connect(this, &Chat::setThreadCountRequested, m_llmodel, &ChatLLM::setThreadCount, Qt::QueuedConnection); // The following are blocking operations and will block the gui thread, therefore must be fast // to respond to @@ -38,9 +43,21 @@ Chat::Chat(QObject *parent) void Chat::reset() { stopGenerating(); + // Erase our current on disk representation as we're completely resetting the chat along with id + LLM::globalInstance()->chatListModel()->removeChatFile(this); emit resetContextRequested(); // blocking queued connection m_id = Network::globalInstance()->generateUniqueId(); emit idChanged(); + // NOTE: We deliberately do no reset the name or creation date to indictate that this was originally + // an older chat that was reset for another purpose. Resetting this data will lead to the chat + // name label changing back to 'New Chat' and showing up in the chat model list as a 'New Chat' + // further down in the list. This might surprise the user. In the future, we me might get rid of + // the "reset context" button in the UI. Right now, by changing the model in the combobox dropdown + // we effectively do a reset context. We *have* to do this right now when switching between different + // types of models. The only way to get rid of that would be a very long recalculate where we rebuild + // the context if we switch between different types of models. Probably the right way to fix this + // is to allow switching models but throwing up a dialog warning users if we switch between types + // of models that a long recalculation will ensue. m_chatModel->clear(); } @@ -49,10 +66,12 @@ bool Chat::isModelLoaded() const return m_llmodel->isModelLoaded(); } -void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, - float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens) +void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, + int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, + int32_t repeat_penalty_tokens) { - emit promptRequested(prompt, prompt_template, n_predict, top_k, top_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens); + emit promptRequested(prompt, prompt_template, n_predict, top_k, top_p, temp, n_batch, + repeat_penalty, repeat_penalty_tokens, LLM::globalInstance()->threadCount()); } void Chat::regenerateResponse() @@ -70,6 +89,13 @@ QString Chat::response() const return m_llmodel->response(); } +void Chat::handleResponseChanged() +{ + const int index = m_chatModel->count() - 1; + m_chatModel->updateValue(index, response()); + emit responseChanged(); +} + void Chat::responseStarted() { m_responseInProgress = true; @@ -98,21 +124,6 @@ void Chat::setModelName(const QString &modelName) emit modelNameChangeRequested(modelName); } -void Chat::syncThreadCount() { - emit setThreadCountRequested(m_desiredThreadCount); -} - -void Chat::setThreadCount(int32_t n_threads) { - if (n_threads <= 0) - n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); - m_desiredThreadCount = n_threads; - syncThreadCount(); -} - -int32_t Chat::threadCount() { - return m_llmodel->threadCount(); -} - void Chat::newPromptResponsePair(const QString &prompt) { m_chatModel->appendPrompt(tr("Prompt: "), prompt); @@ -125,16 +136,25 @@ bool Chat::isRecalc() const return m_llmodel->isRecalc(); } -void Chat::unload() +void Chat::loadDefaultModel() +{ + emit loadDefaultModelRequested(); +} + +void Chat::loadModel(const QString &modelName) +{ + emit loadModelRequested(modelName); +} + +void Chat::unloadModel() { - m_savedModelName = m_llmodel->modelName(); stopGenerating(); - emit unloadRequested(); + emit unloadModelRequested(); } -void Chat::reload() +void Chat::reloadModel() { - emit reloadRequested(m_savedModelName); + emit reloadModelRequested(m_savedModelName); } void Chat::generatedNameChanged() @@ -150,4 +170,98 @@ void Chat::generatedNameChanged() void Chat::handleRecalculating() { Network::globalInstance()->sendRecalculatingContext(m_chatModel->count()); + emit recalcChanged(); +} + +void Chat::handleModelNameChanged() +{ + m_savedModelName = modelName(); + emit modelNameChanged(); +} + +bool Chat::serialize(QDataStream &stream) const +{ + stream << m_creationDate; + stream << m_id; + stream << m_name; + stream << m_userName; + stream << m_savedModelName; + if (!m_llmodel->serialize(stream)) + return false; + if (!m_chatModel->serialize(stream)) + return false; + return stream.status() == QDataStream::Ok; +} + +bool Chat::deserialize(QDataStream &stream) +{ + stream >> m_creationDate; + stream >> m_id; + emit idChanged(); + stream >> m_name; + stream >> m_userName; + emit nameChanged(); + stream >> m_savedModelName; + if (!m_llmodel->deserialize(stream)) + return false; + if (!m_chatModel->deserialize(stream)) + return false; + emit chatModelChanged(); + return stream.status() == QDataStream::Ok; +} + +QList Chat::modelList() const +{ + // Build a model list from exepath and from the localpath + QList list; + + QString exePath = QCoreApplication::applicationDirPath() + QDir::separator(); + QString localPath = Download::globalInstance()->downloadLocalModelsPath(); + + { + QDir dir(exePath); + dir.setNameFilters(QStringList() << "ggml-*.bin"); + QStringList fileNames = dir.entryList(); + for (QString f : fileNames) { + QString filePath = exePath + f; + QFileInfo info(filePath); + QString name = info.completeBaseName().remove(0, 5); + if (info.exists()) { + if (name == modelName()) + list.prepend(name); + else + list.append(name); + } + } + } + + if (localPath != exePath) { + QDir dir(localPath); + dir.setNameFilters(QStringList() << "ggml-*.bin"); + QStringList fileNames = dir.entryList(); + for (QString f : fileNames) { + QString filePath = localPath + f; + QFileInfo info(filePath); + QString name = info.completeBaseName().remove(0, 5); + if (info.exists() && !list.contains(name)) { // don't allow duplicates + if (name == modelName()) + list.prepend(name); + else + list.append(name); + } + } + } + + if (list.isEmpty()) { + if (exePath != localPath) { + qWarning() << "ERROR: Could not find any applicable models in" + << exePath << "nor" << localPath; + } else { + qWarning() << "ERROR: Could not find any applicable models in" + << exePath; + } + return QList(); + } + + return list; } diff --git a/chat.h b/chat.h index 18a26d6a..fa5db003 100644 --- a/chat.h +++ b/chat.h @@ -3,6 +3,7 @@ #include #include +#include #include "chatllm.h" #include "chatmodel.h" @@ -17,8 +18,8 @@ class Chat : public QObject Q_PROPERTY(QString response READ response NOTIFY responseChanged) Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged) Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged) - Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged) Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged) + Q_PROPERTY(QList modelList READ modelList NOTIFY modelListChanged) QML_ELEMENT QML_UNCREATABLE("Only creatable from c++!") @@ -36,13 +37,10 @@ public: Q_INVOKABLE void reset(); Q_INVOKABLE bool isModelLoaded() const; - Q_INVOKABLE void prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, - float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens); + Q_INVOKABLE void prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, + int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens); Q_INVOKABLE void regenerateResponse(); Q_INVOKABLE void stopGenerating(); - Q_INVOKABLE void syncThreadCount(); - Q_INVOKABLE void setThreadCount(int32_t n_threads); - Q_INVOKABLE int32_t threadCount(); Q_INVOKABLE void newPromptResponsePair(const QString &prompt); QString response() const; @@ -51,8 +49,16 @@ public: void setModelName(const QString &modelName); bool isRecalc() const; - void unload(); - void reload(); + void loadDefaultModel(); + void loadModel(const QString &modelName); + void unloadModel(); + void reloadModel(); + + qint64 creationDate() const { return m_creationDate; } + bool serialize(QDataStream &stream) const; + bool deserialize(QDataStream &stream); + + QList modelList() const; Q_SIGNALS: void idChanged(); @@ -61,35 +67,39 @@ Q_SIGNALS: void isModelLoadedChanged(); void responseChanged(); void responseInProgressChanged(); - void promptRequested(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, - float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens); + void promptRequested(const QString &prompt, const QString &prompt_template, int32_t n_predict, + int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, + int32_t n_threads); void regenerateResponseRequested(); void resetResponseRequested(); void resetContextRequested(); void modelNameChangeRequested(const QString &modelName); void modelNameChanged(); - void threadCountChanged(); - void setThreadCountRequested(int32_t threadCount); void recalcChanged(); - void unloadRequested(); - void reloadRequested(const QString &modelName); + void loadDefaultModelRequested(); + void loadModelRequested(const QString &modelName); + void unloadModelRequested(); + void reloadModelRequested(const QString &modelName); void generateNameRequested(); + void modelListChanged(); private Q_SLOTS: + void handleResponseChanged(); void responseStarted(); void responseStopped(); void generatedNameChanged(); void handleRecalculating(); + void handleModelNameChanged(); private: - ChatLLM *m_llmodel; QString m_id; QString m_name; QString m_userName; QString m_savedModelName; ChatModel *m_chatModel; bool m_responseInProgress; - int32_t m_desiredThreadCount; + qint64 m_creationDate; + ChatLLM *m_llmodel; }; #endif // CHAT_H diff --git a/chatlistmodel.cpp b/chatlistmodel.cpp new file mode 100644 index 00000000..5114e02d --- /dev/null +++ b/chatlistmodel.cpp @@ -0,0 +1,72 @@ +#include "chatlistmodel.h" + +#include +#include + +void ChatListModel::removeChatFile(Chat *chat) const +{ + QSettings settings; + QFileInfo settingsInfo(settings.fileName()); + QString settingsPath = settingsInfo.absolutePath(); + QFile file(settingsPath + "/gpt4all-" + chat->id() + ".chat"); + if (!file.exists()) + return; + bool success = file.remove(); + if (!success) + qWarning() << "ERROR: Couldn't remove chat file:" << file.fileName(); +} + +void ChatListModel::saveChats() const +{ + QSettings settings; + QFileInfo settingsInfo(settings.fileName()); + QString settingsPath = settingsInfo.absolutePath(); + for (Chat *chat : m_chats) { + QFile file(settingsPath + "/gpt4all-" + chat->id() + ".chat"); + bool success = file.open(QIODevice::WriteOnly); + if (!success) { + qWarning() << "ERROR: Couldn't save chat to file:" << file.fileName(); + continue; + } + QDataStream out(&file); + if (!chat->serialize(out)) { + qWarning() << "ERROR: Couldn't serialize chat to file:" << file.fileName(); + file.remove(); + } + file.close(); + } +} + +void ChatListModel::restoreChats() +{ + QSettings settings; + QFileInfo settingsInfo(settings.fileName()); + QString settingsPath = settingsInfo.absolutePath(); + QDir dir(settingsPath); + dir.setNameFilters(QStringList() << "gpt4all-*.chat"); + QStringList fileNames = dir.entryList(); + beginResetModel(); + for (QString f : fileNames) { + QString filePath = settingsPath + "/" + f; + QFile file(filePath); + bool success = file.open(QIODevice::ReadOnly); + if (!success) { + qWarning() << "ERROR: Couldn't restore chat from file:" << file.fileName(); + continue; + } + QDataStream in(&file); + Chat *chat = new Chat(this); + if (!chat->deserialize(in)) { + qWarning() << "ERROR: Couldn't deserialize chat from file:" << file.fileName(); + file.remove(); + } else { + connect(chat, &Chat::nameChanged, this, &ChatListModel::nameChanged); + m_chats.append(chat); + } + file.close(); + } + std::sort(m_chats.begin(), m_chats.end(), [](const Chat* a, const Chat* b) { + return a->creationDate() > b->creationDate(); + }); + endResetModel(); +} diff --git a/chatlistmodel.h b/chatlistmodel.h index 20c6eeba..68633da9 100644 --- a/chatlistmodel.h +++ b/chatlistmodel.h @@ -55,7 +55,7 @@ public: Q_INVOKABLE void addChat() { - // Don't add a new chat if the current chat is empty + // Don't add a new chat if we already have one if (m_newChat) return; @@ -73,13 +73,29 @@ public: setCurrentChat(m_newChat); } + void setNewChat(Chat* chat) + { + // Don't add a new chat if we already have one + if (m_newChat) + return; + + m_newChat = chat; + connect(m_newChat->chatModel(), &ChatModel::countChanged, + this, &ChatListModel::newChatCountChanged); + connect(m_newChat, &Chat::nameChanged, + this, &ChatListModel::nameChanged); + setCurrentChat(m_newChat); + } + Q_INVOKABLE void removeChat(Chat* chat) { if (!m_chats.contains(chat)) { - qDebug() << "WARNING: Removing chat failed with id" << chat->id(); + qWarning() << "WARNING: Removing chat failed with id" << chat->id(); return; } + removeChatFile(chat); + emit disconnectChat(chat); if (chat == m_newChat) { m_newChat->disconnect(this); @@ -115,20 +131,20 @@ public: void setCurrentChat(Chat *chat) { if (!m_chats.contains(chat)) { - qDebug() << "ERROR: Setting current chat failed with id" << chat->id(); + qWarning() << "ERROR: Setting current chat failed with id" << chat->id(); return; } if (m_currentChat) { if (m_currentChat->isModelLoaded()) - m_currentChat->unload(); + m_currentChat->unloadModel(); emit disconnect(m_currentChat); } emit connectChat(chat); m_currentChat = chat; if (!m_currentChat->isModelLoaded()) - m_currentChat->reload(); + m_currentChat->reloadModel(); emit currentChatChanged(); } @@ -138,9 +154,12 @@ public: return m_chats.at(index); } - int count() const { return m_chats.size(); } + void removeChatFile(Chat *chat) const; + void saveChats() const; + void restoreChats(); + Q_SIGNALS: void countChanged(); void connectChat(Chat*); diff --git a/chatllm.cpp b/chatllm.cpp index 68230127..071a7ddb 100644 --- a/chatllm.cpp +++ b/chatllm.cpp @@ -1,7 +1,7 @@ #include "chatllm.h" +#include "chat.h" #include "download.h" #include "network.h" -#include "llm.h" #include "llmodel/gptj.h" #include "llmodel/llamamodel.h" @@ -32,28 +32,29 @@ static QString modelFilePath(const QString &modelName) return QString(); } -ChatLLM::ChatLLM() +ChatLLM::ChatLLM(Chat *parent) : QObject{nullptr} , m_llmodel(nullptr) , m_promptResponseTokens(0) , m_responseLogits(0) , m_isRecalc(false) + , m_chat(parent) { moveToThread(&m_llmThread); - connect(&m_llmThread, &QThread::started, this, &ChatLLM::loadModel); connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup); connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded); - m_llmThread.setObjectName("llm thread"); // FIXME: Should identify these with chat name + connect(m_chat, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged); + m_llmThread.setObjectName(m_chat->id()); m_llmThread.start(); } -bool ChatLLM::loadModel() +bool ChatLLM::loadDefaultModel() { - const QList models = LLM::globalInstance()->modelList(); + const QList models = m_chat->modelList(); if (models.isEmpty()) { // try again when we get a list of models connect(Download::globalInstance(), &Download::modelListChanged, this, - &ChatLLM::loadModel, Qt::SingleShotConnection); + &ChatLLM::loadDefaultModel, Qt::SingleShotConnection); return false; } @@ -62,10 +63,10 @@ bool ChatLLM::loadModel() QString defaultModel = settings.value("defaultModel", "gpt4all-j-v1.3-groovy").toString(); if (defaultModel.isEmpty() || !models.contains(defaultModel)) defaultModel = models.first(); - return loadModelPrivate(defaultModel); + return loadModel(defaultModel); } -bool ChatLLM::loadModelPrivate(const QString &modelName) +bool ChatLLM::loadModel(const QString &modelName) { if (isModelLoaded() && m_modelName == modelName) return true; @@ -100,12 +101,13 @@ bool ChatLLM::loadModelPrivate(const QString &modelName) } emit isModelLoadedChanged(); - emit threadCountChanged(); if (isFirstLoad) emit sendStartup(); else emit sendModelLoaded(); + } else { + qWarning() << "ERROR: Could not find model at" << filePath; } if (m_llmodel) @@ -114,19 +116,6 @@ bool ChatLLM::loadModelPrivate(const QString &modelName) return m_llmodel; } -void ChatLLM::setThreadCount(int32_t n_threads) { - if (m_llmodel && m_llmodel->threadCount() != n_threads) { - m_llmodel->setThreadCount(n_threads); - emit threadCountChanged(); - } -} - -int32_t ChatLLM::threadCount() { - if (!m_llmodel) - return 1; - return m_llmodel->threadCount(); -} - bool ChatLLM::isModelLoaded() const { return m_llmodel && m_llmodel->isModelLoaded(); @@ -203,7 +192,7 @@ void ChatLLM::setModelName(const QString &modelName) void ChatLLM::modelNameChangeRequested(const QString &modelName) { - if (!loadModelPrivate(modelName)) + if (!loadModel(modelName)) qWarning() << "ERROR: Could not load model" << modelName; } @@ -247,8 +236,8 @@ bool ChatLLM::handleRecalculate(bool isRecalc) return !m_stopGenerating; } -bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, - float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens) +bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, + float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, int n_threads) { if (!isModelLoaded()) return false; @@ -269,6 +258,7 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3 m_ctx.n_batch = n_batch; m_ctx.repeat_penalty = repeat_penalty; m_ctx.repeat_last_n = repeat_penalty_tokens; + m_llmodel->setThreadCount(n_threads); #if defined(DEBUG) printf("%s", qPrintable(instructPrompt)); fflush(stdout); @@ -288,19 +278,22 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3 return true; } -void ChatLLM::unload() +void ChatLLM::unloadModel() { + saveState(); delete m_llmodel; m_llmodel = nullptr; emit isModelLoadedChanged(); } -void ChatLLM::reload(const QString &modelName) +void ChatLLM::reloadModel(const QString &modelName) { - if (modelName.isEmpty()) - loadModel(); - else - loadModelPrivate(modelName); + if (modelName.isEmpty()) { + loadDefaultModel(); + } else { + loadModel(modelName); + } + restoreState(); } void ChatLLM::generateName() @@ -333,6 +326,11 @@ void ChatLLM::generateName() } } +void ChatLLM::handleChatIdChanged() +{ + m_llmThread.setObjectName(m_chat->id()); +} + bool ChatLLM::handleNamePrompt(int32_t token) { Q_UNUSED(token); @@ -354,3 +352,60 @@ bool ChatLLM::handleNameRecalculate(bool isRecalc) Q_UNREACHABLE(); return true; } + +bool ChatLLM::serialize(QDataStream &stream) +{ + stream << response(); + stream << generatedName(); + stream << m_promptResponseTokens; + stream << m_responseLogits; + stream << m_ctx.n_past; + stream << quint64(m_ctx.logits.size()); + stream.writeRawData(reinterpret_cast(m_ctx.logits.data()), m_ctx.logits.size() * sizeof(float)); + stream << quint64(m_ctx.tokens.size()); + stream.writeRawData(reinterpret_cast(m_ctx.tokens.data()), m_ctx.tokens.size() * sizeof(int)); + saveState(); + stream << m_state; + return stream.status() == QDataStream::Ok; +} + +bool ChatLLM::deserialize(QDataStream &stream) +{ + QString response; + stream >> response; + m_response = response.toStdString(); + QString nameResponse; + stream >> nameResponse; + m_nameResponse = nameResponse.toStdString(); + stream >> m_promptResponseTokens; + stream >> m_responseLogits; + stream >> m_ctx.n_past; + quint64 logitsSize; + stream >> logitsSize; + m_ctx.logits.resize(logitsSize); + stream.readRawData(reinterpret_cast(m_ctx.logits.data()), logitsSize * sizeof(float)); + quint64 tokensSize; + stream >> tokensSize; + m_ctx.tokens.resize(tokensSize); + stream.readRawData(reinterpret_cast(m_ctx.tokens.data()), tokensSize * sizeof(int)); + stream >> m_state; + return stream.status() == QDataStream::Ok; +} + +void ChatLLM::saveState() +{ + if (!isModelLoaded()) + return; + + const size_t stateSize = m_llmodel->stateSize(); + m_state.resize(stateSize); + m_llmodel->saveState(static_cast(reinterpret_cast(m_state.data()))); +} + +void ChatLLM::restoreState() +{ + if (!isModelLoaded()) + return; + + m_llmodel->restoreState(static_cast(reinterpret_cast(m_state.data()))); +} diff --git a/chatllm.h b/chatllm.h index ab2dcc8c..dc1260b8 100644 --- a/chatllm.h +++ b/chatllm.h @@ -6,18 +6,18 @@ #include "llmodel/llmodel.h" +class Chat; class ChatLLM : public QObject { Q_OBJECT Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged) Q_PROPERTY(QString response READ response NOTIFY responseChanged) Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged) - Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged) Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged) Q_PROPERTY(QString generatedName READ generatedName NOTIFY generatedNameChanged) public: - ChatLLM(); + ChatLLM(Chat *parent); bool isModelLoaded() const; void regenerateResponse(); @@ -25,8 +25,6 @@ public: void resetContext(); void stopGenerating() { m_stopGenerating = true; } - void setThreadCount(int32_t n_threads); - int32_t threadCount(); QString response() const; QString modelName() const; @@ -37,14 +35,20 @@ public: QString generatedName() const { return QString::fromStdString(m_nameResponse); } + bool serialize(QDataStream &stream); + bool deserialize(QDataStream &stream); + public Q_SLOTS: - bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, - float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens); - bool loadModel(); + bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, + int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, + int32_t n_threads); + bool loadDefaultModel(); + bool loadModel(const QString &modelName); void modelNameChangeRequested(const QString &modelName); - void unload(); - void reload(const QString &modelName); + void unloadModel(); + void reloadModel(const QString &modelName); void generateName(); + void handleChatIdChanged(); Q_SIGNALS: void isModelLoadedChanged(); @@ -52,22 +56,23 @@ Q_SIGNALS: void responseStarted(); void responseStopped(); void modelNameChanged(); - void threadCountChanged(); void recalcChanged(); void sendStartup(); void sendModelLoaded(); void sendResetContext(); void generatedNameChanged(); + void stateChanged(); private: void resetContextPrivate(); - bool loadModelPrivate(const QString &modelName); bool handlePrompt(int32_t token); bool handleResponse(int32_t token, const std::string &response); bool handleRecalculate(bool isRecalc); bool handleNamePrompt(int32_t token); bool handleNameResponse(int32_t token, const std::string &response); bool handleNameRecalculate(bool isRecalc); + void saveState(); + void restoreState(); private: LLModel::PromptContext m_ctx; @@ -77,6 +82,8 @@ private: quint32 m_promptResponseTokens; quint32 m_responseLogits; QString m_modelName; + Chat *m_chat; + QByteArray m_state; QThread m_llmThread; std::atomic m_stopGenerating; bool m_isRecalc; diff --git a/chatmodel.h b/chatmodel.h index e5be2719..f3e59fa3 100644 --- a/chatmodel.h +++ b/chatmodel.h @@ -3,6 +3,7 @@ #include #include +#include struct ChatItem { @@ -209,6 +210,46 @@ public: int count() const { return m_chatItems.size(); } + bool serialize(QDataStream &stream) const + { + stream << count(); + for (auto c : m_chatItems) { + stream << c.id; + stream << c.name; + stream << c.value; + stream << c.prompt; + stream << c.newResponse; + stream << c.currentResponse; + stream << c.stopped; + stream << c.thumbsUpState; + stream << c.thumbsDownState; + } + return stream.status() == QDataStream::Ok; + } + + bool deserialize(QDataStream &stream) + { + int size; + stream >> size; + for (int i = 0; i < size; ++i) { + ChatItem c; + stream >> c.id; + stream >> c.name; + stream >> c.value; + stream >> c.prompt; + stream >> c.newResponse; + stream >> c.currentResponse; + stream >> c.stopped; + stream >> c.thumbsUpState; + stream >> c.thumbsDownState; + beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size()); + m_chatItems.append(c); + endInsertRows(); + } + emit countChanged(); + return stream.status() == QDataStream::Ok; + } + Q_SIGNALS: void countChanged(); diff --git a/llm.cpp b/llm.cpp index f624b7a3..2689e47f 100644 --- a/llm.cpp +++ b/llm.cpp @@ -20,77 +20,22 @@ LLM *LLM::globalInstance() LLM::LLM() : QObject{nullptr} , m_chatListModel(new ChatListModel(this)) + , m_threadCount(std::min(4, (int32_t) std::thread::hardware_concurrency())) { - // Should be in the same thread - connect(Download::globalInstance(), &Download::modelListChanged, - this, &LLM::modelListChanged, Qt::DirectConnection); - connect(m_chatListModel, &ChatListModel::connectChat, - this, &LLM::connectChat, Qt::DirectConnection); - connect(m_chatListModel, &ChatListModel::disconnectChat, - this, &LLM::disconnectChat, Qt::DirectConnection); - - if (!m_chatListModel->count()) + connect(QCoreApplication::instance(), &QCoreApplication::aboutToQuit, + this, &LLM::aboutToQuit); + + m_chatListModel->restoreChats(); + if (m_chatListModel->count()) { + Chat *firstChat = m_chatListModel->get(0); + if (firstChat->chatModel()->count() < 2) + m_chatListModel->setNewChat(firstChat); + else + m_chatListModel->setCurrentChat(firstChat); + } else m_chatListModel->addChat(); } -QList LLM::modelList() const -{ - Q_ASSERT(m_chatListModel->currentChat()); - const Chat *currentChat = m_chatListModel->currentChat(); - // Build a model list from exepath and from the localpath - QList list; - - QString exePath = QCoreApplication::applicationDirPath() + QDir::separator(); - QString localPath = Download::globalInstance()->downloadLocalModelsPath(); - - { - QDir dir(exePath); - dir.setNameFilters(QStringList() << "ggml-*.bin"); - QStringList fileNames = dir.entryList(); - for (QString f : fileNames) { - QString filePath = exePath + f; - QFileInfo info(filePath); - QString name = info.completeBaseName().remove(0, 5); - if (info.exists()) { - if (name == currentChat->modelName()) - list.prepend(name); - else - list.append(name); - } - } - } - - if (localPath != exePath) { - QDir dir(localPath); - dir.setNameFilters(QStringList() << "ggml-*.bin"); - QStringList fileNames = dir.entryList(); - for (QString f : fileNames) { - QString filePath = localPath + f; - QFileInfo info(filePath); - QString name = info.completeBaseName().remove(0, 5); - if (info.exists() && !list.contains(name)) { // don't allow duplicates - if (name == currentChat->modelName()) - list.prepend(name); - else - list.append(name); - } - } - } - - if (list.isEmpty()) { - if (exePath != localPath) { - qWarning() << "ERROR: Could not find any applicable models in" - << exePath << "nor" << localPath; - } else { - qWarning() << "ERROR: Could not find any applicable models in" - << exePath; - } - return QList(); - } - - return list; -} - bool LLM::checkForUpdates() const { Network::globalInstance()->sendCheckForUpdates(); @@ -113,21 +58,20 @@ bool LLM::checkForUpdates() const return QProcess::startDetached(fileName); } -bool LLM::isRecalc() const +int32_t LLM::threadCount() const { - Q_ASSERT(m_chatListModel->currentChat()); - return m_chatListModel->currentChat()->isRecalc(); + return m_threadCount; } -void LLM::connectChat(Chat *chat) +void LLM::setThreadCount(int32_t n_threads) { - // Should be in the same thread - connect(chat, &Chat::modelNameChanged, this, &LLM::modelListChanged, Qt::DirectConnection); - connect(chat, &Chat::recalcChanged, this, &LLM::recalcChanged, Qt::DirectConnection); - connect(chat, &Chat::responseChanged, this, &LLM::responseChanged, Qt::DirectConnection); + if (n_threads <= 0) + n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + m_threadCount = n_threads; + emit threadCountChanged(); } -void LLM::disconnectChat(Chat *chat) +void LLM::aboutToQuit() { - chat->disconnect(this); + m_chatListModel->saveChats(); } diff --git a/llm.h b/llm.h index e291ebbf..89fee17f 100644 --- a/llm.h +++ b/llm.h @@ -3,37 +3,33 @@ #include -#include "chat.h" #include "chatlistmodel.h" class LLM : public QObject { Q_OBJECT - Q_PROPERTY(QList modelList READ modelList NOTIFY modelListChanged) - Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged) Q_PROPERTY(ChatListModel *chatListModel READ chatListModel NOTIFY chatListModelChanged) + Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged) public: static LLM *globalInstance(); - QList modelList() const; - bool isRecalc() const; ChatListModel *chatListModel() const { return m_chatListModel; } + int32_t threadCount() const; + void setThreadCount(int32_t n_threads); Q_INVOKABLE bool checkForUpdates() const; Q_SIGNALS: - void modelListChanged(); - void recalcChanged(); - void responseChanged(); void chatListModelChanged(); + void threadCountChanged(); private Q_SLOTS: - void connectChat(Chat*); - void disconnectChat(Chat*); + void aboutToQuit(); private: ChatListModel *m_chatListModel; + int32_t m_threadCount; private: explicit LLM(); diff --git a/llmodel/llamamodel.cpp b/llmodel/llamamodel.cpp index d89862f1..272633c7 100644 --- a/llmodel/llamamodel.cpp +++ b/llmodel/llamamodel.cpp @@ -67,6 +67,7 @@ int32_t LLamaModel::threadCount() { LLamaModel::~LLamaModel() { + llama_free(d_ptr->ctx); } bool LLamaModel::isModelLoaded() const @@ -74,6 +75,21 @@ bool LLamaModel::isModelLoaded() const return d_ptr->modelLoaded; } +size_t LLamaModel::stateSize() const +{ + return llama_get_state_size(d_ptr->ctx); +} + +size_t LLamaModel::saveState(uint8_t *dest) const +{ + return llama_copy_state_data(d_ptr->ctx, dest); +} + +size_t LLamaModel::restoreState(const uint8_t *src) +{ + return llama_set_state_data(d_ptr->ctx, src); +} + void LLamaModel::prompt(const std::string &prompt, std::function promptCallback, std::function responseCallback, diff --git a/llmodel/llamamodel.h b/llmodel/llamamodel.h index 13e221a7..7f487803 100644 --- a/llmodel/llamamodel.h +++ b/llmodel/llamamodel.h @@ -14,6 +14,9 @@ public: bool loadModel(const std::string &modelPath) override; bool isModelLoaded() const override; + size_t stateSize() const override; + size_t saveState(uint8_t *dest) const override; + size_t restoreState(const uint8_t *src) override; void prompt(const std::string &prompt, std::function promptCallback, std::function responseCallback, diff --git a/llmodel/llmodel.h b/llmodel/llmodel.h index 08dc1764..5ef900f4 100644 --- a/llmodel/llmodel.h +++ b/llmodel/llmodel.h @@ -12,6 +12,9 @@ public: virtual bool loadModel(const std::string &modelPath) = 0; virtual bool isModelLoaded() const = 0; + virtual size_t stateSize() const { return 0; } + virtual size_t saveState(uint8_t *dest) const { return 0; } + virtual size_t restoreState(const uint8_t *src) { return 0; } struct PromptContext { std::vector logits; // logits of current context std::vector tokens; // current tokens in the context window diff --git a/llmodel/llmodel_c.cpp b/llmodel/llmodel_c.cpp index 46eb1a7d..9788f1fb 100644 --- a/llmodel/llmodel_c.cpp +++ b/llmodel/llmodel_c.cpp @@ -48,6 +48,24 @@ bool llmodel_isModelLoaded(llmodel_model model) return wrapper->llModel->isModelLoaded(); } +uint64_t llmodel_get_state_size(llmodel_model model) +{ + LLModelWrapper *wrapper = reinterpret_cast(model); + return wrapper->llModel->stateSize(); +} + +uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest) +{ + LLModelWrapper *wrapper = reinterpret_cast(model); + return wrapper->llModel->saveState(dest); +} + +uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src) +{ + LLModelWrapper *wrapper = reinterpret_cast(model); + return wrapper->llModel->restoreState(src); +} + // Wrapper functions for the C callbacks bool prompt_wrapper(int32_t token_id, void *user_data) { llmodel_prompt_callback callback = reinterpret_cast(user_data); diff --git a/llmodel/llmodel_c.h b/llmodel/llmodel_c.h index 45cc9cd2..0907d765 100644 --- a/llmodel/llmodel_c.h +++ b/llmodel/llmodel_c.h @@ -98,6 +98,32 @@ bool llmodel_loadModel(llmodel_model model, const char *model_path); */ bool llmodel_isModelLoaded(llmodel_model model); +/** + * Get the size of the internal state of the model. + * NOTE: This state data is specific to the type of model you have created. + * @param model A pointer to the llmodel_model instance. + * @return the size in bytes of the internal state of the model + */ +uint64_t llmodel_get_state_size(llmodel_model model); + +/** + * Saves the internal state of the model to the specified destination address. + * NOTE: This state data is specific to the type of model you have created. + * @param model A pointer to the llmodel_model instance. + * @param dest A pointer to the destination. + * @return the number of bytes copied + */ +uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest); + +/** + * Restores the internal state of the model using data from the specified address. + * NOTE: This state data is specific to the type of model you have created. + * @param model A pointer to the llmodel_model instance. + * @param src A pointer to the src. + * @return the number of bytes read + */ +uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src); + /** * Generate a response using the model. * @param model A pointer to the llmodel_model instance. diff --git a/main.qml b/main.qml index d50087bc..a41730bb 100644 --- a/main.qml +++ b/main.qml @@ -65,7 +65,7 @@ Window { } // check for any current models and if not, open download dialog - if (LLM.modelList.length === 0 && !firstStartDialog.opened) { + if (currentChat.modelList.length === 0 && !firstStartDialog.opened) { downloadNewModels.open(); return; } @@ -125,7 +125,7 @@ Window { anchors.horizontalCenter: parent.horizontalCenter font.pixelSize: theme.fontSizeLarge spacing: 0 - model: LLM.modelList + model: currentChat.modelList Accessible.role: Accessible.ComboBox Accessible.name: qsTr("ComboBox for displaying/picking the current model") Accessible.description: qsTr("Use this for picking the current model to use; the first item is the current model") @@ -367,9 +367,9 @@ Window { text: qsTr("Recalculating context.") Connections { - target: LLM + target: currentChat function onRecalcChanged() { - if (LLM.isRecalc) + if (currentChat.isRecalc) recalcPopup.open() else recalcPopup.close() @@ -422,10 +422,7 @@ Window { var item = chatModel.get(i) var string = item.name; var isResponse = item.name === qsTr("Response: ") - if (item.currentResponse) - string += currentChat.response - else - string += chatModel.get(i).value + string += chatModel.get(i).value if (isResponse && item.stopped) string += " " string += "\n" @@ -440,10 +437,7 @@ Window { var item = chatModel.get(i) var isResponse = item.name === qsTr("Response: ") str += "{\"content\": "; - if (item.currentResponse) - str += JSON.stringify(currentChat.response) - else - str += JSON.stringify(item.value) + str += JSON.stringify(item.value) str += ", \"role\": \"" + (isResponse ? "assistant" : "user") + "\""; if (isResponse && item.thumbsUpState !== item.thumbsDownState) str += ", \"rating\": \"" + (item.thumbsUpState ? "positive" : "negative") + "\""; @@ -572,14 +566,14 @@ Window { Accessible.description: qsTr("This is the list of prompt/response pairs comprising the actual conversation with the model") delegate: TextArea { - text: currentResponse ? currentChat.response : (value ? value : "") + text: value width: listView.width color: theme.textColor wrapMode: Text.WordWrap focus: false readOnly: true font.pixelSize: theme.fontSizeLarge - cursorVisible: currentResponse ? (currentChat.response !== "" ? currentChat.responseInProgress : false) : false + cursorVisible: currentResponse ? currentChat.responseInProgress : false cursorPosition: text.length background: Rectangle { color: name === qsTr("Response: ") ? theme.backgroundLighter : theme.backgroundLight @@ -599,8 +593,8 @@ Window { anchors.leftMargin: 90 anchors.top: parent.top anchors.topMargin: 5 - visible: (currentResponse ? true : false) && currentChat.response === "" && currentChat.responseInProgress - running: (currentResponse ? true : false) && currentChat.response === "" && currentChat.responseInProgress + visible: (currentResponse ? true : false) && value === "" && currentChat.responseInProgress + running: (currentResponse ? true : false) && value === "" && currentChat.responseInProgress Accessible.role: Accessible.Animation Accessible.name: qsTr("Busy indicator") @@ -631,7 +625,7 @@ Window { window.height / 2 - height / 2) x: globalPoint.x y: globalPoint.y - property string text: currentResponse ? currentChat.response : (value ? value : "") + property string text: value response: newResponse === undefined || newResponse === "" ? text : newResponse onAccepted: { var responseHasChanged = response !== text && response !== newResponse @@ -711,7 +705,7 @@ Window { property bool isAutoScrolling: false Connections { - target: LLM + target: currentChat function onResponseChanged() { if (listView.shouldAutoScroll) { listView.isAutoScrolling = true @@ -762,7 +756,6 @@ Window { if (listElement.name === qsTr("Response: ")) { chatModel.updateCurrentResponse(index, true); chatModel.updateStopped(index, false); - chatModel.updateValue(index, currentChat.response); chatModel.updateThumbsUpState(index, false); chatModel.updateThumbsDownState(index, false); chatModel.updateNewResponse(index, ""); @@ -840,7 +833,6 @@ Window { var index = Math.max(0, chatModel.count - 1); var listElement = chatModel.get(index); chatModel.updateCurrentResponse(index, false); - chatModel.updateValue(index, currentChat.response); } currentChat.newPromptResponsePair(textInput.text); currentChat.prompt(textInput.text, settingsDialog.promptTemplate, diff --git a/network.cpp b/network.cpp index dfafddf8..ce77419e 100644 --- a/network.cpp +++ b/network.cpp @@ -458,7 +458,6 @@ void Network::handleIpifyFinished() void Network::handleMixpanelFinished() { - Q_ASSERT(m_usageStatsActive); QNetworkReply *reply = qobject_cast(sender()); if (!reply) return; diff --git a/qml/ChatDrawer.qml b/qml/ChatDrawer.qml index 1d581639..1bf37548 100644 --- a/qml/ChatDrawer.qml +++ b/qml/ChatDrawer.qml @@ -83,6 +83,7 @@ Drawer { opacity: 0.9 property bool isCurrent: LLM.chatListModel.currentChat === LLM.chatListModel.get(index) property bool trashQuestionDisplayed: false + z: isCurrent ? 199 : 1 color: index % 2 === 0 ? theme.backgroundLight : theme.backgroundLighter border.width: isCurrent border.color: chatName.readOnly ? theme.assistantColor : theme.userColor @@ -112,6 +113,11 @@ Drawer { color: "transparent" } onEditingFinished: { + // Work around a bug in qml where we're losing focus when the whole window + // goes out of focus even though this textfield should be marked as not + // having focus + if (chatName.readOnly) + return; changeName(); Network.sendRenameChat() } @@ -188,6 +194,7 @@ Drawer { visible: isCurrent && trashQuestionDisplayed opacity: 1.0 radius: 10 + z: 200 Row { spacing: 10 Button { diff --git a/qml/ModelDownloaderDialog.qml b/qml/ModelDownloaderDialog.qml index 92b564f8..cca4e494 100644 --- a/qml/ModelDownloaderDialog.qml +++ b/qml/ModelDownloaderDialog.qml @@ -12,7 +12,7 @@ Dialog { id: modelDownloaderDialog modal: true opacity: 0.9 - closePolicy: LLM.modelList.length === 0 ? Popup.NoAutoClose : (Popup.CloseOnEscape | Popup.CloseOnPressOutside) + closePolicy: LLM.chatListModel.currentChat.modelList.length === 0 ? Popup.NoAutoClose : (Popup.CloseOnEscape | Popup.CloseOnPressOutside) background: Rectangle { anchors.fill: parent anchors.margins: -20