Major refactor in prep for multiple conversations.

pull/520/head
Adam Treat 1 year ago
parent e005ab8c0a
commit 4d87c46948

@ -59,6 +59,7 @@ set (CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
qt_add_executable(chat
main.cpp
chat.h chat.cpp chatmodel.h
chatllm.h chatllm.cpp
download.h download.cpp
network.h network.cpp
llm.h llm.cpp

@ -1 +1,117 @@
#include "chat.h"
#include "network.h"
Chat::Chat(QObject *parent)
: QObject(parent)
, m_llmodel(new ChatLLM)
, m_id(Network::globalInstance()->generateUniqueId())
, m_name(tr("New Chat"))
, m_chatModel(new ChatModel(this))
, m_responseInProgress(false)
, m_desiredThreadCount(std::min(4, (int32_t) std::thread::hardware_concurrency()))
{
connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::responseChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseStarted, this, &Chat::responseStarted, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelNameChanged, this, &Chat::modelNameChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::threadCountChanged, this, &Chat::threadCountChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::threadCountChanged, this, &Chat::syncThreadCount, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::recalcChanged, Qt::QueuedConnection);
connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection);
connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection);
// The following are blocking operations and will block the gui thread, therefore must be fast
// to respond to
connect(this, &Chat::regenerateResponseRequested, m_llmodel, &ChatLLM::regenerateResponse, Qt::BlockingQueuedConnection);
connect(this, &Chat::resetResponseRequested, m_llmodel, &ChatLLM::resetResponse, Qt::BlockingQueuedConnection);
connect(this, &Chat::resetContextRequested, m_llmodel, &ChatLLM::resetContext, Qt::BlockingQueuedConnection);
connect(this, &Chat::setThreadCountRequested, m_llmodel, &ChatLLM::setThreadCount, Qt::QueuedConnection);
}
void Chat::reset()
{
m_id = Network::globalInstance()->generateUniqueId();
m_chatModel->clear();
}
bool Chat::isModelLoaded() const
{
return m_llmodel->isModelLoaded();
}
void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens)
{
emit promptRequested(prompt, prompt_template, n_predict, top_k, top_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens);
}
void Chat::regenerateResponse()
{
emit regenerateResponseRequested(); // blocking queued connection
}
void Chat::resetResponse()
{
emit resetResponseRequested(); // blocking queued connection
}
void Chat::resetContext()
{
emit resetContextRequested(); // blocking queued connection
}
void Chat::stopGenerating()
{
m_llmodel->stopGenerating();
}
QString Chat::response() const
{
return m_llmodel->response();
}
void Chat::responseStarted()
{
m_responseInProgress = true;
emit responseInProgressChanged();
}
void Chat::responseStopped()
{
m_responseInProgress = false;
emit responseInProgressChanged();
}
QString Chat::modelName() const
{
return m_llmodel->modelName();
}
void Chat::setModelName(const QString &modelName)
{
// doesn't block but will unload old model and load new one which the gui can see through changes
// to the isModelLoaded property
emit modelNameChangeRequested(modelName);
}
void Chat::syncThreadCount() {
emit setThreadCountRequested(m_desiredThreadCount);
}
void Chat::setThreadCount(int32_t n_threads) {
if (n_threads <= 0)
n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
m_desiredThreadCount = n_threads;
syncThreadCount();
}
int32_t Chat::threadCount() {
return m_llmodel->threadCount();
}
bool Chat::isRecalc() const
{
return m_llmodel->isRecalc();
}

@ -4,8 +4,8 @@
#include <QObject>
#include <QtQml>
#include "chatllm.h"
#include "chatmodel.h"
#include "network.h"
class Chat : public QObject
{
@ -13,36 +13,69 @@ class Chat : public QObject
Q_PROPERTY(QString id READ id NOTIFY idChanged)
Q_PROPERTY(QString name READ name NOTIFY nameChanged)
Q_PROPERTY(ChatModel *chatModel READ chatModel NOTIFY chatModelChanged)
Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged)
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
QML_ELEMENT
QML_UNCREATABLE("Only creatable from c++!")
public:
explicit Chat(QObject *parent = nullptr) : QObject(parent)
{
m_id = Network::globalInstance()->generateUniqueId();
m_name = tr("New Chat");
m_chatModel = new ChatModel(this);
}
explicit Chat(QObject *parent = nullptr);
QString id() const { return m_id; }
QString name() const { return m_name; }
ChatModel *chatModel() { return m_chatModel; }
Q_INVOKABLE void reset()
{
m_id = Network::globalInstance()->generateUniqueId();
m_chatModel->clear();
}
Q_INVOKABLE void reset();
Q_INVOKABLE bool isModelLoaded() const;
Q_INVOKABLE void prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens);
Q_INVOKABLE void regenerateResponse();
Q_INVOKABLE void resetResponse();
Q_INVOKABLE void resetContext();
Q_INVOKABLE void stopGenerating();
Q_INVOKABLE void syncThreadCount();
Q_INVOKABLE void setThreadCount(int32_t n_threads);
Q_INVOKABLE int32_t threadCount();
QString response() const;
bool responseInProgress() const { return m_responseInProgress; }
QString modelName() const;
void setModelName(const QString &modelName);
bool isRecalc() const;
Q_SIGNALS:
void idChanged();
void nameChanged();
void chatModelChanged();
void isModelLoadedChanged();
void responseChanged();
void responseInProgressChanged();
void promptRequested(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens);
void regenerateResponseRequested();
void resetResponseRequested();
void resetContextRequested();
void modelNameChangeRequested(const QString &modelName);
void modelNameChanged();
void threadCountChanged();
void setThreadCountRequested(int32_t threadCount);
void recalcChanged();
private Q_SLOTS:
void responseStarted();
void responseStopped();
private:
ChatLLM *m_llmodel;
QString m_id;
QString m_name;
ChatModel *m_chatModel;
bool m_responseInProgress;
int32_t m_desiredThreadCount;
};
#endif // CHAT_H

