diff --git a/CMakeLists.txt b/CMakeLists.txt index b9182434..57561c6b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -59,6 +59,7 @@ set (CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) qt_add_executable(chat main.cpp chat.h chat.cpp chatmodel.h + chatllm.h chatllm.cpp download.h download.cpp network.h network.cpp llm.h llm.cpp diff --git a/chat.cpp b/chat.cpp index fbc66a34..3a330d24 100644 --- a/chat.cpp +++ b/chat.cpp @@ -1 +1,117 @@ #include "chat.h" +#include "network.h" + +Chat::Chat(QObject *parent) + : QObject(parent) + , m_llmodel(new ChatLLM) + , m_id(Network::globalInstance()->generateUniqueId()) + , m_name(tr("New Chat")) + , m_chatModel(new ChatModel(this)) + , m_responseInProgress(false) + , m_desiredThreadCount(std::min(4, (int32_t) std::thread::hardware_concurrency())) +{ + connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::responseChanged, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::responseStarted, this, &Chat::responseStarted, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::modelNameChanged, this, &Chat::modelNameChanged, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::threadCountChanged, this, &Chat::threadCountChanged, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::threadCountChanged, this, &Chat::syncThreadCount, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::recalcChanged, Qt::QueuedConnection); + + connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection); + connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection); + + // The following are blocking operations and will block the gui thread, therefore must be fast + // to respond to + connect(this, &Chat::regenerateResponseRequested, m_llmodel, &ChatLLM::regenerateResponse, Qt::BlockingQueuedConnection); + connect(this, &Chat::resetResponseRequested, m_llmodel, &ChatLLM::resetResponse, Qt::BlockingQueuedConnection); + connect(this, &Chat::resetContextRequested, m_llmodel, &ChatLLM::resetContext, Qt::BlockingQueuedConnection); + connect(this, &Chat::setThreadCountRequested, m_llmodel, &ChatLLM::setThreadCount, Qt::QueuedConnection); +} + +void Chat::reset() +{ + m_id = Network::globalInstance()->generateUniqueId(); + m_chatModel->clear(); +} + +bool Chat::isModelLoaded() const +{ + return m_llmodel->isModelLoaded(); +} + +void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, + float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens) +{ + emit promptRequested(prompt, prompt_template, n_predict, top_k, top_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens); +} + +void Chat::regenerateResponse() +{ + emit regenerateResponseRequested(); // blocking queued connection +} + +void Chat::resetResponse() +{ + emit resetResponseRequested(); // blocking queued connection +} + +void Chat::resetContext() +{ + emit resetContextRequested(); // blocking queued connection +} + +void Chat::stopGenerating() +{ + m_llmodel->stopGenerating(); +} + +QString Chat::response() const +{ + return m_llmodel->response(); +} + +void Chat::responseStarted() +{ + m_responseInProgress = true; + emit responseInProgressChanged(); +} + +void Chat::responseStopped() +{ + m_responseInProgress = false; + emit responseInProgressChanged(); +} + +QString Chat::modelName() const +{ + return m_llmodel->modelName(); +} + +void Chat::setModelName(const QString &modelName) +{ + // doesn't block but will unload old model and load new one which the gui can see through changes + // to the isModelLoaded property + emit modelNameChangeRequested(modelName); +} + +void Chat::syncThreadCount() { + emit setThreadCountRequested(m_desiredThreadCount); +} + +void Chat::setThreadCount(int32_t n_threads) { + if (n_threads <= 0) + n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + m_desiredThreadCount = n_threads; + syncThreadCount(); +} + +int32_t Chat::threadCount() { + return m_llmodel->threadCount(); +} + +bool Chat::isRecalc() const +{ + return m_llmodel->isRecalc(); +} diff --git a/chat.h b/chat.h index 1934b3df..66687ee4 100644 --- a/chat.h +++ b/chat.h @@ -4,8 +4,8 @@ #include #include +#include "chatllm.h" #include "chatmodel.h" -#include "network.h" class Chat : public QObject { @@ -13,36 +13,69 @@ class Chat : public QObject Q_PROPERTY(QString id READ id NOTIFY idChanged) Q_PROPERTY(QString name READ name NOTIFY nameChanged) Q_PROPERTY(ChatModel *chatModel READ chatModel NOTIFY chatModelChanged) + Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged) + Q_PROPERTY(QString response READ response NOTIFY responseChanged) + Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged) + Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged) + Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged) + Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged) QML_ELEMENT QML_UNCREATABLE("Only creatable from c++!") public: - explicit Chat(QObject *parent = nullptr) : QObject(parent) - { - m_id = Network::globalInstance()->generateUniqueId(); - m_name = tr("New Chat"); - m_chatModel = new ChatModel(this); - } + explicit Chat(QObject *parent = nullptr); QString id() const { return m_id; } QString name() const { return m_name; } ChatModel *chatModel() { return m_chatModel; } - Q_INVOKABLE void reset() - { - m_id = Network::globalInstance()->generateUniqueId(); - m_chatModel->clear(); - } + Q_INVOKABLE void reset(); + Q_INVOKABLE bool isModelLoaded() const; + Q_INVOKABLE void prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, + float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens); + Q_INVOKABLE void regenerateResponse(); + Q_INVOKABLE void resetResponse(); + Q_INVOKABLE void resetContext(); + Q_INVOKABLE void stopGenerating(); + Q_INVOKABLE void syncThreadCount(); + Q_INVOKABLE void setThreadCount(int32_t n_threads); + Q_INVOKABLE int32_t threadCount(); + + QString response() const; + bool responseInProgress() const { return m_responseInProgress; } + QString modelName() const; + void setModelName(const QString &modelName); + bool isRecalc() const; Q_SIGNALS: void idChanged(); void nameChanged(); void chatModelChanged(); + void isModelLoadedChanged(); + void responseChanged(); + void responseInProgressChanged(); + void promptRequested(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, + float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens); + void regenerateResponseRequested(); + void resetResponseRequested(); + void resetContextRequested(); + void modelNameChangeRequested(const QString &modelName); + void modelNameChanged(); + void threadCountChanged(); + void setThreadCountRequested(int32_t threadCount); + void recalcChanged(); + +private Q_SLOTS: + void responseStarted(); + void responseStopped(); private: + ChatLLM *m_llmodel; QString m_id; QString m_name; ChatModel *m_chatModel; + bool m_responseInProgress; + int32_t m_desiredThreadCount; }; #endif // CHAT_H diff --git a/chatllm.cpp b/chatllm.cpp new file mode 100644 index 00000000..51e36aaa --- /dev/null +++ b/chatllm.cpp @@ -0,0 +1,290 @@ +#include "chatllm.h" +#include "download.h" +#include "network.h" +#include "llm.h" +#include "llmodel/gptj.h" +#include "llmodel/llamamodel.h" + +#include +#include +#include +#include +#include +#include +#include + +//#define DEBUG + +static QString modelFilePath(const QString &modelName) +{ + QString appPath = QCoreApplication::applicationDirPath() + + "/ggml-" + modelName + ".bin"; + QFileInfo infoAppPath(appPath); + if (infoAppPath.exists()) + return appPath; + + QString downloadPath = Download::globalInstance()->downloadLocalModelsPath() + + "/ggml-" + modelName + ".bin"; + + QFileInfo infoLocalPath(downloadPath); + if (infoLocalPath.exists()) + return downloadPath; + return QString(); +} + +ChatLLM::ChatLLM() + : QObject{nullptr} + , m_llmodel(nullptr) + , m_promptResponseTokens(0) + , m_responseLogits(0) + , m_isRecalc(false) +{ + moveToThread(&m_llmThread); + connect(&m_llmThread, &QThread::started, this, &ChatLLM::loadModel); + connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup); + connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded); + connect(this, &ChatLLM::sendResetContext, Network::globalInstance(), &Network::sendResetContext); + m_llmThread.setObjectName("llm thread"); // FIXME: Should identify these with chat name + m_llmThread.start(); +} + +bool ChatLLM::loadModel() +{ + const QList models = LLM::globalInstance()->modelList(); + if (models.isEmpty()) { + // try again when we get a list of models + connect(Download::globalInstance(), &Download::modelListChanged, this, + &ChatLLM::loadModel, Qt::SingleShotConnection); + return false; + } + + QSettings settings; + settings.sync(); + QString defaultModel = settings.value("defaultModel", "gpt4all-j-v1.3-groovy").toString(); + if (defaultModel.isEmpty() || !models.contains(defaultModel)) + defaultModel = models.first(); + return loadModelPrivate(defaultModel); +} + +bool ChatLLM::loadModelPrivate(const QString &modelName) +{ + if (isModelLoaded() && m_modelName == modelName) + return true; + + bool isFirstLoad = false; + if (isModelLoaded()) { + resetContextPrivate(); + delete m_llmodel; + m_llmodel = nullptr; + emit isModelLoadedChanged(); + } else { + isFirstLoad = true; + } + + bool isGPTJ = false; + QString filePath = modelFilePath(modelName); + QFileInfo info(filePath); + if (info.exists()) { + + auto fin = std::ifstream(filePath.toStdString(), std::ios::binary); + uint32_t magic; + fin.read((char *) &magic, sizeof(magic)); + fin.seekg(0); + fin.close(); + isGPTJ = magic == 0x67676d6c; + if (isGPTJ) { + m_llmodel = new GPTJ; + m_llmodel->loadModel(filePath.toStdString()); + } else { + m_llmodel = new LLamaModel; + m_llmodel->loadModel(filePath.toStdString()); + } + + emit isModelLoadedChanged(); + emit threadCountChanged(); + + if (isFirstLoad) + emit sendStartup(); + else + emit sendModelLoaded(); + } + + if (m_llmodel) + setModelName(info.completeBaseName().remove(0, 5)); // remove the ggml- prefix + + return m_llmodel; +} + +void ChatLLM::setThreadCount(int32_t n_threads) { + if (m_llmodel && m_llmodel->threadCount() != n_threads) { + m_llmodel->setThreadCount(n_threads); + emit threadCountChanged(); + } +} + +int32_t ChatLLM::threadCount() { + if (!m_llmodel) + return 1; + return m_llmodel->threadCount(); +} + +bool ChatLLM::isModelLoaded() const +{ + return m_llmodel && m_llmodel->isModelLoaded(); +} + +void ChatLLM::regenerateResponse() +{ + m_ctx.n_past -= m_promptResponseTokens; + m_ctx.n_past = std::max(0, m_ctx.n_past); + // FIXME: This does not seem to be needed in my testing and llama models don't to it. Remove? + m_ctx.logits.erase(m_ctx.logits.end() -= m_responseLogits, m_ctx.logits.end()); + m_ctx.tokens.erase(m_ctx.tokens.end() -= m_promptResponseTokens, m_ctx.tokens.end()); + m_promptResponseTokens = 0; + m_responseLogits = 0; + m_response = std::string(); + emit responseChanged(); +} + +void ChatLLM::resetResponse() +{ + m_promptResponseTokens = 0; + m_responseLogits = 0; + m_response = std::string(); + emit responseChanged(); +} + +void ChatLLM::resetContext() +{ + resetContextPrivate(); + emit sendResetContext(); +} + +void ChatLLM::resetContextPrivate() +{ + regenerateResponse(); + m_ctx = LLModel::PromptContext(); +} + +std::string remove_leading_whitespace(const std::string& input) { + auto first_non_whitespace = std::find_if(input.begin(), input.end(), [](unsigned char c) { + return !std::isspace(c); + }); + + return std::string(first_non_whitespace, input.end()); +} + +std::string trim_whitespace(const std::string& input) { + auto first_non_whitespace = std::find_if(input.begin(), input.end(), [](unsigned char c) { + return !std::isspace(c); + }); + + auto last_non_whitespace = std::find_if(input.rbegin(), input.rend(), [](unsigned char c) { + return !std::isspace(c); + }).base(); + + return std::string(first_non_whitespace, last_non_whitespace); +} + +QString ChatLLM::response() const +{ + return QString::fromStdString(remove_leading_whitespace(m_response)); +} + +QString ChatLLM::modelName() const +{ + return m_modelName; +} + +void ChatLLM::setModelName(const QString &modelName) +{ + m_modelName = modelName; + emit modelNameChanged(); +} + +void ChatLLM::modelNameChangeRequested(const QString &modelName) +{ + if (!loadModelPrivate(modelName)) + qWarning() << "ERROR: Could not load model" << modelName; +} + +bool ChatLLM::handlePrompt(int32_t token) +{ + // m_promptResponseTokens and m_responseLogits are related to last prompt/response not + // the entire context window which we can reset on regenerate prompt + ++m_promptResponseTokens; + return !m_stopGenerating; +} + +bool ChatLLM::handleResponse(int32_t token, const std::string &response) +{ +#if defined(DEBUG) + printf("%s", response.c_str()); + fflush(stdout); +#endif + + // check for error + if (token < 0) { + m_response.append(response); + emit responseChanged(); + return false; + } + + // m_promptResponseTokens and m_responseLogits are related to last prompt/response not + // the entire context window which we can reset on regenerate prompt + ++m_promptResponseTokens; + Q_ASSERT(!response.empty()); + m_response.append(response); + emit responseChanged(); + return !m_stopGenerating; +} + +bool ChatLLM::handleRecalculate(bool isRecalc) +{ + if (m_isRecalc != isRecalc) { + m_isRecalc = isRecalc; + emit recalcChanged(); + } + return !m_stopGenerating; +} + +bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, + float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens) +{ + if (!isModelLoaded()) + return false; + + QString instructPrompt = prompt_template.arg(prompt); + + m_stopGenerating = false; + auto promptFunc = std::bind(&ChatLLM::handlePrompt, this, std::placeholders::_1); + auto responseFunc = std::bind(&ChatLLM::handleResponse, this, std::placeholders::_1, + std::placeholders::_2); + auto recalcFunc = std::bind(&ChatLLM::handleRecalculate, this, std::placeholders::_1); + emit responseStarted(); + qint32 logitsBefore = m_ctx.logits.size(); + m_ctx.n_predict = n_predict; + m_ctx.top_k = top_k; + m_ctx.top_p = top_p; + m_ctx.temp = temp; + m_ctx.n_batch = n_batch; + m_ctx.repeat_penalty = repeat_penalty; + m_ctx.repeat_last_n = repeat_penalty_tokens; +#if defined(DEBUG) + printf("%s", qPrintable(instructPrompt)); + fflush(stdout); +#endif + m_llmodel->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx); +#if defined(DEBUG) + printf("\n"); + fflush(stdout); +#endif + m_responseLogits += m_ctx.logits.size() - logitsBefore; + std::string trimmed = trim_whitespace(m_response); + if (trimmed != m_response) { + m_response = trimmed; + emit responseChanged(); + } + emit responseStopped(); + return true; +} diff --git a/chatllm.h b/chatllm.h new file mode 100644 index 00000000..eb14cdf1 --- /dev/null +++ b/chatllm.h @@ -0,0 +1,75 @@ +#ifndef CHATLLM_H +#define CHATLLM_H + +#include +#include + +#include "llmodel/llmodel.h" + +class ChatLLM : public QObject +{ + Q_OBJECT + Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged) + Q_PROPERTY(QString response READ response NOTIFY responseChanged) + Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged) + Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged) + Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged) + +public: + + ChatLLM(); + + bool isModelLoaded() const; + void regenerateResponse(); + void resetResponse(); + void resetContext(); + + void stopGenerating() { m_stopGenerating = true; } + void setThreadCount(int32_t n_threads); + int32_t threadCount(); + + QString response() const; + QString modelName() const; + + void setModelName(const QString &modelName); + + bool isRecalc() const { return m_isRecalc; } + +public Q_SLOTS: + bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, + float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens); + bool loadModel(); + void modelNameChangeRequested(const QString &modelName); + +Q_SIGNALS: + void isModelLoadedChanged(); + void responseChanged(); + void responseStarted(); + void responseStopped(); + void modelNameChanged(); + void threadCountChanged(); + void recalcChanged(); + void sendStartup(); + void sendModelLoaded(); + void sendResetContext(); + +private: + void resetContextPrivate(); + bool loadModelPrivate(const QString &modelName); + bool handlePrompt(int32_t token); + bool handleResponse(int32_t token, const std::string &response); + bool handleRecalculate(bool isRecalc); + +private: + LLModel::PromptContext m_ctx; + LLModel *m_llmodel; + std::string m_response; + quint32 m_promptResponseTokens; + quint32 m_responseLogits; + QString m_modelName; + QThread m_llmThread; + std::atomic m_stopGenerating; + bool m_isRecalc; +}; + +#endif // CHATLLM_H diff --git a/llm.cpp b/llm.cpp index 8e3abe0d..a0773ee1 100644 --- a/llm.cpp +++ b/llm.cpp @@ -19,203 +19,18 @@ LLM *LLM::globalInstance() return llmInstance(); } -static LLModel::PromptContext s_ctx; - -static QString modelFilePath(const QString &modelName) -{ - QString appPath = QCoreApplication::applicationDirPath() - + "/ggml-" + modelName + ".bin"; - QFileInfo infoAppPath(appPath); - if (infoAppPath.exists()) - return appPath; - - QString downloadPath = Download::globalInstance()->downloadLocalModelsPath() - + "/ggml-" + modelName + ".bin"; - - QFileInfo infoLocalPath(downloadPath); - if (infoLocalPath.exists()) - return downloadPath; - return QString(); -} - -LLMObject::LLMObject() +LLM::LLM() : QObject{nullptr} - , m_llmodel(nullptr) - , m_promptResponseTokens(0) - , m_responseLogits(0) - , m_isRecalc(false) -{ - moveToThread(&m_llmThread); - connect(&m_llmThread, &QThread::started, this, &LLMObject::loadModel); - connect(this, &LLMObject::sendStartup, Network::globalInstance(), &Network::sendStartup); - connect(this, &LLMObject::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded); - connect(this, &LLMObject::sendResetContext, Network::globalInstance(), &Network::sendResetContext); - m_llmThread.setObjectName("llm thread"); - m_llmThread.start(); -} - -bool LLMObject::loadModel() -{ - const QList models = modelList(); - if (models.isEmpty()) { - // try again when we get a list of models - connect(Download::globalInstance(), &Download::modelListChanged, this, - &LLMObject::loadModel, Qt::SingleShotConnection); - return false; - } - - QSettings settings; - settings.sync(); - QString defaultModel = settings.value("defaultModel", "gpt4all-j-v1.3-groovy").toString(); - if (defaultModel.isEmpty() || !models.contains(defaultModel)) - defaultModel = models.first(); - return loadModelPrivate(defaultModel); -} - -bool LLMObject::loadModelPrivate(const QString &modelName) -{ - if (isModelLoaded() && m_modelName == modelName) - return true; - - bool isFirstLoad = false; - if (isModelLoaded()) { - resetContextPrivate(); - delete m_llmodel; - m_llmodel = nullptr; - emit isModelLoadedChanged(); - } else { - isFirstLoad = true; - } - - bool isGPTJ = false; - QString filePath = modelFilePath(modelName); - QFileInfo info(filePath); - if (info.exists()) { - - auto fin = std::ifstream(filePath.toStdString(), std::ios::binary); - uint32_t magic; - fin.read((char *) &magic, sizeof(magic)); - fin.seekg(0); - fin.close(); - isGPTJ = magic == 0x67676d6c; - if (isGPTJ) { - m_llmodel = new GPTJ; - m_llmodel->loadModel(filePath.toStdString()); - } else { - m_llmodel = new LLamaModel; - m_llmodel->loadModel(filePath.toStdString()); - } - - emit isModelLoadedChanged(); - emit threadCountChanged(); - - if (isFirstLoad) - emit sendStartup(); - else - emit sendModelLoaded(); - } - - if (m_llmodel) - setModelName(info.completeBaseName().remove(0, 5)); // remove the ggml- prefix - - return m_llmodel; -} - -void LLMObject::setThreadCount(int32_t n_threads) { - if (m_llmodel && m_llmodel->threadCount() != n_threads) { - m_llmodel->setThreadCount(n_threads); - emit threadCountChanged(); - } -} - -int32_t LLMObject::threadCount() { - if (!m_llmodel) - return 1; - return m_llmodel->threadCount(); -} - -bool LLMObject::isModelLoaded() const -{ - return m_llmodel && m_llmodel->isModelLoaded(); -} - -void LLMObject::regenerateResponse() -{ - s_ctx.n_past -= m_promptResponseTokens; - s_ctx.n_past = std::max(0, s_ctx.n_past); - // FIXME: This does not seem to be needed in my testing and llama models don't to it. Remove? - s_ctx.logits.erase(s_ctx.logits.end() -= m_responseLogits, s_ctx.logits.end()); - s_ctx.tokens.erase(s_ctx.tokens.end() -= m_promptResponseTokens, s_ctx.tokens.end()); - m_promptResponseTokens = 0; - m_responseLogits = 0; - m_response = std::string(); - emit responseChanged(); -} - -void LLMObject::resetResponse() -{ - m_promptResponseTokens = 0; - m_responseLogits = 0; - m_response = std::string(); - emit responseChanged(); -} - -void LLMObject::resetContext() -{ - resetContextPrivate(); - emit sendResetContext(); -} - -void LLMObject::resetContextPrivate() -{ - regenerateResponse(); - s_ctx = LLModel::PromptContext(); -} - -std::string remove_leading_whitespace(const std::string& input) { - auto first_non_whitespace = std::find_if(input.begin(), input.end(), [](unsigned char c) { - return !std::isspace(c); - }); - - return std::string(first_non_whitespace, input.end()); -} - -std::string trim_whitespace(const std::string& input) { - auto first_non_whitespace = std::find_if(input.begin(), input.end(), [](unsigned char c) { - return !std::isspace(c); - }); - - auto last_non_whitespace = std::find_if(input.rbegin(), input.rend(), [](unsigned char c) { - return !std::isspace(c); - }).base(); - - return std::string(first_non_whitespace, last_non_whitespace); -} - -QString LLMObject::response() const -{ - return QString::fromStdString(remove_leading_whitespace(m_response)); -} - -QString LLMObject::modelName() const -{ - return m_modelName; -} - -void LLMObject::setModelName(const QString &modelName) -{ - m_modelName = modelName; - emit modelNameChanged(); - emit modelListChanged(); -} - -void LLMObject::modelNameChangeRequested(const QString &modelName) + , m_currentChat(new Chat) { - if (!loadModelPrivate(modelName)) - qWarning() << "ERROR: Could not load model" << modelName; + connect(Download::globalInstance(), &Download::modelListChanged, + this, &LLM::modelListChanged, Qt::QueuedConnection); + // FIXME: This should be moved to connect whenever we make a new chat object in future + connect(m_currentChat, &Chat::modelNameChanged, + this, &LLM::modelListChanged, Qt::QueuedConnection); } -QList LLMObject::modelList() const +QList LLM::modelList() const { // Build a model list from exepath and from the localpath QList list; @@ -232,7 +47,7 @@ QList LLMObject::modelList() const QFileInfo info(filePath); QString name = info.completeBaseName().remove(0, 5); if (info.exists()) { - if (name == m_modelName) + if (name == m_currentChat->modelName()) list.prepend(name); else list.append(name); @@ -249,7 +64,7 @@ QList LLMObject::modelList() const QFileInfo info(filePath); QString name = info.completeBaseName().remove(0, 5); if (info.exists() && !list.contains(name)) { // don't allow duplicates - if (name == m_modelName) + if (name == m_currentChat->modelName()) list.prepend(name); else list.append(name); @@ -271,189 +86,6 @@ QList LLMObject::modelList() const return list; } -bool LLMObject::handlePrompt(int32_t token) -{ - // m_promptResponseTokens and m_responseLogits are related to last prompt/response not - // the entire context window which we can reset on regenerate prompt - ++m_promptResponseTokens; - return !m_stopGenerating; -} - -bool LLMObject::handleResponse(int32_t token, const std::string &response) -{ -#if 0 - printf("%s", response.c_str()); - fflush(stdout); -#endif - - // check for error - if (token < 0) { - m_response.append(response); - emit responseChanged(); - return false; - } - - // m_promptResponseTokens and m_responseLogits are related to last prompt/response not - // the entire context window which we can reset on regenerate prompt - ++m_promptResponseTokens; - Q_ASSERT(!response.empty()); - m_response.append(response); - emit responseChanged(); - return !m_stopGenerating; -} - -bool LLMObject::handleRecalculate(bool isRecalc) -{ - if (m_isRecalc != isRecalc) { - m_isRecalc = isRecalc; - emit recalcChanged(); - } - return !m_stopGenerating; -} - -bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, - float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens) -{ - if (!isModelLoaded()) - return false; - - QString instructPrompt = prompt_template.arg(prompt); - - m_stopGenerating = false; - auto promptFunc = std::bind(&LLMObject::handlePrompt, this, std::placeholders::_1); - auto responseFunc = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1, - std::placeholders::_2); - auto recalcFunc = std::bind(&LLMObject::handleRecalculate, this, std::placeholders::_1); - emit responseStarted(); - qint32 logitsBefore = s_ctx.logits.size(); - s_ctx.n_predict = n_predict; - s_ctx.top_k = top_k; - s_ctx.top_p = top_p; - s_ctx.temp = temp; - s_ctx.n_batch = n_batch; - s_ctx.repeat_penalty = repeat_penalty; - s_ctx.repeat_last_n = repeat_penalty_tokens; - m_llmodel->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, s_ctx); - m_responseLogits += s_ctx.logits.size() - logitsBefore; - std::string trimmed = trim_whitespace(m_response); - if (trimmed != m_response) { - m_response = trimmed; - emit responseChanged(); - } - emit responseStopped(); - - return true; -} - -LLM::LLM() - : QObject{nullptr} - , m_currentChat(new Chat) - , m_llmodel(new LLMObject) - , m_responseInProgress(false) -{ - connect(Download::globalInstance(), &Download::modelListChanged, this, &LLM::modelListChanged, Qt::QueuedConnection); - connect(m_llmodel, &LLMObject::isModelLoadedChanged, this, &LLM::isModelLoadedChanged, Qt::QueuedConnection); - connect(m_llmodel, &LLMObject::responseChanged, this, &LLM::responseChanged, Qt::QueuedConnection); - connect(m_llmodel, &LLMObject::responseStarted, this, &LLM::responseStarted, Qt::QueuedConnection); - connect(m_llmodel, &LLMObject::responseStopped, this, &LLM::responseStopped, Qt::QueuedConnection); - connect(m_llmodel, &LLMObject::modelNameChanged, this, &LLM::modelNameChanged, Qt::QueuedConnection); - connect(m_llmodel, &LLMObject::modelListChanged, this, &LLM::modelListChanged, Qt::QueuedConnection); - connect(m_llmodel, &LLMObject::threadCountChanged, this, &LLM::threadCountChanged, Qt::QueuedConnection); - connect(m_llmodel, &LLMObject::threadCountChanged, this, &LLM::syncThreadCount, Qt::QueuedConnection); - connect(m_llmodel, &LLMObject::recalcChanged, this, &LLM::recalcChanged, Qt::QueuedConnection); - - connect(this, &LLM::promptRequested, m_llmodel, &LLMObject::prompt, Qt::QueuedConnection); - connect(this, &LLM::modelNameChangeRequested, m_llmodel, &LLMObject::modelNameChangeRequested, Qt::QueuedConnection); - - // The following are blocking operations and will block the gui thread, therefore must be fast - // to respond to - connect(this, &LLM::regenerateResponseRequested, m_llmodel, &LLMObject::regenerateResponse, Qt::BlockingQueuedConnection); - connect(this, &LLM::resetResponseRequested, m_llmodel, &LLMObject::resetResponse, Qt::BlockingQueuedConnection); - connect(this, &LLM::resetContextRequested, m_llmodel, &LLMObject::resetContext, Qt::BlockingQueuedConnection); - connect(this, &LLM::setThreadCountRequested, m_llmodel, &LLMObject::setThreadCount, Qt::QueuedConnection); -} - -bool LLM::isModelLoaded() const -{ - return m_llmodel->isModelLoaded(); -} - -void LLM::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, - float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens) -{ - emit promptRequested(prompt, prompt_template, n_predict, top_k, top_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens); -} - -void LLM::regenerateResponse() -{ - emit regenerateResponseRequested(); // blocking queued connection -} - -void LLM::resetResponse() -{ - emit resetResponseRequested(); // blocking queued connection -} - -void LLM::resetContext() -{ - emit resetContextRequested(); // blocking queued connection -} - -void LLM::stopGenerating() -{ - m_llmodel->stopGenerating(); -} - -QString LLM::response() const -{ - return m_llmodel->response(); -} - -void LLM::responseStarted() -{ - m_responseInProgress = true; - emit responseInProgressChanged(); -} - -void LLM::responseStopped() -{ - m_responseInProgress = false; - emit responseInProgressChanged(); -} - -QString LLM::modelName() const -{ - return m_llmodel->modelName(); -} - -void LLM::setModelName(const QString &modelName) -{ - // doesn't block but will unload old model and load new one which the gui can see through changes - // to the isModelLoaded property - emit modelNameChangeRequested(modelName); -} - -QList LLM::modelList() const -{ - return m_llmodel->modelList(); -} - -void LLM::syncThreadCount() { - emit setThreadCountRequested(m_desiredThreadCount); -} - -void LLM::setThreadCount(int32_t n_threads) { - if (n_threads <= 0) { - n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); - } - m_desiredThreadCount = n_threads; - syncThreadCount(); -} - -int32_t LLM::threadCount() { - return m_llmodel->threadCount(); -} - bool LLM::checkForUpdates() const { Network::globalInstance()->sendCheckForUpdates(); @@ -475,8 +107,3 @@ bool LLM::checkForUpdates() const return QProcess::startDetached(fileName); } - -bool LLM::isRecalc() const -{ - return m_llmodel->isRecalc(); -} diff --git a/llm.h b/llm.h index ddcd7d48..2cbfe5f5 100644 --- a/llm.h +++ b/llm.h @@ -2,146 +2,29 @@ #define LLM_H #include -#include #include "chat.h" -#include "llmodel/llmodel.h" - -class LLMObject : public QObject -{ - Q_OBJECT - Q_PROPERTY(QList modelList READ modelList NOTIFY modelListChanged) - Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged) - Q_PROPERTY(QString response READ response NOTIFY responseChanged) - Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged) - Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged) - Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged) - -public: - - LLMObject(); - - bool isModelLoaded() const; - void regenerateResponse(); - void resetResponse(); - void resetContext(); - - void stopGenerating() { m_stopGenerating = true; } - void setThreadCount(int32_t n_threads); - int32_t threadCount(); - - QString response() const; - QString modelName() const; - - QList modelList() const; - void setModelName(const QString &modelName); - - bool isRecalc() const { return m_isRecalc; } - -public Q_SLOTS: - bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, - float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens); - bool loadModel(); - void modelNameChangeRequested(const QString &modelName); - -Q_SIGNALS: - void isModelLoadedChanged(); - void responseChanged(); - void responseStarted(); - void responseStopped(); - void modelNameChanged(); - void modelListChanged(); - void threadCountChanged(); - void recalcChanged(); - void sendStartup(); - void sendModelLoaded(); - void sendResetContext(); - -private: - void resetContextPrivate(); - bool loadModelPrivate(const QString &modelName); - bool handlePrompt(int32_t token); - bool handleResponse(int32_t token, const std::string &response); - bool handleRecalculate(bool isRecalc); - -private: - LLModel *m_llmodel; - std::string m_response; - quint32 m_promptResponseTokens; - quint32 m_responseLogits; - QString m_modelName; - QThread m_llmThread; - std::atomic m_stopGenerating; - bool m_isRecalc; -}; class LLM : public QObject { Q_OBJECT Q_PROPERTY(QList modelList READ modelList NOTIFY modelListChanged) - Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged) - Q_PROPERTY(QString response READ response NOTIFY responseChanged) - Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged) - Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged) - Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged) - Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged) Q_PROPERTY(Chat *currentChat READ currentChat NOTIFY currentChatChanged) public: static LLM *globalInstance(); - Q_INVOKABLE bool isModelLoaded() const; - Q_INVOKABLE void prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, - float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens); - Q_INVOKABLE void regenerateResponse(); - Q_INVOKABLE void resetResponse(); - Q_INVOKABLE void resetContext(); - Q_INVOKABLE void stopGenerating(); - Q_INVOKABLE void syncThreadCount(); - Q_INVOKABLE void setThreadCount(int32_t n_threads); - Q_INVOKABLE int32_t threadCount(); - - QString response() const; - bool responseInProgress() const { return m_responseInProgress; } - QList modelList() const; - - QString modelName() const; - void setModelName(const QString &modelName); - Q_INVOKABLE bool checkForUpdates() const; - - bool isRecalc() const; - Chat *currentChat() const { return m_currentChat; } Q_SIGNALS: - void isModelLoadedChanged(); - void responseChanged(); - void responseInProgressChanged(); - void promptRequested(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, - float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens); - void regenerateResponseRequested(); - void resetResponseRequested(); - void resetContextRequested(); - void modelNameChangeRequested(const QString &modelName); - void modelNameChanged(); void modelListChanged(); - void threadCountChanged(); - void setThreadCountRequested(int32_t threadCount); - void recalcChanged(); void currentChatChanged(); -private Q_SLOTS: - void responseStarted(); - void responseStopped(); - private: Chat *m_currentChat; - LLMObject *m_llmodel; - int32_t m_desiredThreadCount; - bool m_responseInProgress; private: explicit LLM(); diff --git a/main.qml b/main.qml index bd2288b5..9f6b2989 100644 --- a/main.qml +++ b/main.qml @@ -94,7 +94,7 @@ Window { Item { anchors.centerIn: parent height: childrenRect.height - visible: LLM.isModelLoaded + visible: LLM.currentChat.isModelLoaded Label { id: modelLabel @@ -169,8 +169,8 @@ Window { } onActivated: { - LLM.stopGenerating() - LLM.modelName = comboBox.currentText + LLM.currentChat.stopGenerating() + LLM.currentChat.modelName = comboBox.currentText LLM.currentChat.reset(); } } @@ -178,8 +178,8 @@ Window { BusyIndicator { anchors.centerIn: parent - visible: !LLM.isModelLoaded - running: !LLM.isModelLoaded + visible: !LLM.currentChat.isModelLoaded + running: !LLM.currentChat.isModelLoaded Accessible.role: Accessible.Animation Accessible.name: qsTr("Busy indicator") Accessible.description: qsTr("Displayed when the model is loading") @@ -353,9 +353,10 @@ Window { text: qsTr("Recalculating context.") Connections { - target: LLM + // FIXME: This connection has to be setup everytime a new chat object is created + target: LLM.currentChat function onRecalcChanged() { - if (LLM.isRecalc) + if (LLM.currentChat.isRecalc) recalcPopup.open() else recalcPopup.close() @@ -409,7 +410,7 @@ Window { var string = item.name; var isResponse = item.name === qsTr("Response: ") if (item.currentResponse) - string += LLM.response + string += LLM.currentChat.response else string += chatModel.get(i).value if (isResponse && item.stopped) @@ -427,7 +428,7 @@ Window { var isResponse = item.name === qsTr("Response: ") str += "{\"content\": "; if (item.currentResponse) - str += JSON.stringify(LLM.response) + str += JSON.stringify(LLM.currentChat.response) else str += JSON.stringify(item.value) str += ", \"role\": \"" + (isResponse ? "assistant" : "user") + "\""; @@ -471,8 +472,8 @@ Window { } onClicked: { - LLM.stopGenerating() - LLM.resetContext() + LLM.currentChat.stopGenerating() + LLM.currentChat.resetContext() LLM.currentChat.reset(); } } @@ -679,14 +680,14 @@ Window { Accessible.description: qsTr("This is the list of prompt/response pairs comprising the actual conversation with the model") delegate: TextArea { - text: currentResponse ? LLM.response : (value ? value : "") + text: currentResponse ? LLM.currentChat.response : (value ? value : "") width: listView.width color: theme.textColor wrapMode: Text.WordWrap focus: false readOnly: true font.pixelSize: theme.fontSizeLarge - cursorVisible: currentResponse ? (LLM.response !== "" ? LLM.responseInProgress : false) : false + cursorVisible: currentResponse ? (LLM.currentChat.response !== "" ? LLM.currentChat.responseInProgress : false) : false cursorPosition: text.length background: Rectangle { color: name === qsTr("Response: ") ? theme.backgroundLighter : theme.backgroundLight @@ -706,8 +707,8 @@ Window { anchors.leftMargin: 90 anchors.top: parent.top anchors.topMargin: 5 - visible: (currentResponse ? true : false) && LLM.response === "" && LLM.responseInProgress - running: (currentResponse ? true : false) && LLM.response === "" && LLM.responseInProgress + visible: (currentResponse ? true : false) && LLM.currentChat.response === "" && LLM.currentChat.responseInProgress + running: (currentResponse ? true : false) && LLM.currentChat.response === "" && LLM.currentChat.responseInProgress Accessible.role: Accessible.Animation Accessible.name: qsTr("Busy indicator") @@ -738,7 +739,7 @@ Window { window.height / 2 - height / 2) x: globalPoint.x y: globalPoint.y - property string text: currentResponse ? LLM.response : (value ? value : "") + property string text: currentResponse ? LLM.currentChat.response : (value ? value : "") response: newResponse === undefined || newResponse === "" ? text : newResponse onAccepted: { var responseHasChanged = response !== text && response !== newResponse @@ -754,7 +755,7 @@ Window { Column { visible: name === qsTr("Response: ") && - (!currentResponse || !LLM.responseInProgress) && Network.isActive + (!currentResponse || !LLM.currentChat.responseInProgress) && Network.isActive anchors.right: parent.right anchors.rightMargin: 20 anchors.top: parent.top @@ -818,7 +819,8 @@ Window { property bool isAutoScrolling: false Connections { - target: LLM + // FIXME: This connection has to be setup everytime a new chat object is created + target: LLM.currentChat function onResponseChanged() { if (listView.shouldAutoScroll) { listView.isAutoScrolling = true @@ -853,27 +855,27 @@ Window { anchors.verticalCenter: parent.verticalCenter anchors.left: parent.left anchors.leftMargin: 15 - source: LLM.responseInProgress ? "qrc:/gpt4all/icons/stop_generating.svg" : "qrc:/gpt4all/icons/regenerate.svg" + source: LLM.currentChat.responseInProgress ? "qrc:/gpt4all/icons/stop_generating.svg" : "qrc:/gpt4all/icons/regenerate.svg" } leftPadding: 50 onClicked: { var index = Math.max(0, chatModel.count - 1); var listElement = chatModel.get(index); - if (LLM.responseInProgress) { + if (LLM.currentChat.responseInProgress) { listElement.stopped = true - LLM.stopGenerating() + LLM.currentChat.stopGenerating() } else { - LLM.regenerateResponse() + LLM.currentChat.regenerateResponse() if (chatModel.count) { if (listElement.name === qsTr("Response: ")) { chatModel.updateCurrentResponse(index, true); chatModel.updateStopped(index, false); - chatModel.updateValue(index, LLM.response); + chatModel.updateValue(index, LLM.currentChat.response); chatModel.updateThumbsUpState(index, false); chatModel.updateThumbsDownState(index, false); chatModel.updateNewResponse(index, ""); - LLM.prompt(listElement.prompt, settingsDialog.promptTemplate, + LLM.currentChat.prompt(listElement.prompt, settingsDialog.promptTemplate, settingsDialog.maxLength, settingsDialog.topK, settingsDialog.topP, settingsDialog.temperature, @@ -889,7 +891,7 @@ Window { anchors.bottomMargin: 40 padding: 15 contentItem: Text { - text: LLM.responseInProgress ? qsTr("Stop generating") : qsTr("Regenerate response") + text: LLM.currentChat.responseInProgress ? qsTr("Stop generating") : qsTr("Regenerate response") color: theme.textColor Accessible.role: Accessible.Button Accessible.name: text @@ -917,7 +919,7 @@ Window { color: theme.textColor padding: 20 rightPadding: 40 - enabled: LLM.isModelLoaded + enabled: LLM.currentChat.isModelLoaded wrapMode: Text.WordWrap font.pixelSize: theme.fontSizeLarge placeholderText: qsTr("Send a message...") @@ -941,19 +943,18 @@ Window { if (textInput.text === "") return - LLM.stopGenerating() + LLM.currentChat.stopGenerating() if (chatModel.count) { var index = Math.max(0, chatModel.count - 1); var listElement = chatModel.get(index); chatModel.updateCurrentResponse(index, false); - chatModel.updateValue(index, LLM.response); + chatModel.updateValue(index, LLM.currentChat.response); } - var prompt = textInput.text + "\n" chatModel.appendPrompt(qsTr("Prompt: "), textInput.text); - chatModel.appendResponse(qsTr("Response: "), prompt); - LLM.resetResponse() - LLM.prompt(prompt, settingsDialog.promptTemplate, + chatModel.appendResponse(qsTr("Response: "), textInput.text); + LLM.currentChat.resetResponse() + LLM.currentChat.prompt(textInput.text, settingsDialog.promptTemplate, settingsDialog.maxLength, settingsDialog.topK, settingsDialog.topP, diff --git a/network.cpp b/network.cpp index 07616e21..db5b24b7 100644 --- a/network.cpp +++ b/network.cpp @@ -86,7 +86,7 @@ bool Network::packageAndSendJson(const QString &ingestId, const QString &json) Q_ASSERT(doc.isObject()); QJsonObject object = doc.object(); object.insert("source", "gpt4all-chat"); - object.insert("agent_id", LLM::globalInstance()->modelName()); + object.insert("agent_id", LLM::globalInstance()->currentChat()->modelName()); object.insert("submitter_id", m_uniqueId); object.insert("ingest_id", ingestId); @@ -230,7 +230,7 @@ void Network::sendMixpanelEvent(const QString &ev) properties.insert("ip", m_ipify); properties.insert("name", QCoreApplication::applicationName() + " v" + QCoreApplication::applicationVersion()); - properties.insert("model", LLM::globalInstance()->modelName()); + properties.insert("model", LLM::globalInstance()->currentChat()->modelName()); QJsonObject event; event.insert("event", ev);