|
|
|
@ -94,7 +94,6 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
|
|
|
|
|
, m_isRecalc(false)
|
|
|
|
|
, m_shouldBeLoaded(true)
|
|
|
|
|
, m_stopGenerating(false)
|
|
|
|
|
, m_chat(parent)
|
|
|
|
|
, m_timer(nullptr)
|
|
|
|
|
, m_isServer(isServer)
|
|
|
|
|
, m_isChatGPT(false)
|
|
|
|
@ -104,14 +103,15 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
|
|
|
|
|
connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded);
|
|
|
|
|
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged,
|
|
|
|
|
Qt::QueuedConnection); // explicitly queued
|
|
|
|
|
connect(m_chat, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
|
|
|
|
|
connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
|
|
|
|
|
connect(parent, &Chat::defaultModelChanged, this, &ChatLLM::handleDefaultModelChanged);
|
|
|
|
|
connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted);
|
|
|
|
|
|
|
|
|
|
// The following are blocking operations and will block the llm thread
|
|
|
|
|
connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB,
|
|
|
|
|
Qt::BlockingQueuedConnection);
|
|
|
|
|
|
|
|
|
|
m_llmThread.setObjectName(m_chat->id());
|
|
|
|
|
m_llmThread.setObjectName(parent->id());
|
|
|
|
|
m_llmThread.start();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -137,14 +137,11 @@ void ChatLLM::handleThreadStarted()
|
|
|
|
|
|
|
|
|
|
bool ChatLLM::loadDefaultModel()
|
|
|
|
|
{
|
|
|
|
|
const QList<QString> models = m_chat->modelList();
|
|
|
|
|
if (models.isEmpty()) {
|
|
|
|
|
// try again when we get a list of models
|
|
|
|
|
connect(Download::globalInstance(), &Download::modelListChanged, this,
|
|
|
|
|
&ChatLLM::loadDefaultModel, Qt::SingleShotConnection);
|
|
|
|
|
if (m_defaultModel.isEmpty()) {
|
|
|
|
|
emit modelLoadingError(QString("Could not find default model to load"));
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return loadModel(models.first());
|
|
|
|
|
return loadModel(m_defaultModel);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ChatLLM::loadModel(const QString &modelName)
|
|
|
|
@ -170,7 +167,7 @@ bool ChatLLM::loadModel(const QString &modelName)
|
|
|
|
|
if (alreadyAcquired) {
|
|
|
|
|
resetContext();
|
|
|
|
|
#if defined(DEBUG_MODEL_LOADING)
|
|
|
|
|
qDebug() << "already acquired model deleted" << m_chat->id() << m_modelInfo.model;
|
|
|
|
|
qDebug() << "already acquired model deleted" << m_llmThread.objectName() << m_modelInfo.model;
|
|
|
|
|
#endif
|
|
|
|
|
delete m_modelInfo.model;
|
|
|
|
|
m_modelInfo.model = nullptr;
|
|
|
|
@ -181,14 +178,14 @@ bool ChatLLM::loadModel(const QString &modelName)
|
|
|
|
|
// returned to it, then the modelInfo.model pointer should be null which will happen on startup
|
|
|
|
|
m_modelInfo = LLModelStore::globalInstance()->acquireModel();
|
|
|
|
|
#if defined(DEBUG_MODEL_LOADING)
|
|
|
|
|
qDebug() << "acquired model from store" << m_chat->id() << m_modelInfo.model;
|
|
|
|
|
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_modelInfo.model;
|
|
|
|
|
#endif
|
|
|
|
|
// At this point it is possible that while we were blocked waiting to acquire the model from the
|
|
|
|
|
// store, that our state was changed to not be loaded. If this is the case, release the model
|
|
|
|
|
// back into the store and quit loading
|
|
|
|
|
if (!m_shouldBeLoaded) {
|
|
|
|
|
#if defined(DEBUG_MODEL_LOADING)
|
|
|
|
|
qDebug() << "no longer need model" << m_chat->id() << m_modelInfo.model;
|
|
|
|
|
qDebug() << "no longer need model" << m_llmThread.objectName() << m_modelInfo.model;
|
|
|
|
|
#endif
|
|
|
|
|
LLModelStore::globalInstance()->releaseModel(m_modelInfo);
|
|
|
|
|
m_modelInfo = LLModelInfo();
|
|
|
|
@ -199,7 +196,7 @@ bool ChatLLM::loadModel(const QString &modelName)
|
|
|
|
|
// Check if the store just gave us exactly the model we were looking for
|
|
|
|
|
if (m_modelInfo.model && m_modelInfo.fileInfo == fileInfo) {
|
|
|
|
|
#if defined(DEBUG_MODEL_LOADING)
|
|
|
|
|
qDebug() << "store had our model" << m_chat->id() << m_modelInfo.model;
|
|
|
|
|
qDebug() << "store had our model" << m_llmThread.objectName() << m_modelInfo.model;
|
|
|
|
|
#endif
|
|
|
|
|
restoreState();
|
|
|
|
|
emit isModelLoadedChanged();
|
|
|
|
@ -207,7 +204,7 @@ bool ChatLLM::loadModel(const QString &modelName)
|
|
|
|
|
} else {
|
|
|
|
|
// Release the memory since we have to switch to a different model.
|
|
|
|
|
#if defined(DEBUG_MODEL_LOADING)
|
|
|
|
|
qDebug() << "deleting model" << m_chat->id() << m_modelInfo.model;
|
|
|
|
|
qDebug() << "deleting model" << m_llmThread.objectName() << m_modelInfo.model;
|
|
|
|
|
#endif
|
|
|
|
|
delete m_modelInfo.model;
|
|
|
|
|
m_modelInfo.model = nullptr;
|
|
|
|
@ -239,13 +236,28 @@ bool ChatLLM::loadModel(const QString &modelName)
|
|
|
|
|
} else {
|
|
|
|
|
m_modelInfo.model = LLModel::construct(filePath.toStdString());
|
|
|
|
|
if (m_modelInfo.model) {
|
|
|
|
|
m_modelInfo.model->loadModel(filePath.toStdString());
|
|
|
|
|
switch (m_modelInfo.model->implementation().modelType[0]) {
|
|
|
|
|
case 'L': m_modelType = LLModelType::LLAMA_; break;
|
|
|
|
|
case 'G': m_modelType = LLModelType::GPTJ_; break;
|
|
|
|
|
case 'M': m_modelType = LLModelType::MPT_; break;
|
|
|
|
|
case 'R': m_modelType = LLModelType::REPLIT_; break;
|
|
|
|
|
default: delete std::exchange(m_modelInfo.model, nullptr);
|
|
|
|
|
bool success = m_modelInfo.model->loadModel(filePath.toStdString());
|
|
|
|
|
if (!success) {
|
|
|
|
|
delete std::exchange(m_modelInfo.model, nullptr);
|
|
|
|
|
if (!m_isServer)
|
|
|
|
|
LLModelStore::globalInstance()->releaseModel(m_modelInfo); // release back into the store
|
|
|
|
|
m_modelInfo = LLModelInfo();
|
|
|
|
|
emit modelLoadingError(QString("Could not load model due to invalid model file for %1").arg(modelName));
|
|
|
|
|
} else {
|
|
|
|
|
switch (m_modelInfo.model->implementation().modelType[0]) {
|
|
|
|
|
case 'L': m_modelType = LLModelType::LLAMA_; break;
|
|
|
|
|
case 'G': m_modelType = LLModelType::GPTJ_; break;
|
|
|
|
|
case 'M': m_modelType = LLModelType::MPT_; break;
|
|
|
|
|
case 'R': m_modelType = LLModelType::REPLIT_; break;
|
|
|
|
|
default:
|
|
|
|
|
{
|
|
|
|
|
delete std::exchange(m_modelInfo.model, nullptr);
|
|
|
|
|
if (!m_isServer)
|
|
|
|
|
LLModelStore::globalInstance()->releaseModel(m_modelInfo); // release back into the store
|
|
|
|
|
m_modelInfo = LLModelInfo();
|
|
|
|
|
emit modelLoadingError(QString("Could not determine model type for %1").arg(modelName));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if (!m_isServer)
|
|
|
|
@ -255,11 +267,11 @@ bool ChatLLM::loadModel(const QString &modelName)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#if defined(DEBUG_MODEL_LOADING)
|
|
|
|
|
qDebug() << "new model" << m_chat->id() << m_modelInfo.model;
|
|
|
|
|
qDebug() << "new model" << m_llmThread.objectName() << m_modelInfo.model;
|
|
|
|
|
#endif
|
|
|
|
|
restoreState();
|
|
|
|
|
#if defined(DEBUG)
|
|
|
|
|
qDebug() << "modelLoadedChanged" << m_chat->id();
|
|
|
|
|
qDebug() << "modelLoadedChanged" << m_llmThread.objectName();
|
|
|
|
|
fflush(stdout);
|
|
|
|
|
#endif
|
|
|
|
|
emit isModelLoadedChanged();
|
|
|
|
@ -368,7 +380,7 @@ bool ChatLLM::handlePrompt(int32_t token)
|
|
|
|
|
// m_promptResponseTokens is related to last prompt/response not
|
|
|
|
|
// the entire context window which we can reset on regenerate prompt
|
|
|
|
|
#if defined(DEBUG)
|
|
|
|
|
qDebug() << "prompt process" << m_chat->id() << token;
|
|
|
|
|
qDebug() << "prompt process" << m_llmThread.objectName() << token;
|
|
|
|
|
#endif
|
|
|
|
|
++m_promptTokens;
|
|
|
|
|
++m_promptResponseTokens;
|
|
|
|
@ -409,7 +421,7 @@ bool ChatLLM::handleRecalculate(bool isRecalc)
|
|
|
|
|
return !m_stopGenerating;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k,
|
|
|
|
|
bool ChatLLM::prompt(const QList<QString> &collectionList, const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k,
|
|
|
|
|
float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, int n_threads)
|
|
|
|
|
{
|
|
|
|
|
if (!isModelLoaded())
|
|
|
|
@ -417,7 +429,7 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
|
|
|
|
|
|
|
|
|
|
QList<ResultInfo> databaseResults;
|
|
|
|
|
const int retrievalSize = LocalDocs::globalInstance()->retrievalSize();
|
|
|
|
|
emit requestRetrieveFromDB(m_chat->collectionList(), prompt, retrievalSize, &databaseResults); // blocks
|
|
|
|
|
emit requestRetrieveFromDB(collectionList, prompt, retrievalSize, &databaseResults); // blocks
|
|
|
|
|
emit databaseResultsChanged(databaseResults);
|
|
|
|
|
|
|
|
|
|
// Augment the prompt template with the results if any
|
|
|
|
@ -468,7 +480,7 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
|
|
|
|
|
void ChatLLM::setShouldBeLoaded(bool b)
|
|
|
|
|
{
|
|
|
|
|
#if defined(DEBUG_MODEL_LOADING)
|
|
|
|
|
qDebug() << "setShouldBeLoaded" << m_chat->id() << b << m_modelInfo.model;
|
|
|
|
|
qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_modelInfo.model;
|
|
|
|
|
#endif
|
|
|
|
|
m_shouldBeLoaded = b; // atomic
|
|
|
|
|
emit shouldBeLoadedChanged();
|
|
|
|
@ -495,7 +507,7 @@ void ChatLLM::unloadModel()
|
|
|
|
|
|
|
|
|
|
saveState();
|
|
|
|
|
#if defined(DEBUG_MODEL_LOADING)
|
|
|
|
|
qDebug() << "unloadModel" << m_chat->id() << m_modelInfo.model;
|
|
|
|
|
qDebug() << "unloadModel" << m_llmThread.objectName() << m_modelInfo.model;
|
|
|
|
|
#endif
|
|
|
|
|
LLModelStore::globalInstance()->releaseModel(m_modelInfo);
|
|
|
|
|
m_modelInfo = LLModelInfo();
|
|
|
|
@ -508,7 +520,7 @@ void ChatLLM::reloadModel()
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
|
|
#if defined(DEBUG_MODEL_LOADING)
|
|
|
|
|
qDebug() << "reloadModel" << m_chat->id() << m_modelInfo.model;
|
|
|
|
|
qDebug() << "reloadModel" << m_llmThread.objectName() << m_modelInfo.model;
|
|
|
|
|
#endif
|
|
|
|
|
if (m_modelName.isEmpty()) {
|
|
|
|
|
loadDefaultModel();
|
|
|
|
@ -547,9 +559,14 @@ void ChatLLM::generateName()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ChatLLM::handleChatIdChanged()
|
|
|
|
|
void ChatLLM::handleChatIdChanged(const QString &id)
|
|
|
|
|
{
|
|
|
|
|
m_llmThread.setObjectName(m_chat->id());
|
|
|
|
|
m_llmThread.setObjectName(id);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ChatLLM::handleDefaultModelChanged(const QString &defaultModel)
|
|
|
|
|
{
|
|
|
|
|
m_defaultModel = defaultModel;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ChatLLM::handleNamePrompt(int32_t token)
|
|
|
|
@ -605,7 +622,7 @@ bool ChatLLM::serialize(QDataStream &stream, int version)
|
|
|
|
|
QByteArray compressed = qCompress(m_state);
|
|
|
|
|
stream << compressed;
|
|
|
|
|
#if defined(DEBUG)
|
|
|
|
|
qDebug() << "serialize" << m_chat->id() << m_state.size();
|
|
|
|
|
qDebug() << "serialize" << m_llmThread.objectName() << m_state.size();
|
|
|
|
|
#endif
|
|
|
|
|
return stream.status() == QDataStream::Ok;
|
|
|
|
|
}
|
|
|
|
@ -645,7 +662,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version)
|
|
|
|
|
stream >> m_state;
|
|
|
|
|
}
|
|
|
|
|
#if defined(DEBUG)
|
|
|
|
|
qDebug() << "deserialize" << m_chat->id();
|
|
|
|
|
qDebug() << "deserialize" << m_llmThread.objectName();
|
|
|
|
|
#endif
|
|
|
|
|
return stream.status() == QDataStream::Ok;
|
|
|
|
|
}
|
|
|
|
@ -667,7 +684,7 @@ void ChatLLM::saveState()
|
|
|
|
|
const size_t stateSize = m_modelInfo.model->stateSize();
|
|
|
|
|
m_state.resize(stateSize);
|
|
|
|
|
#if defined(DEBUG)
|
|
|
|
|
qDebug() << "saveState" << m_chat->id() << "size:" << m_state.size();
|
|
|
|
|
qDebug() << "saveState" << m_llmThread.objectName() << "size:" << m_state.size();
|
|
|
|
|
#endif
|
|
|
|
|
m_modelInfo.model->saveState(static_cast<uint8_t*>(reinterpret_cast<void*>(m_state.data())));
|
|
|
|
|
}
|
|
|
|
@ -690,7 +707,7 @@ void ChatLLM::restoreState()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#if defined(DEBUG)
|
|
|
|
|
qDebug() << "restoreState" << m_chat->id() << "size:" << m_state.size();
|
|
|
|
|
qDebug() << "restoreState" << m_llmThread.objectName() << "size:" << m_state.size();
|
|
|
|
|
#endif
|
|
|
|
|
m_modelInfo.model->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
|
|
|
|
|
m_state.clear();
|
|
|
|
|