@ -0,0 +1,290 @@
#include "chatllm.h"
#include "download.h"
#include "network.h"
#include "llm.h"
#include "llmodel/gptj.h"
#include "llmodel/llamamodel.h"
#include <QCoreApplication>
#include <QDir>
#include <QFile>
#include <QProcess>
#include <QResource>
#include <QSettings>
#include <fstream>
//#define DEBUG
static QString modelFilePath(const QString &modelName)
{
QString appPath = QCoreApplication::applicationDirPath()
+ "/ggml-" + modelName + ".bin";
QFileInfo infoAppPath(appPath);
if (infoAppPath.exists())
return appPath;
QString downloadPath = Download::globalInstance()->downloadLocalModelsPath()
+ "/ggml-" + modelName + ".bin";
QFileInfo infoLocalPath(downloadPath);
if (infoLocalPath.exists())
return downloadPath;
return QString();
}
ChatLLM::ChatLLM()
: QObject{nullptr}
, m_llmodel(nullptr)
, m_promptResponseTokens(0)
, m_responseLogits(0)
, m_isRecalc(false)
{
moveToThread(&m_llmThread);
connect(&m_llmThread, &QThread::started, this, &ChatLLM::loadModel);
connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup);
connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded);
connect(this, &ChatLLM::sendResetContext, Network::globalInstance(), &Network::sendResetContext);
m_llmThread.setObjectName("llm thread"); // FIXME: Should identify these with chat name
m_llmThread.start();
}
bool ChatLLM::loadModel()
{
const QList<QString> models = LLM::globalInstance()->modelList();
if (models.isEmpty()) {
// try again when we get a list of models
connect(Download::globalInstance(), &Download::modelListChanged, this,
&ChatLLM::loadModel, Qt::SingleShotConnection);
return false;
}
QSettings settings;
settings.sync();
QString defaultModel = settings.value("defaultModel", "gpt4all-j-v1.3-groovy").toString();
if (defaultModel.isEmpty() || !models.contains(defaultModel))
defaultModel = models.first();
return loadModelPrivate(defaultModel);
}
bool ChatLLM::loadModelPrivate(const QString &modelName)
{
if (isModelLoaded() && m_modelName == modelName)
return true;
bool isFirstLoad = false;
if (isModelLoaded()) {
resetContextPrivate();
delete m_llmodel;
m_llmodel = nullptr;
emit isModelLoadedChanged();
} else {
isFirstLoad = true;
}
bool isGPTJ = false;
QString filePath = modelFilePath(modelName);
QFileInfo info(filePath);
if (info.exists()) {
auto fin = std::ifstream(filePath.toStdString(), std::ios::binary);
uint32_t magic;
fin.read((char *) &magic, sizeof(magic));
fin.seekg(0);
fin.close();
isGPTJ = magic == 0x67676d6c;
if (isGPTJ) {
m_llmodel = new GPTJ;
m_llmodel->loadModel(filePath.toStdString());
} else {
m_llmodel = new LLamaModel;
m_llmodel->loadModel(filePath.toStdString());
}
emit isModelLoadedChanged();
emit threadCountChanged();
if (isFirstLoad)
emit sendStartup();
else
emit sendModelLoaded();
}
if (m_llmodel)
setModelName(info.completeBaseName().remove(0, 5)); // remove the ggml- prefix
return m_llmodel;
}
void ChatLLM::setThreadCount(int32_t n_threads) {
if (m_llmodel && m_llmodel->threadCount() != n_threads) {
m_llmodel->setThreadCount(n_threads);
emit threadCountChanged();
}
}
int32_t ChatLLM::threadCount() {
if (!m_llmodel)
return 1;
return m_llmodel->threadCount();
}
bool ChatLLM::isModelLoaded() const
{
return m_llmodel && m_llmodel->isModelLoaded();
}
void ChatLLM::regenerateResponse()
{
m_ctx.n_past -= m_promptResponseTokens;
m_ctx.n_past = std::max(0, m_ctx.n_past);
// FIXME: This does not seem to be needed in my testing and llama models don't to it. Remove?
m_ctx.logits.erase(m_ctx.logits.end() -= m_responseLogits, m_ctx.logits.end());
m_ctx.tokens.erase(m_ctx.tokens.end() -= m_promptResponseTokens, m_ctx.tokens.end());
m_promptResponseTokens = 0;
m_responseLogits = 0;
m_response = std::string();
emit responseChanged();
}
void ChatLLM::resetResponse()
{
m_promptResponseTokens = 0;
m_responseLogits = 0;
m_response = std::string();
emit responseChanged();
}
void ChatLLM::resetContext()
{
resetContextPrivate();
emit sendResetContext();
}
void ChatLLM::resetContextPrivate()
{
regenerateResponse();
m_ctx = LLModel::PromptContext();
}
std::string remove_leading_whitespace(const std::string& input) {
auto first_non_whitespace = std::find_if(input.begin(), input.end(), [](unsigned char c) {
return !std::isspace(c);
});
return std::string(first_non_whitespace, input.end());
}
std::string trim_whitespace(const std::string& input) {
auto first_non_whitespace = std::find_if(input.begin(), input.end(), [](unsigned char c) {
return !std::isspace(c);
});
auto last_non_whitespace = std::find_if(input.rbegin(), input.rend(), [](unsigned char c) {
return !std::isspace(c);
}).base();
return std::string(first_non_whitespace, last_non_whitespace);
}
QString ChatLLM::response() const
{
return QString::fromStdString(remove_leading_whitespace(m_response));
}
QString ChatLLM::modelName() const
{
return m_modelName;
}
void ChatLLM::setModelName(const QString &modelName)
{
m_modelName = modelName;
emit modelNameChanged();
}
void ChatLLM::modelNameChangeRequested(const QString &modelName)
{
if (!loadModelPrivate(modelName))
qWarning() << "ERROR: Could not load model" << modelName;
}
bool ChatLLM::handlePrompt(int32_t token)
{
// m_promptResponseTokens and m_responseLogits are related to last prompt/response not
// the entire context window which we can reset on regenerate prompt
++m_promptResponseTokens;
return !m_stopGenerating;
}
bool ChatLLM::handleResponse(int32_t token, const std::string &response)
{
#if defined(DEBUG)
printf("%s", response.c_str());
fflush(stdout);
#endif
// check for error
if (token < 0) {
m_response.append(response);
emit responseChanged();
return false;
}
// m_promptResponseTokens and m_responseLogits are related to last prompt/response not
// the entire context window which we can reset on regenerate prompt
++m_promptResponseTokens;
Q_ASSERT(!response.empty());
m_response.append(response);
emit responseChanged();
return !m_stopGenerating;
}
bool ChatLLM::handleRecalculate(bool isRecalc)
{
if (m_isRecalc != isRecalc) {
m_isRecalc = isRecalc;
emit recalcChanged();
}
return !m_stopGenerating;
}
bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens)
{
if (!isModelLoaded())
return false;
QString instructPrompt = prompt_template.arg(prompt);
m_stopGenerating = false;
auto promptFunc = std::bind(&ChatLLM::handlePrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleResponse, this, std::placeholders::_1,
std::placeholders::_2);
auto recalcFunc = std::bind(&ChatLLM::handleRecalculate, this, std::placeholders::_1);
emit responseStarted();
qint32 logitsBefore = m_ctx.logits.size();
m_ctx.n_predict = n_predict;
m_ctx.top_k = top_k;
m_ctx.top_p = top_p;
m_ctx.temp = temp;
m_ctx.n_batch = n_batch;
m_ctx.repeat_penalty = repeat_penalty;
m_ctx.repeat_last_n = repeat_penalty_tokens;
#if defined(DEBUG)
printf("%s", qPrintable(instructPrompt));
fflush(stdout);
#endif
m_llmodel->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
#if defined(DEBUG)
printf("\n");
fflush(stdout);
#endif
m_responseLogits += m_ctx.logits.size() - logitsBefore;
std::string trimmed = trim_whitespace(m_response);
if (trimmed != m_response) {
m_response = trimmed;
emit responseChanged();
}
emit responseStopped();
return true;
}

