Restore state from text if necessary.

This commit is contained in:
Adam Treat 2023-10-10 16:43:02 -04:00 committed by AT
parent 35f9cdb70a
commit f0742c22f4
5 changed files with 177 additions and 24 deletions

View File

@ -385,7 +385,7 @@ bool Chat::serialize(QDataStream &stream, int version) const
stream << m_modelInfo.filename();
if (version > 2)
stream << m_collections;
if (!m_llmodel->serialize(stream, version))
if (!m_llmodel->serialize(stream, version, true /*serializeKV*/))
return false;
if (!m_chatModel->serialize(stream, version))
return false;
@ -404,29 +404,36 @@ bool Chat::deserialize(QDataStream &stream, int version)
QString modelId;
stream >> modelId;
if (version > 4) {
if (!ModelList::globalInstance()->contains(modelId))
return false;
m_modelInfo = ModelList::globalInstance()->modelInfo(modelId);
if (ModelList::globalInstance()->contains(modelId))
m_modelInfo = ModelList::globalInstance()->modelInfo(modelId);
} else {
if (!ModelList::globalInstance()->containsByFilename(modelId))
return false;
m_modelInfo = ModelList::globalInstance()->modelInfoByFilename(modelId);
if (ModelList::globalInstance()->containsByFilename(modelId))
m_modelInfo = ModelList::globalInstance()->modelInfoByFilename(modelId);
}
emit modelInfoChanged();
if (!m_modelInfo.id().isEmpty())
emit modelInfoChanged();
bool deserializeKV = true; // make this a setting
bool discardKV = m_modelInfo.id().isEmpty();
// Prior to version 2 gptj models had a bug that fixed the kv_cache to F32 instead of F16 so
// unfortunately, we cannot deserialize these
if (version < 2 && m_modelInfo.filename().contains("gpt4all-j"))
return false;
discardKV = true;
if (version > 2) {
stream >> m_collections;
emit collectionListChanged(m_collections);
}
m_llmodel->setModelInfo(m_modelInfo);
if (!m_llmodel->deserialize(stream, version))
if (!m_llmodel->deserialize(stream, version, deserializeKV, discardKV))
return false;
if (!m_chatModel->deserialize(stream, version))
return false;
if (!deserializeKV || discardKV)
m_llmodel->setStateFromText(m_chatModel->text());
emit chatModelChanged();
return stream.status() == QDataStream::Ok;
}

View File

@ -232,7 +232,6 @@ void ChatsRestoreThread::run()
chat->moveToThread(qApp->thread());
if (!chat->deserialize(in, version)) {
qWarning() << "ERROR: Couldn't deserialize chat from file:" << file.fileName();
file.remove();
} else {
emit chatRestored(chat);
}

View File

