diff --git a/CMakeLists.txt b/CMakeLists.txt index bddcdd1f..6fe03ed6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,7 +60,7 @@ qt_add_executable(chat main.cpp chat.h chat.cpp chatllm.h chatllm.cpp - chatmodel.h chatlistmodel.h + chatmodel.h chatlistmodel.h chatlistmodel.cpp download.h download.cpp network.h network.cpp llm.h llm.cpp diff --git a/chat.cpp b/chat.cpp index 87f1dbd6..2350949f 100644 --- a/chat.cpp +++ b/chat.cpp @@ -1,32 +1,37 @@ #include "chat.h" +#include "llm.h" #include "network.h" +#include "download.h" Chat::Chat(QObject *parent) : QObject(parent) - , m_llmodel(new ChatLLM) , m_id(Network::globalInstance()->generateUniqueId()) , m_name(tr("New Chat")) , m_chatModel(new ChatModel(this)) , m_responseInProgress(false) - , m_desiredThreadCount(std::min(4, (int32_t) std::thread::hardware_concurrency())) + , m_creationDate(QDateTime::currentSecsSinceEpoch()) + , m_llmodel(new ChatLLM(this)) { + // Should be in same thread + connect(Download::globalInstance(), &Download::modelListChanged, this, &Chat::modelListChanged, Qt::DirectConnection); + connect(this, &Chat::modelNameChanged, this, &Chat::modelListChanged, Qt::DirectConnection); + + // Should be in different threads connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::responseChanged, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseStarted, this, &Chat::responseStarted, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::modelNameChanged, this, &Chat::modelNameChanged, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::threadCountChanged, this, &Chat::threadCountChanged, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::threadCountChanged, this, &Chat::syncThreadCount, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::recalcChanged, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::modelNameChanged, this, &Chat::handleModelNameChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection); connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection); connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection); - connect(this, &Chat::unloadRequested, m_llmodel, &ChatLLM::unload, Qt::QueuedConnection); - connect(this, &Chat::reloadRequested, m_llmodel, &ChatLLM::reload, Qt::QueuedConnection); + connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &ChatLLM::loadDefaultModel, Qt::QueuedConnection); + connect(this, &Chat::loadModelRequested, m_llmodel, &ChatLLM::loadModel, Qt::QueuedConnection); + connect(this, &Chat::unloadModelRequested, m_llmodel, &ChatLLM::unloadModel, Qt::QueuedConnection); + connect(this, &Chat::reloadModelRequested, m_llmodel, &ChatLLM::reloadModel, Qt::QueuedConnection); connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection); - connect(this, &Chat::setThreadCountRequested, m_llmodel, &ChatLLM::setThreadCount, Qt::QueuedConnection); // The following are blocking operations and will block the gui thread, therefore must be fast // to respond to @@ -38,9 +43,21 @@ Chat::Chat(QObject *parent) void Chat::reset() { stopGenerating(); + // Erase our current on disk representation as we're completely resetting the chat along with id + LLM::globalInstance()->chatListModel()->removeChatFile(this); emit resetContextRequested(); // blocking queued connection m_id = Network::globalInstance()->generateUniqueId(); emit idChanged(); + // NOTE: We deliberately do no reset the name or creation date to indictate that this was originally + // an older chat that was reset for another purpose. Resetting this data will lead to the chat + // name label changing back to 'New Chat' and showing up in the chat model list as a 'New Chat' + // further down in the list. This might surprise the user. In the future, we me might get rid of + // the "reset context" button in the UI. Right now, by changing the model in the combobox dropdown + // we effectively do a reset context. We *have* to do this right now when switching between different + // types of models. The only way to get rid of that would be a very long recalculate where we rebuild + // the context if we switch between different types of models. Probably the right way to fix this + // is to allow switching models but throwing up a dialog warning users if we switch between types + // of models that a long recalculation will ensue. m_chatModel->clear(); } @@ -49,10 +66,12 @@ bool Chat::isModelLoaded() const return m_llmodel->isModelLoaded(); } -void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, - float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens) +void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, + int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, + int32_t repeat_penalty_tokens) { - emit promptRequested(prompt, prompt_template, n_predict, top_k, top_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens); + emit promptRequested(prompt, prompt_template, n_predict, top_k, top_p, temp, n_batch, + repeat_penalty, repeat_penalty_tokens, LLM::globalInstance()->threadCount()); } void Chat::regenerateResponse() @@ -70,6 +89,13 @@ QString Chat::response() const return m_llmodel->response(); } +void Chat::handleResponseChanged() +{ + const int index = m_chatModel->count() - 1; + m_chatModel->updateValue(index, response()); + emit responseChanged(); +} + void Chat::responseStarted() { m_responseInProgress = true; @@ -98,21 +124,6 @@ void Chat::setModelName(const QString &modelName) emit modelNameChangeRequested(modelName); } -void Chat::syncThreadCount() { - emit setThreadCountRequested(m_desiredThreadCount); -} - -void Chat::setThreadCount(int32_t n_threads) { - if (n_threads <= 0) - n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); - m_desiredThreadCount = n_threads; - syncThreadCount(); -} - -int32_t Chat::threadCount() { - return m_llmodel->threadCount(); -} - void Chat::newPromptResponsePair(const QString &prompt) { m_chatModel->appendPrompt(tr("Prompt: "), prompt); @@ -125,16 +136,25 @@ bool Chat::isRecalc() const return m_llmodel->isRecalc(); } -void Chat::unload() +void Chat::loadDefaultModel() +{ + emit loadDefaultModelRequested(); +} + +void Chat::loadModel(const QString &modelName) +{ + emit loadModelRequested(modelName); +} + +void Chat::unloadModel() { - m_savedModelName = m_llmodel->modelName(); stopGenerating(); - emit unloadRequested(); + emit unloadModelRequested(); } -void Chat::reload() +void Chat::reloadModel() { - emit reloadRequested(m_savedModelName); + emit reloadModelRequested(m_savedModelName); } void Chat::generatedNameChanged() @@ -150,4 +170,98 @@ void Chat::generatedNameChanged() void Chat::handleRecalculating() { Network::globalInstance()->sendRecalculatingContext(m_chatModel->count()); + emit recalcChanged(); +} + +void Chat::handleModelNameChanged() +{ + m_savedModelName = modelName(); + emit modelNameChanged(); +} + +bool Chat::serialize(QDataStream &stream) const +{ + stream << m_creationDate; + stream << m_id; + stream << m_name; + stream << m_userName; + stream << m_savedModelName; + if (!m_llmodel->serialize(stream)) + return false; + if (!m_chatModel->serialize(stream)) + return false; + return stream.status() == QDataStream::Ok; +} + +bool Chat::deserialize(QDataStream &stream) +{ + stream >> m_creationDate; + stream >> m_id; + emit idChanged(); + stream >> m_name; + stream >> m_userName; + emit nameChanged(); + stream >> m_savedModelName; + if (!m_llmodel->deserialize(stream)) + return false; + if (!m_chatModel->deserialize(stream)) + return false; + emit chatModelChanged(); + return stream.status() == QDataStream::Ok; +} + +QList Chat::modelList() const +{ + // Build a model list from exepath and from the localpath + QList list; + + QString exePath = QCoreApplication::applicationDirPath() + QDir::separator(); + QString localPath = Download::globalInstance()->downloadLocalModelsPath(); + + { + QDir dir(exePath); + dir.setNameFilters(QStringList() << "ggml-*.bin"); + QStringList fileNames = dir.entryList(); + for (QString f : fileNames) { + QString filePath = exePath + f; + QFileInfo info(filePath); + QString name = info.completeBaseName().remove(0, 5); + if (info.exists()) { + if (name == modelName()) + list.prepend(name); + else + list.append(name); + } + } + } + + if (localPath != exePath) { + QDir dir(localPath); + dir.setNameFilters(QStringList() << "ggml-*.bin"); + QStringList fileNames = dir.entryList(); + for (QString f : fileNames) { + QString filePath = localPath + f; + QFileInfo info(filePath); + QString name = info.completeBaseName().remove(0, 5); + if (info.exists() && !list.contains(name)) { // don't allow duplicates + if (name == modelName()) + list.prepend(name); + else + list.append(name); + } + } + } + + if (list.isEmpty()) { + if (exePath != localPath) { + qWarning() << "ERROR: Could not find any applicable models in" + << exePath << "nor" << localPath; + } else { + qWarning() << "ERROR: Could not find any applicable models in" + << exePath; + } + return QList(); + } + + return list; } diff --git a/chat.h b/chat.h index 18a26d6a..fa5db003 100644 --- a/chat.h +++ b/chat.h @@ -3,6 +3,7 @@ #include #include +#include #include "chatllm.h" #include "chatmodel.h" @@ -17,8 +18,8 @@ class Chat : public QObject Q_PROPERTY(QString response READ response NOTIFY responseChanged) Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged) Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged) - Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged) Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged) + Q_PROPERTY(QList modelList READ modelList NOTIFY modelListChanged) QML_ELEMENT QML_UNCREATABLE("Only creatable from c++!") @@ -36,13 +37,10 @@ public: Q_INVOKABLE void reset(); Q_INVOKABLE bool isModelLoaded() const; - Q_INVOKABLE void prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, - float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens); + Q_INVOKABLE void prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, + int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens); Q_INVOKABLE void regenerateResponse(); Q_INVOKABLE void stopGenerating(); - Q_INVOKABLE void syncThreadCount(); - Q_INVOKABLE void setThreadCount(int32_t n_threads); - Q_INVOKABLE int32_t threadCount(); Q_INVOKABLE void newPromptResponsePair(const QString &prompt); QString response() const; @@ -51,8 +49,16 @@ public: void setModelName(const QString &modelName); bool isRecalc() const; - void unload(); - void reload(); + void loadDefaultModel(); + void loadModel(const QString &modelName); + void unloadModel(); + void reloadModel(); + + qint64 creationDate() const { return m_creationDate; } + bool serialize(QDataStream &stream) const; + bool deserialize(QDataStream &stream); + + QList modelList() const; Q_SIGNALS: void idChanged(); @@ -61,35 +67,39 @@ Q_SIGNALS: void isModelLoadedChanged(); void responseChanged(); void responseInProgressChanged(); - void promptRequested(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, - float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens); + void promptRequested(const QString &prompt, const QString &prompt_template, int32_t n_predict, + int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, + int32_t n_threads); void regenerateResponseRequested(); void resetResponseRequested(); void resetContextRequested(); void modelNameChangeRequested(const QString &modelName); void modelNameChanged(); - void threadCountChanged(); - void setThreadCountRequested(int32_t threadCount); void recalcChanged(); - void unloadRequested(); - void reloadRequested(const QString &modelName); + void loadDefaultModelRequested(); + void loadModelRequested(const QString &modelName); + void unloadModelRequested(); + void reloadModelRequested(const QString &modelName); void generateNameRequested(); + void modelListChanged(); private Q_SLOTS: + void handleResponseChanged(); void responseStarted(); void responseStopped(); void generatedNameChanged(); void handleRecalculating(); + void handleModelNameChanged(); private: - ChatLLM *m_llmodel; QString m_id; QString m_name; QString m_userName; QString m_savedModelName; ChatModel *m_chatModel; bool m_responseInProgress; - int32_t m_desiredThreadCount; + qint64 m_creationDate; + ChatLLM *m_llmodel; }; #endif // CHAT_H diff --git a/chatlistmodel.cpp b/chatlistmodel.cpp new file mode 100644 index 00000000..5114e02d --- /dev/null +++ b/chatlistmodel.cpp @@ -0,0 +1,72 @@ +#include "chatlistmodel.h" + +#include +#include + +void ChatListModel::removeChatFile(Chat *chat) const +{ + QSettings settings; + QFileInfo settingsInfo(settings.fileName()); + QString settingsPath = settingsInfo.absolutePath(); + QFile file(settingsPath + "/gpt4all-" + chat->id() + ".chat"); + if (!file.exists()) + return; + bool success = file.remove(); + if (!success) + qWarning() << "ERROR: Couldn't remove chat file:" << file.fileName(); +} + +void ChatListModel::saveChats() const +{ + QSettings settings; + QFileInfo settingsInfo(settings.fileName()); + QString settingsPath = settingsInfo.absolutePath(); + for (Chat *chat : m_chats) { + QFile file(settingsPath + "/gpt4all-" + chat->id() + ".chat"); + bool success = file.open(QIODevice::WriteOnly); + if (!success) { + qWarning() << "ERROR: Couldn't save chat to file:" << file.fileName(); + continue; + } + QDataStream out(&file); + if (!chat->serialize(out)) { + qWarning() << "ERROR: Couldn't serialize chat to file:" << file.fileName(); + file.remove(); + } + file.close(); + } +} + +void ChatListModel::restoreChats() +{ + QSettings settings; + QFileInfo settingsInfo(settings.fileName()); + QString settingsPath = settingsInfo.absolutePath(); + QDir dir(settingsPath); + dir.setNameFilters(QStringList() << "gpt4all-*.chat"); + QStringList fileNames = dir.entryList(); + beginResetModel(); + for (QString f : fileNames) { + QString filePath = settingsPath + "/" + f; + QFile file(filePath); + bool success = file.open(QIODevice::ReadOnly); + if (!success) { + qWarning() << "ERROR: Couldn't restore chat from file:" << file.fileName(); + continue; + } + QDataStream in(&file); + Chat *chat = new Chat(this); + if (!chat->deserialize(in)) { + qWarning() << "ERROR: Couldn't deserialize chat from file:" << file.fileName(); + file.remove(); + } else { + connect(chat, &Chat::nameChanged, this, &ChatListModel::nameChanged); + m_chats.append(chat); + } + file.close(); + } + std::sort(m_chats.begin(), m_chats.end(), [](const Chat* a, const Chat* b) { + return a->creationDate() > b->creationDate(); + }); + endResetModel(); +} diff --git a/chatlistmodel.h b/chatlistmodel.h index 20c6eeba..68633da9 100644 --- a/chatlistmodel.h +++ b/chatlistmodel.h @@ -55,7 +55,7 @@ public: Q_INVOKABLE void addChat() { - // Don't add a new chat if the current chat is empty + // Don't add a new chat if we already have one if (m_newChat) return; @@ -73,13 +73,29 @@ public: setCurrentChat(m_newChat); } + void setNewChat(Chat* chat) + { + // Don't add a new chat if we already have one + if (m_newChat) + return; + + m_newChat = chat; + connect(m_newChat->chatModel(), &ChatModel::countChanged, + this, &ChatListModel::newChatCountChanged); + connect(m_newChat, &Chat::nameChanged, + this, &ChatListModel::nameChanged); + setCurrentChat(m_newChat); + } + Q_INVOKABLE void removeChat(Chat* chat) { if (!m_chats.contains(chat)) { - qDebug() << "WARNING: Removing chat failed with id" << chat->id(); + qWarning() << "WARNING: Removing chat failed with id" << chat->id(); return; } + removeChatFile(chat); + emit disconnectChat(chat); if (chat == m_newChat) { m_newChat->disconnect(this); @@ -115,20 +131,20 @@ public: void setCurrentChat(Chat *chat) { if (!m_chats.contains(chat)) { - qDebug() << "ERROR: Setting current chat failed with id" << chat->id(); + qWarning() << "ERROR: Setting current chat failed with id" << chat->id(); return; } if (m_currentChat) { if (m_currentChat->isModelLoaded()) - m_currentChat->unload(); + m_currentChat->unloadModel(); emit disconnect(m_currentChat); } emit connectChat(chat); m_currentChat = chat; if (!m_currentChat->isModelLoaded()) - m_currentChat->reload(); + m_currentChat->reloadModel(); emit currentChatChanged(); } @@ -138,9 +154,12 @@ public: return m_chats.at(index); } - int count() const { return m_chats.size(); } + void removeChatFile(Chat *chat) const; + void saveChats() const; + void restoreChats(); + Q_SIGNALS: void countChanged(); void connectChat(Chat*); diff --git a/chatllm.cpp b/chatllm.cpp index 68230127..071a7ddb 100644 --- a/chatllm.cpp +++ b/chatllm.cpp @@ -1,7 +1,7 @@ #include "chatllm.h" +#include "chat.h" #include "download.h" #include "network.h" -#include "llm.h" #include "llmodel/gptj.h" #include "llmodel/llamamodel.h" @@ -32,28 +32,29 @@ static QString modelFilePath(const QString &modelName) return QString(); } -ChatLLM::ChatLLM() +ChatLLM::ChatLLM(Chat *parent) : QObject{nullptr} , m_llmodel(nullptr) , m_promptResponseTokens(0) , m_responseLogits(0) , m_isRecalc(false) + , m_chat(parent) { moveToThread(&m_llmThread); - connect(&m_llmThread, &QThread::started, this, &ChatLLM::loadModel); connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup); connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded); - m_llmThread.setObjectName("llm thread"); // FIXME: Should identify these with chat name + connect(m_chat, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged); + m_llmThread.setObjectName(m_chat->id()); m_llmThread.start(); } -bool ChatLLM::loadModel() +bool ChatLLM::loadDefaultModel() { - const QList models = LLM::globalInstance()->modelList(); + const QList models = m_chat->modelList(); if (models.isEmpty()) { // try again when we get a list of models connect(Download::globalInstance(), &Download::modelListChanged, this, - &ChatLLM::loadModel, Qt::SingleShotConnection); + &ChatLLM::loadDefaultModel, Qt::SingleShotConnection); return false; } @@ -62,10 +63,10 @@ bool ChatLLM::loadModel() QString defaultModel = settings.value("defaultModel", "gpt4all-j-v1.3-groovy").toString(); if (defaultModel.isEmpty() || !models.contains(defaultModel)) defaultModel = models.first(); - return loadModelPrivate(defaultModel); + return loadModel(defaultModel); } -bool ChatLLM::loadModelPrivate(const QString &modelName) +bool ChatLLM::loadModel(const QString &modelName) { if (isModelLoaded() && m_modelName == modelName) return true; @@ -100,12 +101,13 @@ bool ChatLLM::loadModelPrivate(const QString &modelName) } emit isModelLoadedChanged(); - emit threadCountChanged(); if (isFirstLoad) emit sendStartup(); else emit sendModelLoaded(); + } else { + qWarning() << "ERROR: Could not find model at" << filePath; } if (m_llmodel) @@ -114,19 +116,6 @@ bool ChatLLM::loadModelPrivate(const QString &modelName) return m_llmodel; } -void ChatLLM::setThreadCount(int32_t n_threads) { - if (m_llmodel && m_llmodel->threadCount() != n_threads) { - m_llmodel->setThreadCount(n_threads); - emit threadCountChanged(); - } -} - -int32_t ChatLLM::threadCount() { - if (!m_llmodel) - return 1; - return m_llmodel->threadCount(); -} - bool ChatLLM::isModelLoaded() const { return m_llmodel && m_llmodel->isModelLoaded(); @@ -203,7 +192,7 @@ void ChatLLM::setModelName(const QString &modelName) void ChatLLM::modelNameChangeRequested(const QString &modelName) { - if (!loadModelPrivate(modelName)) + if (!loadModel(modelName)) qWarning() << "ERROR: Could not load model" << modelName; } @@ -247,8 +236,8 @@ bool ChatLLM::handleRecalculate(bool isRecalc) return !m_stopGenerating; } -bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, - float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens) +bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, + float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, int n_threads) { if (!isModelLoaded()) return false; @@ -269,6 +258,7 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3 m_ctx.n_batch = n_batch; m_ctx.repeat_penalty = repeat_penalty; m_ctx.repeat_last_n = repeat_penalty_tokens; + m_llmodel->setThreadCount(n_threads); #if defined(DEBUG) printf("%s", qPrintable(instructPrompt)); fflush(stdout); @@ -288,19 +278,22 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3 return true; } -void ChatLLM::unload() +void ChatLLM::unloadModel() { + saveState(); delete m_llmodel; m_llmodel = nullptr; emit isModelLoadedChanged(); } -void ChatLLM::reload(const QString &modelName) +void ChatLLM::reloadModel(const QString &modelName) { - if (modelName.isEmpty()) - loadModel(); - else - loadModelPrivate(modelName); + if (modelName.isEmpty()) { + loadDefaultModel(); + } else { + loadModel(modelName); + } + restoreState(); } void ChatLLM::generateName() @@ -333,6 +326,11 @@ void ChatLLM::generateName() } } +void ChatLLM::handleChatIdChanged() +{ + m_llmThread.setObjectName(m_chat->id()); +} + bool ChatLLM::handleNamePrompt(int32_t token) { Q_UNUSED(token); @@ -354,3 +352,60 @@ bool ChatLLM::handleNameRecalculate(bool isRecalc) Q_UNREACHABLE(); return true; } + +bool ChatLLM::serialize(QDataStream &stream) +{ + stream << response(); + stream << generatedName(); + stream << m_promptResponseTokens; + stream << m_responseLogits; + stream << m_ctx.n_past; + stream << quint64(m_ctx.logits.size()); + stream.writeRawData(reinterpret_cast(m_ctx.logits.data()), m_ctx.logits.size() * sizeof(float)); + stream << quint64(m_ctx.tokens.size()); + stream.writeRawData(reinterpret_cast(m_ctx.tokens.data()), m_ctx.tokens.size() * sizeof(int)); + saveState(); + stream << m_state; + return stream.status() == QDataStream::Ok; +} + +bool ChatLLM::deserialize(QDataStream &stream) +{ + QString response; + stream >> response; + m_response = response.toStdString(); + QString nameResponse; + stream >> nameResponse; + m_nameResponse = nameResponse.toStdString(); + stream >> m_promptResponseTokens; + stream >> m_responseLogits; + stream >> m_ctx.n_past; + quint64 logitsSize; + stream >> logitsSize; + m_ctx.logits.resize(logitsSize); + stream.readRawData(reinterpret_cast(m_ctx.logits.data()), logitsSize * sizeof(float)); + quint64 tokensSize; + stream >> tokensSize; + m_ctx.tokens.resize(tokensSize); + stream.readRawData(reinterpret_cast(m_ctx.tokens.data()), tokensSize * sizeof(int)); + stream >> m_state; + return stream.status() == QDataStream::Ok; +} + +void ChatLLM::saveState() +{ + if (!isModelLoaded()) + return; + + const size_t stateSize = m_llmodel->stateSize(); + m_state.resize(stateSize); + m_llmodel->saveState(static_cast(reinterpret_cast(m_state.data()))); +} + +void ChatLLM::restoreState() +{ + if (!isModelLoaded()) + return; + + m_llmodel->restoreState(static_cast(reinterpret_cast(m_state.data()))); +} diff --git a/chatllm.h b/chatllm.h index ab2dcc8c..dc1260b8 100644 --- a/chatllm.h +++ b/chatllm.h @@ -6,18 +6,18 @@ #include "llmodel/llmodel.h" +class Chat; class ChatLLM : public QObject { Q_OBJECT Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged) Q_PROPERTY(QString response READ response NOTIFY responseChanged) Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged) - Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged) Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged) Q_PROPERTY(QString generatedName READ generatedName NOTIFY generatedNameChanged) public: - ChatLLM(); + ChatLLM(Chat *parent); bool isModelLoaded() const; void regenerateResponse(); @@ -25,8 +25,6 @@ public: void resetContext(); void stopGenerating() { m_stopGenerating = true; } - void setThreadCount(int32_t n_threads); - int32_t threadCount(); QString response() const; QString modelName() const; @@ -37,14 +35,20 @@ public: QString generatedName() const { return QString::fromStdString(m_nameResponse); } + bool serialize(QDataStream &stream); + bool deserialize(QDataStream &stream); + public Q_SLOTS: - bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, - float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens); - bool loadModel(); + bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, + int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, + int32_t n_threads); + bool loadDefaultModel(); + bool loadModel(const QString &modelName); void modelNameChangeRequested(const QString &modelName); - void unload(); - void reload(const QString &modelName); + void unloadModel(); + void reloadModel(const QString &modelName); void generateName(); + void handleChatIdChanged(); Q_SIGNALS: void isModelLoadedChanged(); @@ -52,22 +56,23 @@ Q_SIGNALS: void responseStarted(); void responseStopped(); void modelNameChanged(); - void threadCountChanged(); void recalcChanged(); void sendStartup(); void sendModelLoaded(); void sendResetContext(); void generatedNameChanged(); + void stateChanged(); private: void resetContextPrivate(); - bool loadModelPrivate(const QString &modelName); bool handlePrompt(int32_t token); bool handleResponse(int32_t token, const std::string &response); bool handleRecalculate(bool isRecalc); bool handleNamePrompt(int32_t token); bool handleNameResponse(int32_t token, const std::string &response); bool handleNameRecalculate(bool isRecalc); + void saveState(); + void restoreState(); private: LLModel::PromptContext m_ctx; @@ -77,6 +82,8 @@ private: quint32 m_promptResponseTokens; quint32 m_responseLogits; QString m_modelName; + Chat *m_chat; + QByteArray m_state; QThread m_llmThread; std::atomic m_stopGenerating; bool m_isRecalc; diff --git a/chatmodel.h b/chatmodel.h index e5be2719..f3e59fa3 100644 --- a/chatmodel.h +++ b/chatmodel.h @@ -3,6 +3,7 @@ #include #include +#include struct ChatItem { @@ -209,6 +210,46 @@ public: int count() const { return m_chatItems.size(); } + bool serialize(QDataStream &stream) const + { + stream << count(); + for (auto c : m_chatItems) { + stream << c.id; + stream << c.name; + stream << c.value; + stream << c.prompt; + stream << c.newResponse; + stream << c.currentResponse; + stream << c.stopped; + stream << c.thumbsUpState; + stream << c.thumbsDownState; + } + return stream.status() == QDataStream::Ok; + } + + bool deserialize(QDataStream &stream) + { + int size; + stream >> size; + for (int i = 0; i < size; ++i) { + ChatItem c; + stream >> c.id; + stream >> c.name; + stream >> c.value; + stream >> c.prompt; + stream >> c.newResponse; + stream >> c.currentResponse; + stream >> c.stopped; + stream >> c.thumbsUpState; + stream >> c.thumbsDownState; + beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size()); + m_chatItems.append(c); + endInsertRows(); + } + emit countChanged(); + return stream.status() == QDataStream::Ok; + } + Q_SIGNALS: void countChanged(); diff --git a/llm.cpp b/llm.cpp index f624b7a3..2689e47f 100644 --- a/llm.cpp +++ b/llm.cpp @@ -20,77 +20,22 @@ LLM *LLM::globalInstance() LLM::LLM() : QObject{nullptr} , m_chatListModel(new ChatListModel(this)) + , m_threadCount(std::min(4, (int32_t) std::thread::hardware_concurrency())) { - // Should be in the same thread - connect(Download::globalInstance(), &Download::modelListChanged, - this, &LLM::modelListChanged, Qt::DirectConnection); - connect(m_chatListModel, &ChatListModel::connectChat, - this, &LLM::connectChat, Qt::DirectConnection); - connect(m_chatListModel, &ChatListModel::disconnectChat, - this, &LLM::disconnectChat, Qt::DirectConnection); - - if (!m_chatListModel->count()) + connect(QCoreApplication::instance(), &QCoreApplication::aboutToQuit, + this, &LLM::aboutToQuit); + + m_chatListModel->restoreChats(); + if (m_chatListModel->count()) { + Chat *firstChat = m_chatListModel->get(0); + if (firstChat->chatModel()->count() < 2) + m_chatListModel->setNewChat(firstChat); + else + m_chatListModel->setCurrentChat(firstChat); + } else m_chatListModel->addChat(); } -QList LLM::modelList() const -{ - Q_ASSERT(m_chatListModel->currentChat()); - const Chat *currentChat = m_chatListModel->currentChat(); - // Build a model list from exepath and from the localpath - QList list; - - QString exePath = QCoreApplication::applicationDirPath() + QDir::separator(); - QString localPath = Download::globalInstance()->downloadLocalModelsPath(); - - { - QDir dir(exePath); - dir.setNameFilters(QStringList() << "ggml-*.bin"); - QStringList fileNames = dir.entryList(); - for (QString f : fileNames) { - QString filePath = exePath + f; - QFileInfo info(filePath); - QString name = info.completeBaseName().remove(0, 5); - if (info.exists()) { - if (name == currentChat->modelName()) - list.prepend(name); - else - list.append(name); - } - } - } - - if (localPath != exePath) { - QDir dir(localPath); - dir.setNameFilters(QStringList() << "ggml-*.bin"); - QStringList fileNames = dir.entryList(); - for (QString f : fileNames) { - QString filePath = localPath + f; - QFileInfo info(filePath); - QString name = info.completeBaseName().remove(0, 5); - if (info.exists() && !list.contains(name)) { // don't allow duplicates - if (name == currentChat->modelName()) - list.prepend(name); - else - list.append(name); - } - } - } - - if (list.isEmpty()) { - if (exePath != localPath) { - qWarning() << "ERROR: Could not find any applicable models in" - << exePath << "nor" << localPath; - } else { - qWarning() << "ERROR: Could not find any applicable models in" - << exePath; - } - return QList(); - } - - return list; -} - bool LLM::checkForUpdates() const { Network::globalInstance()->sendCheckForUpdates(); @@ -113,21 +58,20 @@ bool LLM::checkForUpdates() const return QProcess::startDetached(fileName); } -bool LLM::isRecalc() const +int32_t LLM::threadCount() const { - Q_ASSERT(m_chatListModel->currentChat()); - return m_chatListModel->currentChat()->isRecalc(); + return m_threadCount; } -void LLM::connectChat(Chat *chat) +void LLM::setThreadCount(int32_t n_threads) { - // Should be in the same thread - connect(chat, &Chat::modelNameChanged, this, &LLM::modelListChanged, Qt::DirectConnection); - connect(chat, &Chat::recalcChanged, this, &LLM::recalcChanged, Qt::DirectConnection); - connect(chat, &Chat::responseChanged, this, &LLM::responseChanged, Qt::DirectConnection); + if (n_threads <= 0) + n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + m_threadCount = n_threads; + emit threadCountChanged(); } -void LLM::disconnectChat(Chat *chat) +void LLM::aboutToQuit() { - chat->disconnect(this); + m_chatListModel->saveChats(); } diff --git a/llm.h b/llm.h index e291ebbf..89fee17f 100644 --- a/llm.h +++ b/llm.h @@ -3,37 +3,33 @@ #include -#include "chat.h" #include "chatlistmodel.h" class LLM : public QObject { Q_OBJECT - Q_PROPERTY(QList modelList READ modelList NOTIFY modelListChanged) - Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged) Q_PROPERTY(ChatListModel *chatListModel READ chatListModel NOTIFY chatListModelChanged) + Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged) public: static LLM *globalInstance(); - QList modelList() const; - bool isRecalc() const; ChatListModel *chatListModel() const { return m_chatListModel; } + int32_t threadCount() const; + void setThreadCount(int32_t n_threads); Q_INVOKABLE bool checkForUpdates() const; Q_SIGNALS: - void modelListChanged(); - void recalcChanged(); - void responseChanged(); void chatListModelChanged(); + void threadCountChanged(); private Q_SLOTS: - void connectChat(Chat*); - void disconnectChat(Chat*); + void aboutToQuit(); private: ChatListModel *m_chatListModel; + int32_t m_threadCount; private: explicit LLM(); diff --git a/llmodel/llamamodel.cpp b/llmodel/llamamodel.cpp index d89862f1..272633c7 100644 --- a/llmodel/llamamodel.cpp +++ b/llmodel/llamamodel.cpp @@ -67,6 +67,7 @@ int32_t LLamaModel::threadCount() { LLamaModel::~LLamaModel() { + llama_free(d_ptr->ctx); } bool LLamaModel::isModelLoaded() const @@ -74,6 +75,21 @@ bool LLamaModel::isModelLoaded() const return d_ptr->modelLoaded; } +size_t LLamaModel::stateSize() const +{ + return llama_get_state_size(d_ptr->ctx); +} + +size_t LLamaModel::saveState(uint8_t *dest) const +{ + return llama_copy_state_data(d_ptr->ctx, dest); +} + +size_t LLamaModel::restoreState(const uint8_t *src) +{ + return llama_set_state_data(d_ptr->ctx, src); +} + void LLamaModel::prompt(const std::string &prompt, std::function promptCallback, std::function responseCallback, diff --git a/llmodel/llamamodel.h b/llmodel/llamamodel.h index 13e221a7..7f487803 100644 --- a/llmodel/llamamodel.h +++ b/llmodel/llamamodel.h @@ -14,6 +14,9 @@ public: bool loadModel(const std::string &modelPath) override; bool isModelLoaded() const override; + size_t stateSize() const override; + size_t saveState(uint8_t *dest) const override; + size_t restoreState(const uint8_t *src) override; void prompt(const std::string &prompt, std::function promptCallback, std::function responseCallback, diff --git a/llmodel/llmodel.h b/llmodel/llmodel.h index 08dc1764..5ef900f4 100644 --- a/llmodel/llmodel.h +++ b/llmodel/llmodel.h @@ -12,6 +12,9 @@ public: virtual bool loadModel(const std::string &modelPath) = 0; virtual bool isModelLoaded() const = 0; + virtual size_t stateSize() const { return 0; } + virtual size_t saveState(uint8_t *dest) const { return 0; } + virtual size_t restoreState(const uint8_t *src) { return 0; } struct PromptContext { std::vector logits; // logits of current context std::vector tokens; // current tokens in the context window diff --git a/llmodel/llmodel_c.cpp b/llmodel/llmodel_c.cpp index 46eb1a7d..9788f1fb 100644 --- a/llmodel/llmodel_c.cpp +++ b/llmodel/llmodel_c.cpp @@ -48,6 +48,24 @@ bool llmodel_isModelLoaded(llmodel_model model) return wrapper->llModel->isModelLoaded(); } +uint64_t llmodel_get_state_size(llmodel_model model) +{ + LLModelWrapper *wrapper = reinterpret_cast(model); + return wrapper->llModel->stateSize(); +} + +uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest) +{ + LLModelWrapper *wrapper = reinterpret_cast(model); + return wrapper->llModel->saveState(dest); +} + +uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src) +{ + LLModelWrapper *wrapper = reinterpret_cast(model); + return wrapper->llModel->restoreState(src); +} + // Wrapper functions for the C callbacks bool prompt_wrapper(int32_t token_id, void *user_data) { llmodel_prompt_callback callback = reinterpret_cast(user_data); diff --git a/llmodel/llmodel_c.h b/llmodel/llmodel_c.h index 45cc9cd2..0907d765 100644 --- a/llmodel/llmodel_c.h +++ b/llmodel/llmodel_c.h @@ -98,6 +98,32 @@ bool llmodel_loadModel(llmodel_model model, const char *model_path); */ bool llmodel_isModelLoaded(llmodel_model model); +/** + * Get the size of the internal state of the model. + * NOTE: This state data is specific to the type of model you have created. + * @param model A pointer to the llmodel_model instance. + * @return the size in bytes of the internal state of the model + */ +uint64_t llmodel_get_state_size(llmodel_model model); + +/** + * Saves the internal state of the model to the specified destination address. + * NOTE: This state data is specific to the type of model you have created. + * @param model A pointer to the llmodel_model instance. + * @param dest A pointer to the destination. + * @return the number of bytes copied + */ +uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest); + +/** + * Restores the internal state of the model using data from the specified address. + * NOTE: This state data is specific to the type of model you have created. + * @param model A pointer to the llmodel_model instance. + * @param src A pointer to the src. + * @return the number of bytes read + */ +uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src); + /** * Generate a response using the model. * @param model A pointer to the llmodel_model instance. diff --git a/main.qml b/main.qml index d50087bc..a41730bb 100644 --- a/main.qml +++ b/main.qml @@ -65,7 +65,7 @@ Window { } // check for any current models and if not, open download dialog - if (LLM.modelList.length === 0 && !firstStartDialog.opened) { + if (currentChat.modelList.length === 0 && !firstStartDialog.opened) { downloadNewModels.open(); return; } @@ -125,7 +125,7 @@ Window { anchors.horizontalCenter: parent.horizontalCenter font.pixelSize: theme.fontSizeLarge spacing: 0 - model: LLM.modelList + model: currentChat.modelList Accessible.role: Accessible.ComboBox Accessible.name: qsTr("ComboBox for displaying/picking the current model") Accessible.description: qsTr("Use this for picking the current model to use; the first item is the current model") @@ -367,9 +367,9 @@ Window { text: qsTr("Recalculating context.") Connections { - target: LLM + target: currentChat function onRecalcChanged() { - if (LLM.isRecalc) + if (currentChat.isRecalc) recalcPopup.open() else recalcPopup.close() @@ -422,10 +422,7 @@ Window { var item = chatModel.get(i) var string = item.name; var isResponse = item.name === qsTr("Response: ") - if (item.currentResponse) - string += currentChat.response - else - string += chatModel.get(i).value + string += chatModel.get(i).value if (isResponse && item.stopped) string += " " string += "\n" @@ -440,10 +437,7 @@ Window { var item = chatModel.get(i) var isResponse = item.name === qsTr("Response: ") str += "{\"content\": "; - if (item.currentResponse) - str += JSON.stringify(currentChat.response) - else - str += JSON.stringify(item.value) + str += JSON.stringify(item.value) str += ", \"role\": \"" + (isResponse ? "assistant" : "user") + "\""; if (isResponse && item.thumbsUpState !== item.thumbsDownState) str += ", \"rating\": \"" + (item.thumbsUpState ? "positive" : "negative") + "\""; @@ -572,14 +566,14 @@ Window { Accessible.description: qsTr("This is the list of prompt/response pairs comprising the actual conversation with the model") delegate: TextArea { - text: currentResponse ? currentChat.response : (value ? value : "") + text: value width: listView.width color: theme.textColor wrapMode: Text.WordWrap focus: false readOnly: true font.pixelSize: theme.fontSizeLarge - cursorVisible: currentResponse ? (currentChat.response !== "" ? currentChat.responseInProgress : false) : false + cursorVisible: currentResponse ? currentChat.responseInProgress : false cursorPosition: text.length background: Rectangle { color: name === qsTr("Response: ") ? theme.backgroundLighter : theme.backgroundLight @@ -599,8 +593,8 @@ Window { anchors.leftMargin: 90 anchors.top: parent.top anchors.topMargin: 5 - visible: (currentResponse ? true : false) && currentChat.response === "" && currentChat.responseInProgress - running: (currentResponse ? true : false) && currentChat.response === "" && currentChat.responseInProgress + visible: (currentResponse ? true : false) && value === "" && currentChat.responseInProgress + running: (currentResponse ? true : false) && value === "" && currentChat.responseInProgress Accessible.role: Accessible.Animation Accessible.name: qsTr("Busy indicator") @@ -631,7 +625,7 @@ Window { window.height / 2 - height / 2) x: globalPoint.x y: globalPoint.y - property string text: currentResponse ? currentChat.response : (value ? value : "") + property string text: value response: newResponse === undefined || newResponse === "" ? text : newResponse onAccepted: { var responseHasChanged = response !== text && response !== newResponse @@ -711,7 +705,7 @@ Window { property bool isAutoScrolling: false Connections { - target: LLM + target: currentChat function onResponseChanged() { if (listView.shouldAutoScroll) { listView.isAutoScrolling = true @@ -762,7 +756,6 @@ Window { if (listElement.name === qsTr("Response: ")) { chatModel.updateCurrentResponse(index, true); chatModel.updateStopped(index, false); - chatModel.updateValue(index, currentChat.response); chatModel.updateThumbsUpState(index, false); chatModel.updateThumbsDownState(index, false); chatModel.updateNewResponse(index, ""); @@ -840,7 +833,6 @@ Window { var index = Math.max(0, chatModel.count - 1); var listElement = chatModel.get(index); chatModel.updateCurrentResponse(index, false); - chatModel.updateValue(index, currentChat.response); } currentChat.newPromptResponsePair(textInput.text); currentChat.prompt(textInput.text, settingsDialog.promptTemplate, diff --git a/network.cpp b/network.cpp index dfafddf8..ce77419e 100644 --- a/network.cpp +++ b/network.cpp @@ -458,7 +458,6 @@ void Network::handleIpifyFinished() void Network::handleMixpanelFinished() { - Q_ASSERT(m_usageStatsActive); QNetworkReply *reply = qobject_cast(sender()); if (!reply) return; diff --git a/qml/ChatDrawer.qml b/qml/ChatDrawer.qml index 1d581639..1bf37548 100644 --- a/qml/ChatDrawer.qml +++ b/qml/ChatDrawer.qml @@ -83,6 +83,7 @@ Drawer { opacity: 0.9 property bool isCurrent: LLM.chatListModel.currentChat === LLM.chatListModel.get(index) property bool trashQuestionDisplayed: false + z: isCurrent ? 199 : 1 color: index % 2 === 0 ? theme.backgroundLight : theme.backgroundLighter border.width: isCurrent border.color: chatName.readOnly ? theme.assistantColor : theme.userColor @@ -112,6 +113,11 @@ Drawer { color: "transparent" } onEditingFinished: { + // Work around a bug in qml where we're losing focus when the whole window + // goes out of focus even though this textfield should be marked as not + // having focus + if (chatName.readOnly) + return; changeName(); Network.sendRenameChat() } @@ -188,6 +194,7 @@ Drawer { visible: isCurrent && trashQuestionDisplayed opacity: 1.0 radius: 10 + z: 200 Row { spacing: 10 Button { diff --git a/qml/ModelDownloaderDialog.qml b/qml/ModelDownloaderDialog.qml index 92b564f8..cca4e494 100644 --- a/qml/ModelDownloaderDialog.qml +++ b/qml/ModelDownloaderDialog.qml @@ -12,7 +12,7 @@ Dialog { id: modelDownloaderDialog modal: true opacity: 0.9 - closePolicy: LLM.modelList.length === 0 ? Popup.NoAutoClose : (Popup.CloseOnEscape | Popup.CloseOnPressOutside) + closePolicy: LLM.chatListModel.currentChat.modelList.length === 0 ? Popup.NoAutoClose : (Popup.CloseOnEscape | Popup.CloseOnPressOutside) background: Rectangle { anchors.fill: parent anchors.margins: -20