@ -0,0 +1,75 @@
#ifndef CHATLLM_H
#define CHATLLM_H
#include <QObject>
#include <QThread>
#include "llmodel/llmodel.h"
class ChatLLM : public QObject
{
Q_OBJECT
Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged)
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
public:
ChatLLM();
bool isModelLoaded() const;
void regenerateResponse();
void resetResponse();
void resetContext();
void stopGenerating() { m_stopGenerating = true; }
void setThreadCount(int32_t n_threads);
int32_t threadCount();
QString response() const;
QString modelName() const;
void setModelName(const QString &modelName);
bool isRecalc() const { return m_isRecalc; }
public Q_SLOTS:
bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens);
bool loadModel();
void modelNameChangeRequested(const QString &modelName);
Q_SIGNALS:
void isModelLoadedChanged();
void responseChanged();
void responseStarted();
void responseStopped();
void modelNameChanged();
void threadCountChanged();
void recalcChanged();
void sendStartup();
void sendModelLoaded();
void sendResetContext();
private:
void resetContextPrivate();
bool loadModelPrivate(const QString &modelName);
bool handlePrompt(int32_t token);
bool handleResponse(int32_t token, const std::string &response);
bool handleRecalculate(bool isRecalc);
private:
LLModel::PromptContext m_ctx;
LLModel *m_llmodel;
std::string m_response;
quint32 m_promptResponseTokens;
quint32 m_responseLogits;
QString m_modelName;
QThread m_llmThread;
std::atomic<bool> m_stopGenerating;
bool m_isRecalc;
};
#endif // CHATLLM_H

