|
|
|
@ -15,6 +15,7 @@
|
|
|
|
|
#include <fstream>
|
|
|
|
|
|
|
|
|
|
//#define DEBUG
|
|
|
|
|
//#define DEBUG_MODEL_LOADING
|
|
|
|
|
|
|
|
|
|
#define MPT_INTERNAL_STATE_VERSION 0
|
|
|
|
|
#define GPTJ_INTERNAL_STATE_VERSION 0
|
|
|
|
@ -37,9 +38,51 @@ static QString modelFilePath(const QString &modelName)
|
|
|
|
|
return QString();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class LLModelStore {
|
|
|
|
|
public:
|
|
|
|
|
static LLModelStore *globalInstance();
|
|
|
|
|
|
|
|
|
|
LLModelInfo acquireModel(); // will block until llmodel is ready
|
|
|
|
|
void releaseModel(const LLModelInfo &info); // must be called when you are done
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
LLModelStore()
|
|
|
|
|
{
|
|
|
|
|
// seed with empty model
|
|
|
|
|
m_availableModels.append(LLModelInfo());
|
|
|
|
|
}
|
|
|
|
|
~LLModelStore() {}
|
|
|
|
|
QVector<LLModelInfo> m_availableModels;
|
|
|
|
|
QMutex m_mutex;
|
|
|
|
|
QWaitCondition m_condition;
|
|
|
|
|
friend class MyLLModelStore;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class MyLLModelStore : public LLModelStore { };
|
|
|
|
|
Q_GLOBAL_STATIC(MyLLModelStore, storeInstance)
|
|
|
|
|
LLModelStore *LLModelStore::globalInstance()
|
|
|
|
|
{
|
|
|
|
|
return storeInstance();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
LLModelInfo LLModelStore::acquireModel()
|
|
|
|
|
{
|
|
|
|
|
QMutexLocker locker(&m_mutex);
|
|
|
|
|
while (m_availableModels.isEmpty())
|
|
|
|
|
m_condition.wait(locker.mutex());
|
|
|
|
|
return m_availableModels.takeFirst();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void LLModelStore::releaseModel(const LLModelInfo &info)
|
|
|
|
|
{
|
|
|
|
|
QMutexLocker locker(&m_mutex);
|
|
|
|
|
m_availableModels.append(info);
|
|
|
|
|
Q_ASSERT(m_availableModels.count() < 2);
|
|
|
|
|
m_condition.wakeAll();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ChatLLM::ChatLLM(Chat *parent)
|
|
|
|
|
: QObject{nullptr}
|
|
|
|
|
, m_llmodel(nullptr)
|
|
|
|
|
, m_promptResponseTokens(0)
|
|
|
|
|
, m_promptTokens(0)
|
|
|
|
|
, m_responseLogits(0)
|
|
|
|
@ -49,6 +92,7 @@ ChatLLM::ChatLLM(Chat *parent)
|
|
|
|
|
moveToThread(&m_llmThread);
|
|
|
|
|
connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup);
|
|
|
|
|
connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded);
|
|
|
|
|
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged, Qt::QueuedConnection);
|
|
|
|
|
connect(m_chat, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
|
|
|
|
|
connect(&m_llmThread, &QThread::started, this, &ChatLLM::threadStarted);
|
|
|
|
|
m_llmThread.setObjectName(m_chat->id());
|
|
|
|
@ -59,7 +103,13 @@ ChatLLM::~ChatLLM()
|
|
|
|
|
{
|
|
|
|
|
m_llmThread.quit();
|
|
|
|
|
m_llmThread.wait();
|
|
|
|
|
delete m_llmodel;
|
|
|
|
|
|
|
|
|
|
// The only time we should have a model loaded here is on shutdown
|
|
|
|
|
// as we explicitly unload the model in all other circumstances
|
|
|
|
|
if (isModelLoaded()) {
|
|
|
|
|
delete m_modelInfo.model;
|
|
|
|
|
m_modelInfo.model = nullptr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ChatLLM::loadDefaultModel()
|
|
|
|
@ -76,50 +126,103 @@ bool ChatLLM::loadDefaultModel()
|
|
|
|
|
|
|
|
|
|
bool ChatLLM::loadModel(const QString &modelName)
|
|
|
|
|
{
|
|
|
|
|
// This is a complicated method because N different possible threads are interested in the outcome
|
|
|
|
|
// of this method. Why? Because we have a main/gui thread trying to monitor the state of N different
|
|
|
|
|
// possible chat threads all vying for a single resource - the currently loaded model - as the user
|
|
|
|
|
// switches back and forth between chats. It is important for our main/gui thread to never block
|
|
|
|
|
// but simultaneously always have up2date information with regards to which chat has the model loaded
|
|
|
|
|
// and what the type and name of that model is. I've tried to comment extensively in this method
|
|
|
|
|
// to provide an overview of what we're doing here.
|
|
|
|
|
|
|
|
|
|
// We're already loaded with this model
|
|
|
|
|
if (isModelLoaded() && m_modelName == modelName)
|
|
|
|
|
return true;
|
|
|
|
|
|
|
|
|
|
if (isModelLoaded()) {
|
|
|
|
|
QString filePath = modelFilePath(modelName);
|
|
|
|
|
QFileInfo fileInfo(filePath);
|
|
|
|
|
|
|
|
|
|
// We have a live model, but it isn't the one we want
|
|
|
|
|
bool alreadyAcquired = isModelLoaded();
|
|
|
|
|
if (alreadyAcquired) {
|
|
|
|
|
resetContextProtected();
|
|
|
|
|
delete m_llmodel;
|
|
|
|
|
m_llmodel = nullptr;
|
|
|
|
|
#if defined(DEBUG_MODEL_LOADING)
|
|
|
|
|
qDebug() << "already acquired model deleted" << m_chat->id() << m_modelInfo.model;
|
|
|
|
|
#endif
|
|
|
|
|
delete m_modelInfo.model;
|
|
|
|
|
m_modelInfo.model = nullptr;
|
|
|
|
|
emit isModelLoadedChanged();
|
|
|
|
|
} else {
|
|
|
|
|
// This is a blocking call that tries to retrieve the model we need from the model store.
|
|
|
|
|
// If it succeeds, then we just have to restore state. If the store has never had a model
|
|
|
|
|
// returned to it, then the modelInfo.model pointer should be null which will happen on startup
|
|
|
|
|
m_modelInfo = LLModelStore::globalInstance()->acquireModel();
|
|
|
|
|
#if defined(DEBUG_MODEL_LOADING)
|
|
|
|
|
qDebug() << "acquired model from store" << m_chat->id() << m_modelInfo.model;
|
|
|
|
|
#endif
|
|
|
|
|
// At this point it is possible that while we were blocked waiting to acquire the model from the
|
|
|
|
|
// store, that our state was changed to not be loaded. If this is the case, release the model
|
|
|
|
|
// back into the store and quit loading
|
|
|
|
|
if (!m_shouldBeLoaded) {
|
|
|
|
|
qDebug() << "no longer need model" << m_chat->id() << m_modelInfo.model;
|
|
|
|
|
LLModelStore::globalInstance()->releaseModel(m_modelInfo);
|
|
|
|
|
m_modelInfo = LLModelInfo();
|
|
|
|
|
emit isModelLoadedChanged();
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Check if the store just gave us exactly the model we were looking for
|
|
|
|
|
if (m_modelInfo.model && m_modelInfo.fileInfo == fileInfo) {
|
|
|
|
|
#if defined(DEBUG_MODEL_LOADING)
|
|
|
|
|
qDebug() << "store had our model" << m_chat->id() << m_modelInfo.model;
|
|
|
|
|
#endif
|
|
|
|
|
restoreState();
|
|
|
|
|
emit isModelLoadedChanged();
|
|
|
|
|
return true;
|
|
|
|
|
} else {
|
|
|
|
|
// Release the memory since we have to switch to a different model.
|
|
|
|
|
#if defined(DEBUG_MODEL_LOADING)
|
|
|
|
|
qDebug() << "deleting model" << m_chat->id() << m_modelInfo.model;
|
|
|
|
|
#endif
|
|
|
|
|
delete m_modelInfo.model;
|
|
|
|
|
m_modelInfo.model = nullptr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool isGPTJ = false;
|
|
|
|
|
bool isMPT = false;
|
|
|
|
|
QString filePath = modelFilePath(modelName);
|
|
|
|
|
QFileInfo info(filePath);
|
|
|
|
|
if (info.exists()) {
|
|
|
|
|
// Guarantee we've released the previous models memory
|
|
|
|
|
Q_ASSERT(!m_modelInfo.model);
|
|
|
|
|
|
|
|
|
|
// Store the file info in the modelInfo in case we have an error loading
|
|
|
|
|
m_modelInfo.fileInfo = fileInfo;
|
|
|
|
|
|
|
|
|
|
if (fileInfo.exists()) {
|
|
|
|
|
auto fin = std::ifstream(filePath.toStdString(), std::ios::binary);
|
|
|
|
|
uint32_t magic;
|
|
|
|
|
fin.read((char *) &magic, sizeof(magic));
|
|
|
|
|
fin.seekg(0);
|
|
|
|
|
fin.close();
|
|
|
|
|
isGPTJ = magic == 0x67676d6c;
|
|
|
|
|
isMPT = magic == 0x67676d6d;
|
|
|
|
|
const bool isGPTJ = magic == 0x67676d6c;
|
|
|
|
|
const bool isMPT = magic == 0x67676d6d;
|
|
|
|
|
if (isGPTJ) {
|
|
|
|
|
m_modelType = ModelType::GPTJ_;
|
|
|
|
|
m_llmodel = new GPTJ;
|
|
|
|
|
m_llmodel->loadModel(filePath.toStdString());
|
|
|
|
|
m_modelType = LLModelType::GPTJ_;
|
|
|
|
|
m_modelInfo.model = new GPTJ;
|
|
|
|
|
m_modelInfo.model->loadModel(filePath.toStdString());
|
|
|
|
|
} else if (isMPT) {
|
|
|
|
|
m_modelType = ModelType::MPT_;
|
|
|
|
|
m_llmodel = new MPT;
|
|
|
|
|
m_llmodel->loadModel(filePath.toStdString());
|
|
|
|
|
m_modelType = LLModelType::MPT_;
|
|
|
|
|
m_modelInfo.model = new MPT;
|
|
|
|
|
m_modelInfo.model->loadModel(filePath.toStdString());
|
|
|
|
|
} else {
|
|
|
|
|
m_modelType = ModelType::LLAMA_;
|
|
|
|
|
m_llmodel = new LLamaModel;
|
|
|
|
|
m_llmodel->loadModel(filePath.toStdString());
|
|
|
|
|
m_modelType = LLModelType::LLAMA_;
|
|
|
|
|
m_modelInfo.model = new LLamaModel;
|
|
|
|
|
m_modelInfo.model->loadModel(filePath.toStdString());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#if defined(DEBUG_MODEL_LOADING)
|
|
|
|
|
qDebug() << "new model" << m_chat->id() << m_modelInfo.model;
|
|
|
|
|
#endif
|
|
|
|
|
restoreState();
|
|
|
|
|
|
|
|
|
|
#if defined(DEBUG)
|
|
|
|
|
qDebug() << "chatllm modelLoadedChanged" << m_chat->id();
|
|
|
|
|
fflush(stdout);
|
|
|
|
|
qDebug() << "modelLoadedChanged" << m_chat->id();
|
|
|
|
|
fflush(stdout);
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
emit isModelLoadedChanged();
|
|
|
|
|
|
|
|
|
|
static bool isFirstLoad = true;
|
|
|
|
@ -129,19 +232,20 @@ bool ChatLLM::loadModel(const QString &modelName)
|
|
|
|
|
} else
|
|
|
|
|
emit sendModelLoaded();
|
|
|
|
|
} else {
|
|
|
|
|
LLModelStore::globalInstance()->releaseModel(m_modelInfo); // release back into the store
|
|
|
|
|
const QString error = QString("Could not find model %1").arg(modelName);
|
|
|
|
|
emit modelLoadingError(error);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (m_llmodel)
|
|
|
|
|
setModelName(info.completeBaseName().remove(0, 5)); // remove the ggml- prefix
|
|
|
|
|
if (m_modelInfo.model)
|
|
|
|
|
setModelName(fileInfo.completeBaseName().remove(0, 5)); // remove the ggml- prefix
|
|
|
|
|
|
|
|
|
|
return m_llmodel;
|
|
|
|
|
return m_modelInfo.model;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ChatLLM::isModelLoaded() const
|
|
|
|
|
{
|
|
|
|
|
return m_llmodel && m_llmodel->isModelLoaded();
|
|
|
|
|
return m_modelInfo.model && m_modelInfo.model->isModelLoaded();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ChatLLM::regenerateResponse()
|
|
|
|
@ -226,7 +330,7 @@ bool ChatLLM::handlePrompt(int32_t token)
|
|
|
|
|
// m_promptResponseTokens and m_responseLogits are related to last prompt/response not
|
|
|
|
|
// the entire context window which we can reset on regenerate prompt
|
|
|
|
|
#if defined(DEBUG)
|
|
|
|
|
qDebug() << "chatllm prompt process" << m_chat->id() << token;
|
|
|
|
|
qDebug() << "prompt process" << m_chat->id() << token;
|
|
|
|
|
#endif
|
|
|
|
|
++m_promptTokens;
|
|
|
|
|
++m_promptResponseTokens;
|
|
|
|
@ -287,12 +391,12 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
|
|
|
|
|
m_ctx.n_batch = n_batch;
|
|
|
|
|
m_ctx.repeat_penalty = repeat_penalty;
|
|
|
|
|
m_ctx.repeat_last_n = repeat_penalty_tokens;
|
|
|
|
|
m_llmodel->setThreadCount(n_threads);
|
|
|
|
|
m_modelInfo.model->setThreadCount(n_threads);
|
|
|
|
|
#if defined(DEBUG)
|
|
|
|
|
printf("%s", qPrintable(instructPrompt));
|
|
|
|
|
fflush(stdout);
|
|
|
|
|
#endif
|
|
|
|
|
m_llmodel->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
|
|
|
|
|
m_modelInfo.model->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
|
|
|
|
|
#if defined(DEBUG)
|
|
|
|
|
printf("\n");
|
|
|
|
|
fflush(stdout);
|
|
|
|
@ -307,26 +411,55 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ChatLLM::unloadModel()
|
|
|
|
|
void ChatLLM::setShouldBeLoaded(bool b)
|
|
|
|
|
{
|
|
|
|
|
#if defined(DEBUG)
|
|
|
|
|
qDebug() << "chatllm unloadModel" << m_chat->id();
|
|
|
|
|
#if defined(DEBUG_MODEL_LOADING)
|
|
|
|
|
qDebug() << "setShouldBeLoaded" << m_chat->id() << b << m_modelInfo.model;
|
|
|
|
|
#endif
|
|
|
|
|
m_shouldBeLoaded = b; // atomic
|
|
|
|
|
emit shouldBeLoadedChanged();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ChatLLM::handleShouldBeLoadedChanged()
|
|
|
|
|
{
|
|
|
|
|
if (m_shouldBeLoaded)
|
|
|
|
|
reloadModel();
|
|
|
|
|
else
|
|
|
|
|
unloadModel();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ChatLLM::forceUnloadModel()
|
|
|
|
|
{
|
|
|
|
|
m_shouldBeLoaded = false; // atomic
|
|
|
|
|
unloadModel();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ChatLLM::unloadModel()
|
|
|
|
|
{
|
|
|
|
|
if (!isModelLoaded())
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
|
|
saveState();
|
|
|
|
|
delete m_llmodel;
|
|
|
|
|
m_llmodel = nullptr;
|
|
|
|
|
#if defined(DEBUG_MODEL_LOADING)
|
|
|
|
|
qDebug() << "unloadModel" << m_chat->id() << m_modelInfo.model;
|
|
|
|
|
#endif
|
|
|
|
|
LLModelStore::globalInstance()->releaseModel(m_modelInfo);
|
|
|
|
|
m_modelInfo = LLModelInfo();
|
|
|
|
|
emit isModelLoadedChanged();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ChatLLM::reloadModel(const QString &modelName)
|
|
|
|
|
void ChatLLM::reloadModel()
|
|
|
|
|
{
|
|
|
|
|
#if defined(DEBUG)
|
|
|
|
|
qDebug() << "chatllm reloadModel" << m_chat->id();
|
|
|
|
|
if (isModelLoaded())
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
|
|
#if defined(DEBUG_MODEL_LOADING)
|
|
|
|
|
qDebug() << "reloadModel" << m_chat->id() << m_modelInfo.model;
|
|
|
|
|
#endif
|
|
|
|
|
if (modelName.isEmpty()) {
|
|
|
|
|
if (m_modelName.isEmpty()) {
|
|
|
|
|
loadDefaultModel();
|
|
|
|
|
} else {
|
|
|
|
|
loadModel(modelName);
|
|
|
|
|
loadModel(m_modelName);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -348,7 +481,7 @@ void ChatLLM::generateName()
|
|
|
|
|
printf("%s", qPrintable(instructPrompt));
|
|
|
|
|
fflush(stdout);
|
|
|
|
|
#endif
|
|
|
|
|
m_llmodel->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, ctx);
|
|
|
|
|
m_modelInfo.model->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, ctx);
|
|
|
|
|
#if defined(DEBUG)
|
|
|
|
|
printf("\n");
|
|
|
|
|
fflush(stdout);
|
|
|
|
@ -415,7 +548,7 @@ bool ChatLLM::serialize(QDataStream &stream, int version)
|
|
|
|
|
QByteArray compressed = qCompress(m_state);
|
|
|
|
|
stream << compressed;
|
|
|
|
|
#if defined(DEBUG)
|
|
|
|
|
qDebug() << "chatllm serialize" << m_chat->id() << m_state.size();
|
|
|
|
|
qDebug() << "serialize" << m_chat->id() << m_state.size();
|
|
|
|
|
#endif
|
|
|
|
|
return stream.status() == QDataStream::Ok;
|
|
|
|
|
}
|
|
|
|
@ -452,7 +585,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version)
|
|
|
|
|
stream >> m_state;
|
|
|
|
|
}
|
|
|
|
|
#if defined(DEBUG)
|
|
|
|
|
qDebug() << "chatllm deserialize" << m_chat->id();
|
|
|
|
|
qDebug() << "deserialize" << m_chat->id();
|
|
|
|
|
#endif
|
|
|
|
|
return stream.status() == QDataStream::Ok;
|
|
|
|
|
}
|
|
|
|
@ -462,12 +595,12 @@ void ChatLLM::saveState()
|
|
|
|
|
if (!isModelLoaded())
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
|
|
const size_t stateSize = m_llmodel->stateSize();
|
|
|
|
|
const size_t stateSize = m_modelInfo.model->stateSize();
|
|
|
|
|
m_state.resize(stateSize);
|
|
|
|
|
#if defined(DEBUG)
|
|
|
|
|
qDebug() << "chatllm saveState" << m_chat->id() << "size:" << m_state.size();
|
|
|
|
|
qDebug() << "saveState" << m_chat->id() << "size:" << m_state.size();
|
|
|
|
|
#endif
|
|
|
|
|
m_llmodel->saveState(static_cast<uint8_t*>(reinterpret_cast<void*>(m_state.data())));
|
|
|
|
|
m_modelInfo.model->saveState(static_cast<uint8_t*>(reinterpret_cast<void*>(m_state.data())));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ChatLLM::restoreState()
|
|
|
|
@ -476,9 +609,9 @@ void ChatLLM::restoreState()
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
|
|
#if defined(DEBUG)
|
|
|
|
|
qDebug() << "chatllm restoreState" << m_chat->id() << "size:" << m_state.size();
|
|
|
|
|
qDebug() << "restoreState" << m_chat->id() << "size:" << m_state.size();
|
|
|
|
|
#endif
|
|
|
|
|
m_llmodel->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
|
|
|
|
|
m_modelInfo.model->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
|
|
|
|
|
m_state.clear();
|
|
|
|
|
m_state.resize(0);
|
|
|
|
|
}
|
|
|
|
|