@ -69,6 +69,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
, m_forceMetal(MySettings::globalInstance()->forceMetal())
, m_reloadingToChangeVariant(false)
, m_processedSystemPrompt(false)
, m_restoreStateFromText(false)
{
moveToThread(&m_llmThread);
connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup);
@ -726,7 +727,35 @@ bool ChatLLM::handleSystemRecalculate(bool isRecalc)
return false;
}
bool ChatLLM::serialize(QDataStream &stream, int version)
bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token)
{
#if defined(DEBUG)
qDebug() << "restore state from text prompt" << m_llmThread.objectName() << token << m_stopGenerating;
#endif
Q_UNUSED(token);
return !m_stopGenerating;
}
bool ChatLLM::handleRestoreStateFromTextResponse(int32_t token, const std::string &response)
{
#if defined(DEBUG)
qDebug() << "restore state from text response" << m_llmThread.objectName() << token << response << m_stopGenerating;
#endif
Q_UNUSED(token);
Q_UNUSED(response);
return false;
}
bool ChatLLM::handleRestoreStateFromTextRecalculate(bool isRecalc)
{
#if defined(DEBUG)
qDebug() << "restore state from text recalc" << m_llmThread.objectName() << isRecalc;
#endif
Q_UNUSED(isRecalc);
return false;
}
bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
{
if (version > 1) {
stream << m_llModelType;
@ -741,6 +770,14 @@ bool ChatLLM::serialize(QDataStream &stream, int version)
stream << response();
stream << generatedName();
stream << m_promptResponseTokens;
if (!serializeKV) {
#if defined(DEBUG)
qDebug() << "serialize" << m_llmThread.objectName() << m_state.size();
#endif
return stream.status() == QDataStream::Ok;
}
if (version <= 3) {
int responseLogits;
stream << responseLogits;
@ -759,7 +796,7 @@ bool ChatLLM::serialize(QDataStream &stream, int version)
return stream.status() == QDataStream::Ok;
}
bool ChatLLM::deserialize(QDataStream &stream, int version)
bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV)
{
if (version > 1) {
int internalStateVersion;
@ -773,26 +810,60 @@ bool ChatLLM::deserialize(QDataStream &stream, int version)
stream >> nameResponse;
m_nameResponse = nameResponse.toStdString();
stream >> m_promptResponseTokens;
// If we do not deserialize the KV or it is discarded, then we need to restore the state from the
// text only. This will be a costly operation, but the chat has to be restored from the text archive
// alone.
m_restoreStateFromText = !deserializeKV || discardKV;
if (!deserializeKV) {
#if defined(DEBUG)
qDebug() << "deserialize" << m_llmThread.objectName();
#endif
return stream.status() == QDataStream::Ok;
}
if (version <= 3) {
int responseLogits;
stream >> responseLogits;
}
stream >> m_ctx.n_past;
int32_t n_past;
stream >> n_past;
if (!discardKV) m_ctx.n_past = n_past;
quint64 logitsSize;
stream >> logitsSize;
m_ctx.logits.resize(logitsSize);
stream.readRawData(reinterpret_cast<char*>(m_ctx.logits.data()), logitsSize * sizeof(float));
if (!discardKV) {
m_ctx.logits.resize(logitsSize);
stream.readRawData(reinterpret_cast<char*>(m_ctx.logits.data()), logitsSize * sizeof(float));
} else {
stream.skipRawData(logitsSize * sizeof(float));
}
quint64 tokensSize;
stream >> tokensSize;
m_ctx.tokens.resize(tokensSize);
stream.readRawData(reinterpret_cast<char*>(m_ctx.tokens.data()), tokensSize * sizeof(int));
if (!discardKV) {
m_ctx.tokens.resize(tokensSize);
stream.readRawData(reinterpret_cast<char*>(m_ctx.tokens.data()), tokensSize * sizeof(int));
} else {
stream.skipRawData(tokensSize * sizeof(int));
}
if (version > 0) {
QByteArray compressed;
stream >> compressed;
m_state = qUncompress(compressed);
if (!discardKV)
m_state = qUncompress(compressed);
} else {
stream >> m_state;
if (!discardKV)
stream >> m_state;
else {
QByteArray state;
stream >> m_state;
}
}
#if defined(DEBUG)
qDebug() << "deserialize" << m_llmThread.objectName();
#endif
@ -823,7 +894,7 @@ void ChatLLM::saveState()
void ChatLLM::restoreState()
{
if (!isModelLoaded() || m_state.isEmpty())
if (!isModelLoaded())
return;
if (m_llModelType == LLModelType::CHATGPT_) {
@ -838,10 +909,19 @@ void ChatLLM::restoreState()
return;
}
if (m_restoreStateFromText) {
Q_ASSERT(m_state.isEmpty());
processRestoreStateFromText();
}
#if defined(DEBUG)
qDebug() << "restoreState" << m_llmThread.objectName() << "size:" << m_state.size();
#endif
m_processedSystemPrompt = true;
if (m_state.isEmpty())
return;
m_llModelInfo.model->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
m_state.clear();
m_state.resize(0);
@ -859,7 +939,10 @@ void ChatLLM::processSystemPrompt()
return;
}
// Start with a whole new context
m_stopGenerating = false;
m_ctx = LLModel::PromptContext();
auto promptFunc = std::bind(&ChatLLM::handleSystemPrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleSystemResponse, this, std::placeholders::_1,
std::placeholders::_2);
@ -890,5 +973,54 @@ void ChatLLM::processSystemPrompt()
printf("\n");
fflush(stdout);
#endif
m_processedSystemPrompt = true;
m_processedSystemPrompt = !m_stopGenerating;
}
void ChatLLM::processRestoreStateFromText()
{
Q_ASSERT(isModelLoaded());
if (!isModelLoaded() || !m_restoreStateFromText || m_isServer)
return;
m_isRecalc = true;
emit recalcChanged();
m_stopGenerating = false;
m_ctx = LLModel::PromptContext();
auto promptFunc = std::bind(&ChatLLM::handleRestoreStateFromTextPrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleRestoreStateFromTextResponse, this, std::placeholders::_1,
std::placeholders::_2);
auto recalcFunc = std::bind(&ChatLLM::handleRestoreStateFromTextRecalculate, this, std::placeholders::_1);
const QString promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo);
const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo);
const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo);
const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo);
const int32_t repeat_penalty_tokens = MySettings::globalInstance()->modelRepeatPenaltyTokens(m_modelInfo);
int n_threads = MySettings::globalInstance()->threadCount();
m_ctx.n_predict = n_predict;
m_ctx.top_k = top_k;
m_ctx.top_p = top_p;
m_ctx.temp = temp;
m_ctx.n_batch = n_batch;
m_ctx.repeat_penalty = repeat_penalty;
m_ctx.repeat_last_n = repeat_penalty_tokens;
m_llModelInfo.model->setThreadCount(n_threads);
for (auto pair : m_stateFromText) {
const QString str = pair.first == "Prompt: " ? promptTemplate.arg(pair.second) : pair.second;
m_llModelInfo.model->prompt(str.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
}
if (!m_stopGenerating) {
m_restoreStateFromText = false;
m_stateFromText.clear();
}
m_isRecalc = false;
emit recalcChanged();
}