@ -19,203 +19,18 @@ LLM *LLM::globalInstance()
return llmInstance();
}
static LLModel::PromptContext s_ctx;
static QString modelFilePath(const QString &modelName)
{
QString appPath = QCoreApplication::applicationDirPath()
+ "/ggml-" + modelName + ".bin";
QFileInfo infoAppPath(appPath);
if (infoAppPath.exists())
return appPath;
QString downloadPath = Download::globalInstance()->downloadLocalModelsPath()
+ "/ggml-" + modelName + ".bin";
QFileInfo infoLocalPath(downloadPath);
if (infoLocalPath.exists())
return downloadPath;
return QString();
}
LLMObject::LLMObject()
LLM::LLM()
: QObject{nullptr}
, m_llmodel(nullptr)
, m_promptResponseTokens(0)
, m_responseLogits(0)
, m_isRecalc(false)
{
moveToThread(&m_llmThread);
connect(&m_llmThread, &QThread::started, this, &LLMObject::loadModel);
connect(this, &LLMObject::sendStartup, Network::globalInstance(), &Network::sendStartup);
connect(this, &LLMObject::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded);
connect(this, &LLMObject::sendResetContext, Network::globalInstance(), &Network::sendResetContext);
m_llmThread.setObjectName("llm thread");
m_llmThread.start();
}
bool LLMObject::loadModel()
{
const QList<QString> models = modelList();
if (models.isEmpty()) {
// try again when we get a list of models
connect(Download::globalInstance(), &Download::modelListChanged, this,
&LLMObject::loadModel, Qt::SingleShotConnection);
return false;
}
QSettings settings;
settings.sync();
QString defaultModel = settings.value("defaultModel", "gpt4all-j-v1.3-groovy").toString();
if (defaultModel.isEmpty() || !models.contains(defaultModel))
defaultModel = models.first();
return loadModelPrivate(defaultModel);
}
bool LLMObject::loadModelPrivate(const QString &modelName)
{
if (isModelLoaded() && m_modelName == modelName)
return true;
bool isFirstLoad = false;
if (isModelLoaded()) {
resetContextPrivate();
delete m_llmodel;
m_llmodel = nullptr;
emit isModelLoadedChanged();
} else {
isFirstLoad = true;
}
bool isGPTJ = false;
QString filePath = modelFilePath(modelName);
QFileInfo info(filePath);
if (info.exists()) {
auto fin = std::ifstream(filePath.toStdString(), std::ios::binary);
uint32_t magic;
fin.read((char *) &magic, sizeof(magic));
fin.seekg(0);
fin.close();
isGPTJ = magic == 0x67676d6c;
if (isGPTJ) {
m_llmodel = new GPTJ;
m_llmodel->loadModel(filePath.toStdString());
} else {
m_llmodel = new LLamaModel;
m_llmodel->loadModel(filePath.toStdString());
}
emit isModelLoadedChanged();
emit threadCountChanged();
if (isFirstLoad)
emit sendStartup();
else
emit sendModelLoaded();
}
if (m_llmodel)
setModelName(info.completeBaseName().remove(0, 5)); // remove the ggml- prefix
return m_llmodel;
}
void LLMObject::setThreadCount(int32_t n_threads) {
if (m_llmodel && m_llmodel->threadCount() != n_threads) {
m_llmodel->setThreadCount(n_threads);
emit threadCountChanged();
}
}
int32_t LLMObject::threadCount() {
if (!m_llmodel)
return 1;
return m_llmodel->threadCount();
}
bool LLMObject::isModelLoaded() const
{
return m_llmodel && m_llmodel->isModelLoaded();
}
void LLMObject::regenerateResponse()
{
s_ctx.n_past -= m_promptResponseTokens;
s_ctx.n_past = std::max(0, s_ctx.n_past);
// FIXME: This does not seem to be needed in my testing and llama models don't to it. Remove?
s_ctx.logits.erase(s_ctx.logits.end() -= m_responseLogits, s_ctx.logits.end());
s_ctx.tokens.erase(s_ctx.tokens.end() -= m_promptResponseTokens, s_ctx.tokens.end());
m_promptResponseTokens = 0;
m_responseLogits = 0;
m_response = std::string();
emit responseChanged();
}
void LLMObject::resetResponse()
{
m_promptResponseTokens = 0;
m_responseLogits = 0;
m_response = std::string();
emit responseChanged();
}
void LLMObject::resetContext()
{
resetContextPrivate();
emit sendResetContext();
}
void LLMObject::resetContextPrivate()
{
regenerateResponse();
s_ctx = LLModel::PromptContext();
}
std::string remove_leading_whitespace(const std::string& input) {
auto first_non_whitespace = std::find_if(input.begin(), input.end(), [](unsigned char c) {
return !std::isspace(c);
});
return std::string(first_non_whitespace, input.end());
}
std::string trim_whitespace(const std::string& input) {
auto first_non_whitespace = std::find_if(input.begin(), input.end(), [](unsigned char c) {
return !std::isspace(c);
});
auto last_non_whitespace = std::find_if(input.rbegin(), input.rend(), [](unsigned char c) {
return !std::isspace(c);
}).base();
return std::string(first_non_whitespace, last_non_whitespace);
}
QString LLMObject::response() const
{
return QString::fromStdString(remove_leading_whitespace(m_response));
}
QString LLMObject::modelName() const
{
return m_modelName;
}
void LLMObject::setModelName(const QString &modelName)
{
m_modelName = modelName;
emit modelNameChanged();
emit modelListChanged();
}
void LLMObject::modelNameChangeRequested(const QString &modelName)
, m_currentChat(new Chat)
{
if (!loadModelPrivate(modelName))
qWarning() << "ERROR: Could not load model" << modelName;
connect(Download::globalInstance(), &Download::modelListChanged,
this, &LLM::modelListChanged, Qt::QueuedConnection);
// FIXME: This should be moved to connect whenever we make a new chat object in future
connect(m_currentChat, &Chat::modelNameChanged,
this, &LLM::modelListChanged, Qt::QueuedConnection);
}
QList<QString> LLMObject::modelList() const
QList<QString> LLM::modelList() const
{
// Build a model list from exepath and from the localpath
QList<QString> list;
@ -232,7 +47,7 @@ QList<QString> LLMObject::modelList() const
QFileInfo info(filePath);
QString name = info.completeBaseName().remove(0, 5);
if (info.exists()) {
if (name == m_modelName)
if (name == m_currentChat->modelName())
list.prepend(name);
else
list.append(name);
@ -249,7 +64,7 @@ QList<QString> LLMObject::modelList() const
QFileInfo info(filePath);
QString name = info.completeBaseName().remove(0, 5);
if (info.exists() && !list.contains(name)) { // don't allow duplicates
if (name == m_modelName)
if (name == m_currentChat->modelName())
list.prepend(name);
else
list.append(name);
@ -271,189 +86,6 @@ QList<QString> LLMObject::modelList() const
return list;
}
bool LLMObject::handlePrompt(int32_t token)
{
// m_promptResponseTokens and m_responseLogits are related to last prompt/response not
// the entire context window which we can reset on regenerate prompt
++m_promptResponseTokens;
return !m_stopGenerating;
}
bool LLMObject::handleResponse(int32_t token, const std::string &response)
{
#if 0
printf("%s", response.c_str());
fflush(stdout);
#endif
// check for error
if (token < 0) {
m_response.append(response);
emit responseChanged();
return false;
}
// m_promptResponseTokens and m_responseLogits are related to last prompt/response not
// the entire context window which we can reset on regenerate prompt
++m_promptResponseTokens;
Q_ASSERT(!response.empty());
m_response.append(response);
emit responseChanged();
return !m_stopGenerating;
}
bool LLMObject::handleRecalculate(bool isRecalc)
{
if (m_isRecalc != isRecalc) {
m_isRecalc = isRecalc;
emit recalcChanged();
}
return !m_stopGenerating;
}
bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens)
{
if (!isModelLoaded())
return false;
QString instructPrompt = prompt_template.arg(prompt);
m_stopGenerating = false;
auto promptFunc = std::bind(&LLMObject::handlePrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1,
std::placeholders::_2);
auto recalcFunc = std::bind(&LLMObject::handleRecalculate, this, std::placeholders::_1);
emit responseStarted();
qint32 logitsBefore = s_ctx.logits.size();
s_ctx.n_predict = n_predict;
s_ctx.top_k = top_k;
s_ctx.top_p = top_p;
s_ctx.temp = temp;
s_ctx.n_batch = n_batch;
s_ctx.repeat_penalty = repeat_penalty;
s_ctx.repeat_last_n = repeat_penalty_tokens;
m_llmodel->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, s_ctx);
m_responseLogits += s_ctx.logits.size() - logitsBefore;
std::string trimmed = trim_whitespace(m_response);
if (trimmed != m_response) {
m_response = trimmed;
emit responseChanged();
}
emit responseStopped();
return true;
}
LLM::LLM()
: QObject{nullptr}
, m_currentChat(new Chat)
, m_llmodel(new LLMObject)
, m_responseInProgress(false)
{
connect(Download::globalInstance(), &Download::modelListChanged, this, &LLM::modelListChanged, Qt::QueuedConnection);
connect(m_llmodel, &LLMObject::isModelLoadedChanged, this, &LLM::isModelLoadedChanged, Qt::QueuedConnection);
connect(m_llmodel, &LLMObject::responseChanged, this, &LLM::responseChanged, Qt::QueuedConnection);
connect(m_llmodel, &LLMObject::responseStarted, this, &LLM::responseStarted, Qt::QueuedConnection);
connect(m_llmodel, &LLMObject::responseStopped, this, &LLM::responseStopped, Qt::QueuedConnection);
connect(m_llmodel, &LLMObject::modelNameChanged, this, &LLM::modelNameChanged, Qt::QueuedConnection);
connect(m_llmodel, &LLMObject::modelListChanged, this, &LLM::modelListChanged, Qt::QueuedConnection);
connect(m_llmodel, &LLMObject::threadCountChanged, this, &LLM::threadCountChanged, Qt::QueuedConnection);
connect(m_llmodel, &LLMObject::threadCountChanged, this, &LLM::syncThreadCount, Qt::QueuedConnection);
connect(m_llmodel, &LLMObject::recalcChanged, this, &LLM::recalcChanged, Qt::QueuedConnection);
connect(this, &LLM::promptRequested, m_llmodel, &LLMObject::prompt, Qt::QueuedConnection);
connect(this, &LLM::modelNameChangeRequested, m_llmodel, &LLMObject::modelNameChangeRequested, Qt::QueuedConnection);
// The following are blocking operations and will block the gui thread, therefore must be fast
// to respond to
connect(this, &LLM::regenerateResponseRequested, m_llmodel, &LLMObject::regenerateResponse, Qt::BlockingQueuedConnection);
connect(this, &LLM::resetResponseRequested, m_llmodel, &LLMObject::resetResponse, Qt::BlockingQueuedConnection);
connect(this, &LLM::resetContextRequested, m_llmodel, &LLMObject::resetContext, Qt::BlockingQueuedConnection);
connect(this, &LLM::setThreadCountRequested, m_llmodel, &LLMObject::setThreadCount, Qt::QueuedConnection);
}
bool LLM::isModelLoaded() const
{
return m_llmodel->isModelLoaded();
}
void LLM::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens)
{
emit promptRequested(prompt, prompt_template, n_predict, top_k, top_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens);
}
void LLM::regenerateResponse()
{
emit regenerateResponseRequested(); // blocking queued connection
}
void LLM::resetResponse()
{
emit resetResponseRequested(); // blocking queued connection
}
void LLM::resetContext()
{
emit resetContextRequested(); // blocking queued connection
}
void LLM::stopGenerating()
{
m_llmodel->stopGenerating();
}
QString LLM::response() const
{
return m_llmodel->response();
}
void LLM::responseStarted()
{
m_responseInProgress = true;
emit responseInProgressChanged();
}
void LLM::responseStopped()
{
m_responseInProgress = false;
emit responseInProgressChanged();
}
QString LLM::modelName() const
{
return m_llmodel->modelName();
}
void LLM::setModelName(const QString &modelName)
{
// doesn't block but will unload old model and load new one which the gui can see through changes
// to the isModelLoaded property
emit modelNameChangeRequested(modelName);
}
QList<QString> LLM::modelList() const
{
return m_llmodel->modelList();
}
void LLM::syncThreadCount() {
emit setThreadCountRequested(m_desiredThreadCount);
}
void LLM::setThreadCount(int32_t n_threads) {
if (n_threads <= 0) {
n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
}
m_desiredThreadCount = n_threads;
syncThreadCount();
}
int32_t LLM::threadCount() {
return m_llmodel->threadCount();
}
bool LLM::checkForUpdates() const
{
Network::globalInstance()->sendCheckForUpdates();
@ -475,8 +107,3 @@ bool LLM::checkForUpdates() const
return QProcess::startDetached(fileName);
}
bool LLM::isRecalc() const
{
return m_llmodel->isRecalc();
}

