mirror of https://github.com/nomic-ai/gpt4all
Major refactor in prep for multiple conversations.
parent
e005ab8c0a
commit
4d87c46948
@ -1 +1,117 @@
|
||||
#include "chat.h"
|
||||
#include "network.h"
|
||||
|
||||
Chat::Chat(QObject *parent)
|
||||
: QObject(parent)
|
||||
, m_llmodel(new ChatLLM)
|
||||
, m_id(Network::globalInstance()->generateUniqueId())
|
||||
, m_name(tr("New Chat"))
|
||||
, m_chatModel(new ChatModel(this))
|
||||
, m_responseInProgress(false)
|
||||
, m_desiredThreadCount(std::min(4, (int32_t) std::thread::hardware_concurrency()))
|
||||
{
|
||||
connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::responseChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::responseStarted, this, &Chat::responseStarted, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::modelNameChanged, this, &Chat::modelNameChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::threadCountChanged, this, &Chat::threadCountChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::threadCountChanged, this, &Chat::syncThreadCount, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::recalcChanged, Qt::QueuedConnection);
|
||||
|
||||
connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection);
|
||||
connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection);
|
||||
|
||||
// The following are blocking operations and will block the gui thread, therefore must be fast
|
||||
// to respond to
|
||||
connect(this, &Chat::regenerateResponseRequested, m_llmodel, &ChatLLM::regenerateResponse, Qt::BlockingQueuedConnection);
|
||||
connect(this, &Chat::resetResponseRequested, m_llmodel, &ChatLLM::resetResponse, Qt::BlockingQueuedConnection);
|
||||
connect(this, &Chat::resetContextRequested, m_llmodel, &ChatLLM::resetContext, Qt::BlockingQueuedConnection);
|
||||
connect(this, &Chat::setThreadCountRequested, m_llmodel, &ChatLLM::setThreadCount, Qt::QueuedConnection);
|
||||
}
|
||||
|
||||
void Chat::reset()
|
||||
{
|
||||
m_id = Network::globalInstance()->generateUniqueId();
|
||||
m_chatModel->clear();
|
||||
}
|
||||
|
||||
bool Chat::isModelLoaded() const
|
||||
{
|
||||
return m_llmodel->isModelLoaded();
|
||||
}
|
||||
|
||||
void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
|
||||
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens)
|
||||
{
|
||||
emit promptRequested(prompt, prompt_template, n_predict, top_k, top_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens);
|
||||
}
|
||||
|
||||
void Chat::regenerateResponse()
|
||||
{
|
||||
emit regenerateResponseRequested(); // blocking queued connection
|
||||
}
|
||||
|
||||
void Chat::resetResponse()
|
||||
{
|
||||
emit resetResponseRequested(); // blocking queued connection
|
||||
}
|
||||
|
||||
void Chat::resetContext()
|
||||
{
|
||||
emit resetContextRequested(); // blocking queued connection
|
||||
}
|
||||
|
||||
void Chat::stopGenerating()
|
||||
{
|
||||
m_llmodel->stopGenerating();
|
||||
}
|
||||
|
||||
QString Chat::response() const
|
||||
{
|
||||
return m_llmodel->response();
|
||||
}
|
||||
|
||||
void Chat::responseStarted()
|
||||
{
|
||||
m_responseInProgress = true;
|
||||
emit responseInProgressChanged();
|
||||
}
|
||||
|
||||
void Chat::responseStopped()
|
||||
{
|
||||
m_responseInProgress = false;
|
||||
emit responseInProgressChanged();
|
||||
}
|
||||
|
||||
QString Chat::modelName() const
|
||||
{
|
||||
return m_llmodel->modelName();
|
||||
}
|
||||
|
||||
void Chat::setModelName(const QString &modelName)
|
||||
{
|
||||
// doesn't block but will unload old model and load new one which the gui can see through changes
|
||||
// to the isModelLoaded property
|
||||
emit modelNameChangeRequested(modelName);
|
||||
}
|
||||
|
||||
void Chat::syncThreadCount() {
|
||||
emit setThreadCountRequested(m_desiredThreadCount);
|
||||
}
|
||||
|
||||
void Chat::setThreadCount(int32_t n_threads) {
|
||||
if (n_threads <= 0)
|
||||
n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
m_desiredThreadCount = n_threads;
|
||||
syncThreadCount();
|
||||
}
|
||||
|
||||
int32_t Chat::threadCount() {
|
||||
return m_llmodel->threadCount();
|
||||
}
|
||||
|
||||
bool Chat::isRecalc() const
|
||||
{
|
||||
return m_llmodel->isRecalc();
|
||||
}
|
||||
|
@ -0,0 +1,290 @@
|
||||
#include "chatllm.h"
|
||||
#include "download.h"
|
||||
#include "network.h"
|
||||
#include "llm.h"
|
||||
#include "llmodel/gptj.h"
|
||||
#include "llmodel/llamamodel.h"
|
||||
|
||||
#include <QCoreApplication>
|
||||
#include <QDir>
|
||||
#include <QFile>
|
||||
#include <QProcess>
|
||||
#include <QResource>
|
||||
#include <QSettings>
|
||||
#include <fstream>
|
||||
|
||||
//#define DEBUG
|
||||
|
||||
static QString modelFilePath(const QString &modelName)
|
||||
{
|
||||
QString appPath = QCoreApplication::applicationDirPath()
|
||||
+ "/ggml-" + modelName + ".bin";
|
||||
QFileInfo infoAppPath(appPath);
|
||||
if (infoAppPath.exists())
|
||||
return appPath;
|
||||
|
||||
QString downloadPath = Download::globalInstance()->downloadLocalModelsPath()
|
||||
+ "/ggml-" + modelName + ".bin";
|
||||
|
||||
QFileInfo infoLocalPath(downloadPath);
|
||||
if (infoLocalPath.exists())
|
||||
return downloadPath;
|
||||
return QString();
|
||||
}
|
||||
|
||||
ChatLLM::ChatLLM()
|
||||
: QObject{nullptr}
|
||||
, m_llmodel(nullptr)
|
||||
, m_promptResponseTokens(0)
|
||||
, m_responseLogits(0)
|
||||
, m_isRecalc(false)
|
||||
{
|
||||
moveToThread(&m_llmThread);
|
||||
connect(&m_llmThread, &QThread::started, this, &ChatLLM::loadModel);
|
||||
connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup);
|
||||
connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded);
|
||||
connect(this, &ChatLLM::sendResetContext, Network::globalInstance(), &Network::sendResetContext);
|
||||
m_llmThread.setObjectName("llm thread"); // FIXME: Should identify these with chat name
|
||||
m_llmThread.start();
|
||||
}
|
||||
|
||||
bool ChatLLM::loadModel()
|
||||
{
|
||||
const QList<QString> models = LLM::globalInstance()->modelList();
|
||||
if (models.isEmpty()) {
|
||||
// try again when we get a list of models
|
||||
connect(Download::globalInstance(), &Download::modelListChanged, this,
|
||||
&ChatLLM::loadModel, Qt::SingleShotConnection);
|
||||
return false;
|
||||
}
|
||||
|
||||
QSettings settings;
|
||||
settings.sync();
|
||||
QString defaultModel = settings.value("defaultModel", "gpt4all-j-v1.3-groovy").toString();
|
||||
if (defaultModel.isEmpty() || !models.contains(defaultModel))
|
||||
defaultModel = models.first();
|
||||
return loadModelPrivate(defaultModel);
|
||||
}
|
||||
|
||||
bool ChatLLM::loadModelPrivate(const QString &modelName)
|
||||
{
|
||||
if (isModelLoaded() && m_modelName == modelName)
|
||||
return true;
|
||||
|
||||
bool isFirstLoad = false;
|
||||
if (isModelLoaded()) {
|
||||
resetContextPrivate();
|
||||
delete m_llmodel;
|
||||
m_llmodel = nullptr;
|
||||
emit isModelLoadedChanged();
|
||||
} else {
|
||||
isFirstLoad = true;
|
||||
}
|
||||
|
||||
bool isGPTJ = false;
|
||||
QString filePath = modelFilePath(modelName);
|
||||
QFileInfo info(filePath);
|
||||
if (info.exists()) {
|
||||
|
||||
auto fin = std::ifstream(filePath.toStdString(), std::ios::binary);
|
||||
uint32_t magic;
|
||||
fin.read((char *) &magic, sizeof(magic));
|
||||
fin.seekg(0);
|
||||
fin.close();
|
||||
isGPTJ = magic == 0x67676d6c;
|
||||
if (isGPTJ) {
|
||||
m_llmodel = new GPTJ;
|
||||
m_llmodel->loadModel(filePath.toStdString());
|
||||
} else {
|
||||
m_llmodel = new LLamaModel;
|
||||
m_llmodel->loadModel(filePath.toStdString());
|
||||
}
|
||||
|
||||
emit isModelLoadedChanged();
|
||||
emit threadCountChanged();
|
||||
|
||||
if (isFirstLoad)
|
||||
emit sendStartup();
|
||||
else
|
||||
emit sendModelLoaded();
|
||||
}
|
||||
|
||||
if (m_llmodel)
|
||||
setModelName(info.completeBaseName().remove(0, 5)); // remove the ggml- prefix
|
||||
|
||||
return m_llmodel;
|
||||
}
|
||||
|
||||
void ChatLLM::setThreadCount(int32_t n_threads) {
|
||||
if (m_llmodel && m_llmodel->threadCount() != n_threads) {
|
||||
m_llmodel->setThreadCount(n_threads);
|
||||
emit threadCountChanged();
|
||||
}
|
||||
}
|
||||
|
||||
int32_t ChatLLM::threadCount() {
|
||||
if (!m_llmodel)
|
||||
return 1;
|
||||
return m_llmodel->threadCount();
|
||||
}
|
||||
|
||||
bool ChatLLM::isModelLoaded() const
|
||||
{
|
||||
return m_llmodel && m_llmodel->isModelLoaded();
|
||||
}
|
||||
|
||||
void ChatLLM::regenerateResponse()
|
||||
{
|
||||
m_ctx.n_past -= m_promptResponseTokens;
|
||||
m_ctx.n_past = std::max(0, m_ctx.n_past);
|
||||
// FIXME: This does not seem to be needed in my testing and llama models don't to it. Remove?
|
||||
m_ctx.logits.erase(m_ctx.logits.end() -= m_responseLogits, m_ctx.logits.end());
|
||||
m_ctx.tokens.erase(m_ctx.tokens.end() -= m_promptResponseTokens, m_ctx.tokens.end());
|
||||
m_promptResponseTokens = 0;
|
||||
m_responseLogits = 0;
|
||||
m_response = std::string();
|
||||
emit responseChanged();
|
||||
}
|
||||
|
||||
void ChatLLM::resetResponse()
|
||||
{
|
||||
m_promptResponseTokens = 0;
|
||||
m_responseLogits = 0;
|
||||
m_response = std::string();
|
||||
emit responseChanged();
|
||||
}
|
||||
|
||||
void ChatLLM::resetContext()
|
||||
{
|
||||
resetContextPrivate();
|
||||
emit sendResetContext();
|
||||
}
|
||||
|
||||
void ChatLLM::resetContextPrivate()
|
||||
{
|
||||
regenerateResponse();
|
||||
m_ctx = LLModel::PromptContext();
|
||||
}
|
||||
|
||||
std::string remove_leading_whitespace(const std::string& input) {
|
||||
auto first_non_whitespace = std::find_if(input.begin(), input.end(), [](unsigned char c) {
|
||||
return !std::isspace(c);
|
||||
});
|
||||
|
||||
return std::string(first_non_whitespace, input.end());
|
||||
}
|
||||
|
||||
std::string trim_whitespace(const std::string& input) {
|
||||
auto first_non_whitespace = std::find_if(input.begin(), input.end(), [](unsigned char c) {
|
||||
return !std::isspace(c);
|
||||
});
|
||||
|
||||
auto last_non_whitespace = std::find_if(input.rbegin(), input.rend(), [](unsigned char c) {
|
||||
return !std::isspace(c);
|
||||
}).base();
|
||||
|
||||
return std::string(first_non_whitespace, last_non_whitespace);
|
||||
}
|
||||
|
||||
QString ChatLLM::response() const
|
||||
{
|
||||
return QString::fromStdString(remove_leading_whitespace(m_response));
|
||||
}
|
||||
|
||||
QString ChatLLM::modelName() const
|
||||
{
|
||||
return m_modelName;
|
||||
}
|
||||
|
||||
void ChatLLM::setModelName(const QString &modelName)
|
||||
{
|
||||
m_modelName = modelName;
|
||||
emit modelNameChanged();
|
||||
}
|
||||
|
||||
void ChatLLM::modelNameChangeRequested(const QString &modelName)
|
||||
{
|
||||
if (!loadModelPrivate(modelName))
|
||||
qWarning() << "ERROR: Could not load model" << modelName;
|
||||
}
|
||||
|
||||
bool ChatLLM::handlePrompt(int32_t token)
|
||||
{
|
||||
// m_promptResponseTokens and m_responseLogits are related to last prompt/response not
|
||||
// the entire context window which we can reset on regenerate prompt
|
||||
++m_promptResponseTokens;
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleResponse(int32_t token, const std::string &response)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
printf("%s", response.c_str());
|
||||
fflush(stdout);
|
||||
#endif
|
||||
|
||||
// check for error
|
||||
if (token < 0) {
|
||||
m_response.append(response);
|
||||
emit responseChanged();
|
||||
return false;
|
||||
}
|
||||
|
||||
// m_promptResponseTokens and m_responseLogits are related to last prompt/response not
|
||||
// the entire context window which we can reset on regenerate prompt
|
||||
++m_promptResponseTokens;
|
||||
Q_ASSERT(!response.empty());
|
||||
m_response.append(response);
|
||||
emit responseChanged();
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleRecalculate(bool isRecalc)
|
||||
{
|
||||
if (m_isRecalc != isRecalc) {
|
||||
m_isRecalc = isRecalc;
|
||||
emit recalcChanged();
|
||||
}
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
|
||||
bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
|
||||
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens)
|
||||
{
|
||||
if (!isModelLoaded())
|
||||
return false;
|
||||
|
||||
QString instructPrompt = prompt_template.arg(prompt);
|
||||
|
||||
m_stopGenerating = false;
|
||||
auto promptFunc = std::bind(&ChatLLM::handlePrompt, this, std::placeholders::_1);
|
||||
auto responseFunc = std::bind(&ChatLLM::handleResponse, this, std::placeholders::_1,
|
||||
std::placeholders::_2);
|
||||
auto recalcFunc = std::bind(&ChatLLM::handleRecalculate, this, std::placeholders::_1);
|
||||
emit responseStarted();
|
||||
qint32 logitsBefore = m_ctx.logits.size();
|
||||
m_ctx.n_predict = n_predict;
|
||||
m_ctx.top_k = top_k;
|
||||
m_ctx.top_p = top_p;
|
||||
m_ctx.temp = temp;
|
||||
m_ctx.n_batch = n_batch;
|
||||
m_ctx.repeat_penalty = repeat_penalty;
|
||||
m_ctx.repeat_last_n = repeat_penalty_tokens;
|
||||
#if defined(DEBUG)
|
||||
printf("%s", qPrintable(instructPrompt));
|
||||
fflush(stdout);
|
||||
#endif
|
||||
m_llmodel->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
|
||||
#if defined(DEBUG)
|
||||
printf("\n");
|
||||
fflush(stdout);
|
||||
#endif
|
||||
m_responseLogits += m_ctx.logits.size() - logitsBefore;
|
||||
std::string trimmed = trim_whitespace(m_response);
|
||||
if (trimmed != m_response) {
|
||||
m_response = trimmed;
|
||||
emit responseChanged();
|
||||
}
|
||||
emit responseStopped();
|
||||
return true;
|
||||
}
|
@ -0,0 +1,75 @@
|
||||
#ifndef CHATLLM_H
|
||||
#define CHATLLM_H
|
||||
|
||||
#include <QObject>
|
||||
#include <QThread>
|
||||
|
||||
#include "llmodel/llmodel.h"
|
||||
|
||||
class ChatLLM : public QObject
|
||||
{
|
||||
Q_OBJECT
|
||||
Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged)
|
||||
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
|
||||
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
|
||||
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
|
||||
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
|
||||
|
||||
public:
|
||||
|
||||
ChatLLM();
|
||||
|
||||
bool isModelLoaded() const;
|
||||
void regenerateResponse();
|
||||
void resetResponse();
|
||||
void resetContext();
|
||||
|
||||
void stopGenerating() { m_stopGenerating = true; }
|
||||
void setThreadCount(int32_t n_threads);
|
||||
int32_t threadCount();
|
||||
|
||||
QString response() const;
|
||||
QString modelName() const;
|
||||
|
||||
void setModelName(const QString &modelName);
|
||||
|
||||
bool isRecalc() const { return m_isRecalc; }
|
||||
|
||||
public Q_SLOTS:
|
||||
bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
|
||||
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens);
|
||||
bool loadModel();
|
||||
void modelNameChangeRequested(const QString &modelName);
|
||||
|
||||
Q_SIGNALS:
|
||||
void isModelLoadedChanged();
|
||||
void responseChanged();
|
||||
void responseStarted();
|
||||
void responseStopped();
|
||||
void modelNameChanged();
|
||||
void threadCountChanged();
|
||||
void recalcChanged();
|
||||
void sendStartup();
|
||||
void sendModelLoaded();
|
||||
void sendResetContext();
|
||||
|
||||
private:
|
||||
void resetContextPrivate();
|
||||
bool loadModelPrivate(const QString &modelName);
|
||||
bool handlePrompt(int32_t token);
|
||||
bool handleResponse(int32_t token, const std::string &response);
|
||||
bool handleRecalculate(bool isRecalc);
|
||||
|
||||
private:
|
||||
LLModel::PromptContext m_ctx;
|
||||
LLModel *m_llmodel;
|
||||
std::string m_response;
|
||||
quint32 m_promptResponseTokens;
|
||||
quint32 m_responseLogits;
|
||||
QString m_modelName;
|
||||
QThread m_llmThread;
|
||||
std::atomic<bool> m_stopGenerating;
|
||||
bool m_isRecalc;
|
||||
};
|
||||
|
||||
#endif // CHATLLM_H
|
Loading…
Reference in New Issue