View File

@ -92,8 +92,9 @@ public:
QString generatedName() const { return QString::fromStdString(m_nameResponse); }
bool serialize(QDataStream &stream, int version);
bool deserialize(QDataStream &stream, int version);
bool serialize(QDataStream &stream, int version, bool serializeKV);
bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV);
void setStateFromText(const QVector<QPair<QString, QString>> &stateFromText) { m_stateFromText = stateFromText; }
public Q_SLOTS:
bool prompt(const QList<QString> &collectionList, const QString &prompt);
@ -110,6 +111,7 @@ public Q_SLOTS:
void handleForceMetalChanged(bool forceMetal);
void handleDeviceChanged();
void processSystemPrompt();
void processRestoreStateFromText();
Q_SIGNALS:
void recalcChanged();
@ -144,6 +146,9 @@ protected:
bool handleSystemPrompt(int32_t token);
bool handleSystemResponse(int32_t token, const std::string &response);
bool handleSystemRecalculate(bool isRecalc);
bool handleRestoreStateFromTextPrompt(int32_t token);
bool handleRestoreStateFromTextResponse(int32_t token, const std::string &response);
bool handleRestoreStateFromTextRecalculate(bool isRecalc);
void saveState();
void restoreState();
@ -168,6 +173,8 @@ private:
bool m_forceMetal;
bool m_reloadingToChangeVariant;
bool m_processedSystemPrompt;
bool m_restoreStateFromText;
QVector<QPair<QString, QString>> m_stateFromText;
};
#endif // CHATLLM_H

View File

@ -285,6 +285,14 @@ public:
return stream.status() == QDataStream::Ok;
}
QVector<QPair<QString, QString>> text() const
{
QVector<QPair<QString, QString>> result;
for (const auto &c : m_chatItems)
result << qMakePair(c.name, c.value);
return result;
}
Q_SIGNALS:
void countChanged();