|
|
|
@ -69,6 +69,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
|
|
|
|
|
, m_forceMetal(MySettings::globalInstance()->forceMetal())
|
|
|
|
|
, m_reloadingToChangeVariant(false)
|
|
|
|
|
, m_processedSystemPrompt(false)
|
|
|
|
|
, m_restoreStateFromText(false)
|
|
|
|
|
{
|
|
|
|
|
moveToThread(&m_llmThread);
|
|
|
|
|
connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup);
|
|
|
|
@ -726,7 +727,35 @@ bool ChatLLM::handleSystemRecalculate(bool isRecalc)
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ChatLLM::serialize(QDataStream &stream, int version)
|
|
|
|
|
bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token)
|
|
|
|
|
{
|
|
|
|
|
#if defined(DEBUG)
|
|
|
|
|
qDebug() << "restore state from text prompt" << m_llmThread.objectName() << token << m_stopGenerating;
|
|
|
|
|
#endif
|
|
|
|
|
Q_UNUSED(token);
|
|
|
|
|
return !m_stopGenerating;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ChatLLM::handleRestoreStateFromTextResponse(int32_t token, const std::string &response)
|
|
|
|
|
{
|
|
|
|
|
#if defined(DEBUG)
|
|
|
|
|
qDebug() << "restore state from text response" << m_llmThread.objectName() << token << response << m_stopGenerating;
|
|
|
|
|
#endif
|
|
|
|
|
Q_UNUSED(token);
|
|
|
|
|
Q_UNUSED(response);
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ChatLLM::handleRestoreStateFromTextRecalculate(bool isRecalc)
|
|
|
|
|
{
|
|
|
|
|
#if defined(DEBUG)
|
|
|
|
|
qDebug() << "restore state from text recalc" << m_llmThread.objectName() << isRecalc;
|
|
|
|
|
#endif
|
|
|
|
|
Q_UNUSED(isRecalc);
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
|
|
|
|
|
{
|
|
|
|
|
if (version > 1) {
|
|
|
|
|
stream << m_llModelType;
|
|
|
|
@ -741,6 +770,14 @@ bool ChatLLM::serialize(QDataStream &stream, int version)
|
|
|
|
|
stream << response();
|
|
|
|
|
stream << generatedName();
|
|
|
|
|
stream << m_promptResponseTokens;
|
|
|
|
|
|
|
|
|
|
if (!serializeKV) {
|
|
|
|
|
#if defined(DEBUG)
|
|
|
|
|
qDebug() << "serialize" << m_llmThread.objectName() << m_state.size();
|
|
|
|
|
#endif
|
|
|
|
|
return stream.status() == QDataStream::Ok;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (version <= 3) {
|
|
|
|
|
int responseLogits;
|
|
|
|
|
stream << responseLogits;
|
|
|
|
@ -759,7 +796,7 @@ bool ChatLLM::serialize(QDataStream &stream, int version)
|
|
|
|
|
return stream.status() == QDataStream::Ok;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ChatLLM::deserialize(QDataStream &stream, int version)
|
|
|
|
|
bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV)
|
|
|
|
|
{
|
|
|
|
|
if (version > 1) {
|
|
|
|
|
int internalStateVersion;
|
|
|
|
@ -773,26 +810,60 @@ bool ChatLLM::deserialize(QDataStream &stream, int version)
|
|
|
|
|
stream >> nameResponse;
|
|
|
|
|
m_nameResponse = nameResponse.toStdString();
|
|
|
|
|
stream >> m_promptResponseTokens;
|
|
|
|
|
|
|
|
|
|
// If we do not deserialize the KV or it is discarded, then we need to restore the state from the
|
|
|
|
|
// text only. This will be a costly operation, but the chat has to be restored from the text archive
|
|
|
|
|
// alone.
|
|
|
|
|
m_restoreStateFromText = !deserializeKV || discardKV;
|
|
|
|
|
|
|
|
|
|
if (!deserializeKV) {
|
|
|
|
|
#if defined(DEBUG)
|
|
|
|
|
qDebug() << "deserialize" << m_llmThread.objectName();
|
|
|
|
|
#endif
|
|
|
|
|
return stream.status() == QDataStream::Ok;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (version <= 3) {
|
|
|
|
|
int responseLogits;
|
|
|
|
|
stream >> responseLogits;
|
|
|
|
|
}
|
|
|
|
|
stream >> m_ctx.n_past;
|
|
|
|
|
|
|
|
|
|
int32_t n_past;
|
|
|
|
|
stream >> n_past;
|
|
|
|
|
if (!discardKV) m_ctx.n_past = n_past;
|
|
|
|
|
|
|
|
|
|
quint64 logitsSize;
|
|
|
|
|
stream >> logitsSize;
|
|
|
|
|
m_ctx.logits.resize(logitsSize);
|
|
|
|
|
stream.readRawData(reinterpret_cast<char*>(m_ctx.logits.data()), logitsSize * sizeof(float));
|
|
|
|
|
if (!discardKV) {
|
|
|
|
|
m_ctx.logits.resize(logitsSize);
|
|
|
|
|
stream.readRawData(reinterpret_cast<char*>(m_ctx.logits.data()), logitsSize * sizeof(float));
|
|
|
|
|
} else {
|
|
|
|
|
stream.skipRawData(logitsSize * sizeof(float));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
quint64 tokensSize;
|
|
|
|
|
stream >> tokensSize;
|
|
|
|
|
m_ctx.tokens.resize(tokensSize);
|
|
|
|
|
stream.readRawData(reinterpret_cast<char*>(m_ctx.tokens.data()), tokensSize * sizeof(int));
|
|
|
|
|
if (!discardKV) {
|
|
|
|
|
m_ctx.tokens.resize(tokensSize);
|
|
|
|
|
stream.readRawData(reinterpret_cast<char*>(m_ctx.tokens.data()), tokensSize * sizeof(int));
|
|
|
|
|
} else {
|
|
|
|
|
stream.skipRawData(tokensSize * sizeof(int));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (version > 0) {
|
|
|
|
|
QByteArray compressed;
|
|
|
|
|
stream >> compressed;
|
|
|
|
|
m_state = qUncompress(compressed);
|
|
|
|
|
if (!discardKV)
|
|
|
|
|
m_state = qUncompress(compressed);
|
|
|
|
|
} else {
|
|
|
|
|
stream >> m_state;
|
|
|
|
|
if (!discardKV)
|
|
|
|
|
stream >> m_state;
|
|
|
|
|
else {
|
|
|
|
|
QByteArray state;
|
|
|
|
|
stream >> m_state;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#if defined(DEBUG)
|
|
|
|
|
qDebug() << "deserialize" << m_llmThread.objectName();
|
|
|
|
|
#endif
|
|
|
|
@ -823,7 +894,7 @@ void ChatLLM::saveState()
|
|
|
|
|
|
|
|
|
|
void ChatLLM::restoreState()
|
|
|
|
|
{
|
|
|
|
|
if (!isModelLoaded() || m_state.isEmpty())
|
|
|
|
|
if (!isModelLoaded())
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
|
|
if (m_llModelType == LLModelType::CHATGPT_) {
|
|
|
|
@ -838,10 +909,19 @@ void ChatLLM::restoreState()
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (m_restoreStateFromText) {
|
|
|
|
|
Q_ASSERT(m_state.isEmpty());
|
|
|
|
|
processRestoreStateFromText();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#if defined(DEBUG)
|
|
|
|
|
qDebug() << "restoreState" << m_llmThread.objectName() << "size:" << m_state.size();
|
|
|
|
|
#endif
|
|
|
|
|
m_processedSystemPrompt = true;
|
|
|
|
|
|
|
|
|
|
if (m_state.isEmpty())
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
|
|
m_llModelInfo.model->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
|
|
|
|
|
m_state.clear();
|
|
|
|
|
m_state.resize(0);
|
|
|
|
@ -859,7 +939,10 @@ void ChatLLM::processSystemPrompt()
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Start with a whole new context
|
|
|
|
|
m_stopGenerating = false;
|
|
|
|
|
m_ctx = LLModel::PromptContext();
|
|
|
|
|
|
|
|
|
|
auto promptFunc = std::bind(&ChatLLM::handleSystemPrompt, this, std::placeholders::_1);
|
|
|
|
|
auto responseFunc = std::bind(&ChatLLM::handleSystemResponse, this, std::placeholders::_1,
|
|
|
|
|
std::placeholders::_2);
|
|
|
|
@ -890,5 +973,54 @@ void ChatLLM::processSystemPrompt()
|
|
|
|
|
printf("\n");
|
|
|
|
|
fflush(stdout);
|
|
|
|
|
#endif
|
|
|
|
|
m_processedSystemPrompt = true;
|
|
|
|
|
|
|
|
|
|
m_processedSystemPrompt = !m_stopGenerating;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ChatLLM::processRestoreStateFromText()
|
|
|
|
|
{
|
|
|
|
|
Q_ASSERT(isModelLoaded());
|
|
|
|
|
if (!isModelLoaded() || !m_restoreStateFromText || m_isServer)
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
|
|
m_isRecalc = true;
|
|
|
|
|
emit recalcChanged();
|
|
|
|
|
|
|
|
|
|
m_stopGenerating = false;
|
|
|
|
|
m_ctx = LLModel::PromptContext();
|
|
|
|
|
|
|
|
|
|
auto promptFunc = std::bind(&ChatLLM::handleRestoreStateFromTextPrompt, this, std::placeholders::_1);
|
|
|
|
|
auto responseFunc = std::bind(&ChatLLM::handleRestoreStateFromTextResponse, this, std::placeholders::_1,
|
|
|
|
|
std::placeholders::_2);
|
|
|
|
|
auto recalcFunc = std::bind(&ChatLLM::handleRestoreStateFromTextRecalculate, this, std::placeholders::_1);
|
|
|
|
|
|
|
|
|
|
const QString promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
|
|
|
|
|
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
|
|
|
|
|
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
|
|
|
|
|
const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo);
|
|
|
|
|
const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo);
|
|
|
|
|
const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo);
|
|
|
|
|
const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo);
|
|
|
|
|
const int32_t repeat_penalty_tokens = MySettings::globalInstance()->modelRepeatPenaltyTokens(m_modelInfo);
|
|
|
|
|
int n_threads = MySettings::globalInstance()->threadCount();
|
|
|
|
|
m_ctx.n_predict = n_predict;
|
|
|
|
|
m_ctx.top_k = top_k;
|
|
|
|
|
m_ctx.top_p = top_p;
|
|
|
|
|
m_ctx.temp = temp;
|
|
|
|
|
m_ctx.n_batch = n_batch;
|
|
|
|
|
m_ctx.repeat_penalty = repeat_penalty;
|
|
|
|
|
m_ctx.repeat_last_n = repeat_penalty_tokens;
|
|
|
|
|
m_llModelInfo.model->setThreadCount(n_threads);
|
|
|
|
|
for (auto pair : m_stateFromText) {
|
|
|
|
|
const QString str = pair.first == "Prompt: " ? promptTemplate.arg(pair.second) : pair.second;
|
|
|
|
|
m_llModelInfo.model->prompt(str.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!m_stopGenerating) {
|
|
|
|
|
m_restoreStateFromText = false;
|
|
|
|
|
m_stateFromText.clear();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
m_isRecalc = false;
|
|
|
|
|
emit recalcChanged();
|
|
|
|
|
}
|
|
|
|
|