117
llm.h

@ -2,146 +2,29 @@
#define LLM_H
#include <QObject>
#include <QThread>
#include "chat.h"
#include "llmodel/llmodel.h"
class LLMObject : public QObject
{
Q_OBJECT
Q_PROPERTY(QList<QString> modelList READ modelList NOTIFY modelListChanged)
Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged)
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
public:
LLMObject();
bool isModelLoaded() const;
void regenerateResponse();
void resetResponse();
void resetContext();
void stopGenerating() { m_stopGenerating = true; }
void setThreadCount(int32_t n_threads);
int32_t threadCount();
QString response() const;
QString modelName() const;
QList<QString> modelList() const;
void setModelName(const QString &modelName);
bool isRecalc() const { return m_isRecalc; }
public Q_SLOTS:
bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens);
bool loadModel();
void modelNameChangeRequested(const QString &modelName);
Q_SIGNALS:
void isModelLoadedChanged();
void responseChanged();
void responseStarted();
void responseStopped();
void modelNameChanged();
void modelListChanged();
void threadCountChanged();
void recalcChanged();
void sendStartup();
void sendModelLoaded();
void sendResetContext();
private:
void resetContextPrivate();
bool loadModelPrivate(const QString &modelName);
bool handlePrompt(int32_t token);
bool handleResponse(int32_t token, const std::string &response);
bool handleRecalculate(bool isRecalc);
private:
LLModel *m_llmodel;
std::string m_response;
quint32 m_promptResponseTokens;
quint32 m_responseLogits;
QString m_modelName;
QThread m_llmThread;
std::atomic<bool> m_stopGenerating;
bool m_isRecalc;
};
class LLM : public QObject
{
Q_OBJECT
Q_PROPERTY(QList<QString> modelList READ modelList NOTIFY modelListChanged)
Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged)
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
Q_PROPERTY(Chat *currentChat READ currentChat NOTIFY currentChatChanged)
public:
static LLM *globalInstance();
Q_INVOKABLE bool isModelLoaded() const;
Q_INVOKABLE void prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens);
Q_INVOKABLE void regenerateResponse();
Q_INVOKABLE void resetResponse();
Q_INVOKABLE void resetContext();
Q_INVOKABLE void stopGenerating();
Q_INVOKABLE void syncThreadCount();
Q_INVOKABLE void setThreadCount(int32_t n_threads);
Q_INVOKABLE int32_t threadCount();
QString response() const;
bool responseInProgress() const { return m_responseInProgress; }
QList<QString> modelList() const;
QString modelName() const;
void setModelName(const QString &modelName);
Q_INVOKABLE bool checkForUpdates() const;
bool isRecalc() const;
Chat *currentChat() const { return m_currentChat; }
Q_SIGNALS:
void isModelLoadedChanged();
void responseChanged();
void responseInProgressChanged();
void promptRequested(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens);
void regenerateResponseRequested();
void resetResponseRequested();
void resetContextRequested();
void modelNameChangeRequested(const QString &modelName);
void modelNameChanged();
void modelListChanged();
void threadCountChanged();
void setThreadCountRequested(int32_t threadCount);
void recalcChanged();
void currentChatChanged();
private Q_SLOTS:
void responseStarted();
void responseStopped();
private:
Chat *m_currentChat;
LLMObject *m_llmodel;
int32_t m_desiredThreadCount;
bool m_responseInProgress;
private:
explicit LLM();

