mirror of
https://github.com/nomic-ai/gpt4all
synced 2024-11-06 09:20:33 +00:00
Restore state from text if necessary.
This commit is contained in:
parent
35f9cdb70a
commit
f0742c22f4
@ -385,7 +385,7 @@ bool Chat::serialize(QDataStream &stream, int version) const
|
||||
stream << m_modelInfo.filename();
|
||||
if (version > 2)
|
||||
stream << m_collections;
|
||||
if (!m_llmodel->serialize(stream, version))
|
||||
if (!m_llmodel->serialize(stream, version, true /*serializeKV*/))
|
||||
return false;
|
||||
if (!m_chatModel->serialize(stream, version))
|
||||
return false;
|
||||
@ -404,29 +404,36 @@ bool Chat::deserialize(QDataStream &stream, int version)
|
||||
QString modelId;
|
||||
stream >> modelId;
|
||||
if (version > 4) {
|
||||
if (!ModelList::globalInstance()->contains(modelId))
|
||||
return false;
|
||||
m_modelInfo = ModelList::globalInstance()->modelInfo(modelId);
|
||||
if (ModelList::globalInstance()->contains(modelId))
|
||||
m_modelInfo = ModelList::globalInstance()->modelInfo(modelId);
|
||||
} else {
|
||||
if (!ModelList::globalInstance()->containsByFilename(modelId))
|
||||
return false;
|
||||
m_modelInfo = ModelList::globalInstance()->modelInfoByFilename(modelId);
|
||||
if (ModelList::globalInstance()->containsByFilename(modelId))
|
||||
m_modelInfo = ModelList::globalInstance()->modelInfoByFilename(modelId);
|
||||
}
|
||||
emit modelInfoChanged();
|
||||
if (!m_modelInfo.id().isEmpty())
|
||||
emit modelInfoChanged();
|
||||
|
||||
bool deserializeKV = true; // make this a setting
|
||||
bool discardKV = m_modelInfo.id().isEmpty();
|
||||
|
||||
// Prior to version 2 gptj models had a bug that fixed the kv_cache to F32 instead of F16 so
|
||||
// unfortunately, we cannot deserialize these
|
||||
if (version < 2 && m_modelInfo.filename().contains("gpt4all-j"))
|
||||
return false;
|
||||
discardKV = true;
|
||||
|
||||
if (version > 2) {
|
||||
stream >> m_collections;
|
||||
emit collectionListChanged(m_collections);
|
||||
}
|
||||
m_llmodel->setModelInfo(m_modelInfo);
|
||||
if (!m_llmodel->deserialize(stream, version))
|
||||
if (!m_llmodel->deserialize(stream, version, deserializeKV, discardKV))
|
||||
return false;
|
||||
if (!m_chatModel->deserialize(stream, version))
|
||||
return false;
|
||||
|
||||
if (!deserializeKV || discardKV)
|
||||
m_llmodel->setStateFromText(m_chatModel->text());
|
||||
|
||||
emit chatModelChanged();
|
||||
return stream.status() == QDataStream::Ok;
|
||||
}
|
||||
|
@ -232,7 +232,6 @@ void ChatsRestoreThread::run()
|
||||
chat->moveToThread(qApp->thread());
|
||||
if (!chat->deserialize(in, version)) {
|
||||
qWarning() << "ERROR: Couldn't deserialize chat from file:" << file.fileName();
|
||||
file.remove();
|
||||
} else {
|
||||
emit chatRestored(chat);
|
||||
}
|
||||
|
@ -69,6 +69,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
|
||||
, m_forceMetal(MySettings::globalInstance()->forceMetal())
|
||||
, m_reloadingToChangeVariant(false)
|
||||
, m_processedSystemPrompt(false)
|
||||
, m_restoreStateFromText(false)
|
||||
{
|
||||
moveToThread(&m_llmThread);
|
||||
connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup);
|
||||
@ -726,7 +727,35 @@ bool ChatLLM::handleSystemRecalculate(bool isRecalc)
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ChatLLM::serialize(QDataStream &stream, int version)
|
||||
bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "restore state from text prompt" << m_llmThread.objectName() << token << m_stopGenerating;
|
||||
#endif
|
||||
Q_UNUSED(token);
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleRestoreStateFromTextResponse(int32_t token, const std::string &response)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "restore state from text response" << m_llmThread.objectName() << token << response << m_stopGenerating;
|
||||
#endif
|
||||
Q_UNUSED(token);
|
||||
Q_UNUSED(response);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleRestoreStateFromTextRecalculate(bool isRecalc)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "restore state from text recalc" << m_llmThread.objectName() << isRecalc;
|
||||
#endif
|
||||
Q_UNUSED(isRecalc);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
|
||||
{
|
||||
if (version > 1) {
|
||||
stream << m_llModelType;
|
||||
@ -741,6 +770,14 @@ bool ChatLLM::serialize(QDataStream &stream, int version)
|
||||
stream << response();
|
||||
stream << generatedName();
|
||||
stream << m_promptResponseTokens;
|
||||
|
||||
if (!serializeKV) {
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "serialize" << m_llmThread.objectName() << m_state.size();
|
||||
#endif
|
||||
return stream.status() == QDataStream::Ok;
|
||||
}
|
||||
|
||||
if (version <= 3) {
|
||||
int responseLogits;
|
||||
stream << responseLogits;
|
||||
@ -759,7 +796,7 @@ bool ChatLLM::serialize(QDataStream &stream, int version)
|
||||
return stream.status() == QDataStream::Ok;
|
||||
}
|
||||
|
||||
bool ChatLLM::deserialize(QDataStream &stream, int version)
|
||||
bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV)
|
||||
{
|
||||
if (version > 1) {
|
||||
int internalStateVersion;
|
||||
@ -773,26 +810,60 @@ bool ChatLLM::deserialize(QDataStream &stream, int version)
|
||||
stream >> nameResponse;
|
||||
m_nameResponse = nameResponse.toStdString();
|
||||
stream >> m_promptResponseTokens;
|
||||
|
||||
// If we do not deserialize the KV or it is discarded, then we need to restore the state from the
|
||||
// text only. This will be a costly operation, but the chat has to be restored from the text archive
|
||||
// alone.
|
||||
m_restoreStateFromText = !deserializeKV || discardKV;
|
||||
|
||||
if (!deserializeKV) {
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "deserialize" << m_llmThread.objectName();
|
||||
#endif
|
||||
return stream.status() == QDataStream::Ok;
|
||||
}
|
||||
|
||||
if (version <= 3) {
|
||||
int responseLogits;
|
||||
stream >> responseLogits;
|
||||
}
|
||||
stream >> m_ctx.n_past;
|
||||
|
||||
int32_t n_past;
|
||||
stream >> n_past;
|
||||
if (!discardKV) m_ctx.n_past = n_past;
|
||||
|
||||
quint64 logitsSize;
|
||||
stream >> logitsSize;
|
||||
m_ctx.logits.resize(logitsSize);
|
||||
stream.readRawData(reinterpret_cast<char*>(m_ctx.logits.data()), logitsSize * sizeof(float));
|
||||
if (!discardKV) {
|
||||
m_ctx.logits.resize(logitsSize);
|
||||
stream.readRawData(reinterpret_cast<char*>(m_ctx.logits.data()), logitsSize * sizeof(float));
|
||||
} else {
|
||||
stream.skipRawData(logitsSize * sizeof(float));
|
||||
}
|
||||
|
||||
quint64 tokensSize;
|
||||
stream >> tokensSize;
|
||||
m_ctx.tokens.resize(tokensSize);
|
||||
stream.readRawData(reinterpret_cast<char*>(m_ctx.tokens.data()), tokensSize * sizeof(int));
|
||||
if (!discardKV) {
|
||||
m_ctx.tokens.resize(tokensSize);
|
||||
stream.readRawData(reinterpret_cast<char*>(m_ctx.tokens.data()), tokensSize * sizeof(int));
|
||||
} else {
|
||||
stream.skipRawData(tokensSize * sizeof(int));
|
||||
}
|
||||
|
||||
if (version > 0) {
|
||||
QByteArray compressed;
|
||||
stream >> compressed;
|
||||
m_state = qUncompress(compressed);
|
||||
if (!discardKV)
|
||||
m_state = qUncompress(compressed);
|
||||
} else {
|
||||
stream >> m_state;
|
||||
if (!discardKV)
|
||||
stream >> m_state;
|
||||
else {
|
||||
QByteArray state;
|
||||
stream >> m_state;
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "deserialize" << m_llmThread.objectName();
|
||||
#endif
|
||||
@ -823,7 +894,7 @@ void ChatLLM::saveState()
|
||||
|
||||
void ChatLLM::restoreState()
|
||||
{
|
||||
if (!isModelLoaded() || m_state.isEmpty())
|
||||
if (!isModelLoaded())
|
||||
return;
|
||||
|
||||
if (m_llModelType == LLModelType::CHATGPT_) {
|
||||
@ -838,10 +909,19 @@ void ChatLLM::restoreState()
|
||||
return;
|
||||
}
|
||||
|
||||
if (m_restoreStateFromText) {
|
||||
Q_ASSERT(m_state.isEmpty());
|
||||
processRestoreStateFromText();
|
||||
}
|
||||
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "restoreState" << m_llmThread.objectName() << "size:" << m_state.size();
|
||||
#endif
|
||||
m_processedSystemPrompt = true;
|
||||
|
||||
if (m_state.isEmpty())
|
||||
return;
|
||||
|
||||
m_llModelInfo.model->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
|
||||
m_state.clear();
|
||||
m_state.resize(0);
|
||||
@ -859,7 +939,10 @@ void ChatLLM::processSystemPrompt()
|
||||
return;
|
||||
}
|
||||
|
||||
// Start with a whole new context
|
||||
m_stopGenerating = false;
|
||||
m_ctx = LLModel::PromptContext();
|
||||
|
||||
auto promptFunc = std::bind(&ChatLLM::handleSystemPrompt, this, std::placeholders::_1);
|
||||
auto responseFunc = std::bind(&ChatLLM::handleSystemResponse, this, std::placeholders::_1,
|
||||
std::placeholders::_2);
|
||||
@ -890,5 +973,54 @@ void ChatLLM::processSystemPrompt()
|
||||
printf("\n");
|
||||
fflush(stdout);
|
||||
#endif
|
||||
m_processedSystemPrompt = true;
|
||||
|
||||
m_processedSystemPrompt = !m_stopGenerating;
|
||||
}
|
||||
|
||||
void ChatLLM::processRestoreStateFromText()
|
||||
{
|
||||
Q_ASSERT(isModelLoaded());
|
||||
if (!isModelLoaded() || !m_restoreStateFromText || m_isServer)
|
||||
return;
|
||||
|
||||
m_isRecalc = true;
|
||||
emit recalcChanged();
|
||||
|
||||
m_stopGenerating = false;
|
||||
m_ctx = LLModel::PromptContext();
|
||||
|
||||
auto promptFunc = std::bind(&ChatLLM::handleRestoreStateFromTextPrompt, this, std::placeholders::_1);
|
||||
auto responseFunc = std::bind(&ChatLLM::handleRestoreStateFromTextResponse, this, std::placeholders::_1,
|
||||
std::placeholders::_2);
|
||||
auto recalcFunc = std::bind(&ChatLLM::handleRestoreStateFromTextRecalculate, this, std::placeholders::_1);
|
||||
|
||||
const QString promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
|
||||
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
|
||||
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
|
||||
const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo);
|
||||
const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo);
|
||||
const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo);
|
||||
const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo);
|
||||
const int32_t repeat_penalty_tokens = MySettings::globalInstance()->modelRepeatPenaltyTokens(m_modelInfo);
|
||||
int n_threads = MySettings::globalInstance()->threadCount();
|
||||
m_ctx.n_predict = n_predict;
|
||||
m_ctx.top_k = top_k;
|
||||
m_ctx.top_p = top_p;
|
||||
m_ctx.temp = temp;
|
||||
m_ctx.n_batch = n_batch;
|
||||
m_ctx.repeat_penalty = repeat_penalty;
|
||||
m_ctx.repeat_last_n = repeat_penalty_tokens;
|
||||
m_llModelInfo.model->setThreadCount(n_threads);
|
||||
for (auto pair : m_stateFromText) {
|
||||
const QString str = pair.first == "Prompt: " ? promptTemplate.arg(pair.second) : pair.second;
|
||||
m_llModelInfo.model->prompt(str.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
|
||||
}
|
||||
|
||||
if (!m_stopGenerating) {
|
||||
m_restoreStateFromText = false;
|
||||
m_stateFromText.clear();
|
||||
}
|
||||
|
||||
m_isRecalc = false;
|
||||
emit recalcChanged();
|
||||
}
|
||||
|
@ -92,8 +92,9 @@ public:
|
||||
|
||||
QString generatedName() const { return QString::fromStdString(m_nameResponse); }
|
||||
|
||||
bool serialize(QDataStream &stream, int version);
|
||||
bool deserialize(QDataStream &stream, int version);
|
||||
bool serialize(QDataStream &stream, int version, bool serializeKV);
|
||||
bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV);
|
||||
void setStateFromText(const QVector<QPair<QString, QString>> &stateFromText) { m_stateFromText = stateFromText; }
|
||||
|
||||
public Q_SLOTS:
|
||||
bool prompt(const QList<QString> &collectionList, const QString &prompt);
|
||||
@ -110,6 +111,7 @@ public Q_SLOTS:
|
||||
void handleForceMetalChanged(bool forceMetal);
|
||||
void handleDeviceChanged();
|
||||
void processSystemPrompt();
|
||||
void processRestoreStateFromText();
|
||||
|
||||
Q_SIGNALS:
|
||||
void recalcChanged();
|
||||
@ -144,6 +146,9 @@ protected:
|
||||
bool handleSystemPrompt(int32_t token);
|
||||
bool handleSystemResponse(int32_t token, const std::string &response);
|
||||
bool handleSystemRecalculate(bool isRecalc);
|
||||
bool handleRestoreStateFromTextPrompt(int32_t token);
|
||||
bool handleRestoreStateFromTextResponse(int32_t token, const std::string &response);
|
||||
bool handleRestoreStateFromTextRecalculate(bool isRecalc);
|
||||
void saveState();
|
||||
void restoreState();
|
||||
|
||||
@ -168,6 +173,8 @@ private:
|
||||
bool m_forceMetal;
|
||||
bool m_reloadingToChangeVariant;
|
||||
bool m_processedSystemPrompt;
|
||||
bool m_restoreStateFromText;
|
||||
QVector<QPair<QString, QString>> m_stateFromText;
|
||||
};
|
||||
|
||||
#endif // CHATLLM_H
|
||||
|
@ -285,6 +285,14 @@ public:
|
||||
return stream.status() == QDataStream::Ok;
|
||||
}
|
||||
|
||||
QVector<QPair<QString, QString>> text() const
|
||||
{
|
||||
QVector<QPair<QString, QString>> result;
|
||||
for (const auto &c : m_chatItems)
|
||||
result << qMakePair(c.name, c.value);
|
||||
return result;
|
||||
}
|
||||
|
||||
Q_SIGNALS:
|
||||
void countChanged();
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user