@ -94,7 +94,7 @@ Window {
Item {
anchors.centerIn: parent
height: childrenRect.height
visible: LLM.isModelLoaded
visible: LLM.currentChat.isModelLoaded
Label {
id: modelLabel
@ -169,8 +169,8 @@ Window {
}
onActivated: {
LLM.stopGenerating()
LLM.modelName = comboBox.currentText
LLM.currentChat.stopGenerating()
LLM.currentChat.modelName = comboBox.currentText
LLM.currentChat.reset();
}
}
@ -178,8 +178,8 @@ Window {
BusyIndicator {
anchors.centerIn: parent
visible: !LLM.isModelLoaded
running: !LLM.isModelLoaded
visible: !LLM.currentChat.isModelLoaded
running: !LLM.currentChat.isModelLoaded
Accessible.role: Accessible.Animation
Accessible.name: qsTr("Busy indicator")
Accessible.description: qsTr("Displayed when the model is loading")
@ -353,9 +353,10 @@ Window {
text: qsTr("Recalculating context.")
Connections {
target: LLM
// FIXME: This connection has to be setup everytime a new chat object is created
target: LLM.currentChat
function onRecalcChanged() {
if (LLM.isRecalc)
if (LLM.currentChat.isRecalc)
recalcPopup.open()
else
recalcPopup.close()
@ -409,7 +410,7 @@ Window {
var string = item.name;
var isResponse = item.name === qsTr("Response: ")
if (item.currentResponse)
string += LLM.response
string += LLM.currentChat.response
else
string += chatModel.get(i).value
if (isResponse && item.stopped)
@ -427,7 +428,7 @@ Window {
var isResponse = item.name === qsTr("Response: ")
str += "{\"content\": ";
if (item.currentResponse)
str += JSON.stringify(LLM.response)
str += JSON.stringify(LLM.currentChat.response)
else
str += JSON.stringify(item.value)
str += ", \"role\": \"" + (isResponse ? "assistant" : "user") + "\"";
@ -471,8 +472,8 @@ Window {
}
onClicked: {
LLM.stopGenerating()
LLM.resetContext()
LLM.currentChat.stopGenerating()
LLM.currentChat.resetContext()
LLM.currentChat.reset();
}
}
@ -679,14 +680,14 @@ Window {
Accessible.description: qsTr("This is the list of prompt/response pairs comprising the actual conversation with the model")
delegate: TextArea {
text: currentResponse ? LLM.response : (value ? value : "")
text: currentResponse ? LLM.currentChat.response : (value ? value : "")
width: listView.width
color: theme.textColor
wrapMode: Text.WordWrap
focus: false
readOnly: true
font.pixelSize: theme.fontSizeLarge
cursorVisible: currentResponse ? (LLM.response !== "" ? LLM.responseInProgress : false) : false
cursorVisible: currentResponse ? (LLM.currentChat.response !== "" ? LLM.currentChat.responseInProgress : false) : false
cursorPosition: text.length
background: Rectangle {
color: name === qsTr("Response: ") ? theme.backgroundLighter : theme.backgroundLight
@ -706,8 +707,8 @@ Window {
anchors.leftMargin: 90
anchors.top: parent.top
anchors.topMargin: 5
visible: (currentResponse ? true : false) && LLM.response === "" && LLM.responseInProgress
running: (currentResponse ? true : false) && LLM.response === "" && LLM.responseInProgress
visible: (currentResponse ? true : false) && LLM.currentChat.response === "" && LLM.currentChat.responseInProgress
running: (currentResponse ? true : false) && LLM.currentChat.response === "" && LLM.currentChat.responseInProgress
Accessible.role: Accessible.Animation
Accessible.name: qsTr("Busy indicator")
@ -738,7 +739,7 @@ Window {
window.height / 2 - height / 2)
x: globalPoint.x
y: globalPoint.y
property string text: currentResponse ? LLM.response : (value ? value : "")
property string text: currentResponse ? LLM.currentChat.response : (value ? value : "")
response: newResponse === undefined || newResponse === "" ? text : newResponse
onAccepted: {
var responseHasChanged = response !== text && response !== newResponse
@ -754,7 +755,7 @@ Window {
Column {
visible: name === qsTr("Response: ") &&
(!currentResponse || !LLM.responseInProgress) && Network.isActive
(!currentResponse || !LLM.currentChat.responseInProgress) && Network.isActive
anchors.right: parent.right
anchors.rightMargin: 20
anchors.top: parent.top
@ -818,7 +819,8 @@ Window {
property bool isAutoScrolling: false
Connections {
target: LLM
// FIXME: This connection has to be setup everytime a new chat object is created
target: LLM.currentChat
function onResponseChanged() {
if (listView.shouldAutoScroll) {
listView.isAutoScrolling = true
@ -853,27 +855,27 @@ Window {
anchors.verticalCenter: parent.verticalCenter
anchors.left: parent.left
anchors.leftMargin: 15
source: LLM.responseInProgress ? "qrc:/gpt4all/icons/stop_generating.svg" : "qrc:/gpt4all/icons/regenerate.svg"
source: LLM.currentChat.responseInProgress ? "qrc:/gpt4all/icons/stop_generating.svg" : "qrc:/gpt4all/icons/regenerate.svg"
}
leftPadding: 50
onClicked: {
var index = Math.max(0, chatModel.count - 1);
var listElement = chatModel.get(index);
if (LLM.responseInProgress) {
if (LLM.currentChat.responseInProgress) {
listElement.stopped = true
LLM.stopGenerating()
LLM.currentChat.stopGenerating()
} else {
LLM.regenerateResponse()
LLM.currentChat.regenerateResponse()
if (chatModel.count) {
if (listElement.name === qsTr("Response: ")) {
chatModel.updateCurrentResponse(index, true);
chatModel.updateStopped(index, false);
chatModel.updateValue(index, LLM.response);
chatModel.updateValue(index, LLM.currentChat.response);
chatModel.updateThumbsUpState(index, false);
chatModel.updateThumbsDownState(index, false);
chatModel.updateNewResponse(index, "");
LLM.prompt(listElement.prompt, settingsDialog.promptTemplate,
LLM.currentChat.prompt(listElement.prompt, settingsDialog.promptTemplate,
settingsDialog.maxLength,
settingsDialog.topK, settingsDialog.topP,
settingsDialog.temperature,
@ -889,7 +891,7 @@ Window {
anchors.bottomMargin: 40
padding: 15
contentItem: Text {
text: LLM.responseInProgress ? qsTr("Stop generating") : qsTr("Regenerate response")
text: LLM.currentChat.responseInProgress ? qsTr("Stop generating") : qsTr("Regenerate response")
color: theme.textColor
Accessible.role: Accessible.Button
Accessible.name: text
@ -917,7 +919,7 @@ Window {
color: theme.textColor
padding: 20
rightPadding: 40
enabled: LLM.isModelLoaded
enabled: LLM.currentChat.isModelLoaded
wrapMode: Text.WordWrap
font.pixelSize: theme.fontSizeLarge
placeholderText: qsTr("Send a message...")
@ -941,19 +943,18 @@ Window {
if (textInput.text === "")
return
LLM.stopGenerating()
LLM.currentChat.stopGenerating()
if (chatModel.count) {
var index = Math.max(0, chatModel.count - 1);
var listElement = chatModel.get(index);
chatModel.updateCurrentResponse(index, false);
chatModel.updateValue(index, LLM.response);
chatModel.updateValue(index, LLM.currentChat.response);
}
var prompt = textInput.text + "\n"
chatModel.appendPrompt(qsTr("Prompt: "), textInput.text);
chatModel.appendResponse(qsTr("Response: "), prompt);
LLM.resetResponse()
LLM.prompt(prompt, settingsDialog.promptTemplate,
chatModel.appendResponse(qsTr("Response: "), textInput.text);
LLM.currentChat.resetResponse()
LLM.currentChat.prompt(textInput.text, settingsDialog.promptTemplate,
settingsDialog.maxLength,
settingsDialog.topK,
settingsDialog.topP,

@ -86,7 +86,7 @@ bool Network::packageAndSendJson(const QString &ingestId, const QString &json)
Q_ASSERT(doc.isObject());
QJsonObject object = doc.object();
object.insert("source", "gpt4all-chat");
object.insert("agent_id", LLM::globalInstance()->modelName());
object.insert("agent_id", LLM::globalInstance()->currentChat()->modelName());
object.insert("submitter_id", m_uniqueId);
object.insert("ingest_id", ingestId);
@ -230,7 +230,7 @@ void Network::sendMixpanelEvent(const QString &ev)
properties.insert("ip", m_ipify);
properties.insert("name", QCoreApplication::applicationName() + " v"
+ QCoreApplication::applicationVersion());
properties.insert("model", LLM::globalInstance()->modelName());
properties.insert("model", LLM::globalInstance()->currentChat()->modelName());
QJsonObject event;
event.insert("event", ev);

Loading…
Cancel
Save