From 0600f551b32df504e78e884832867a181976ff5a Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Tue, 12 Dec 2023 11:45:03 -0500 Subject: [PATCH] chatllm: do not attempt to serialize incompatible state (#1742) --- gpt4all-backend/llamamodel.cpp | 4 ++++ gpt4all-chat/chat.cpp | 3 +-- gpt4all-chat/chatllm.cpp | 20 +++++++++++++------- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/gpt4all-backend/llamamodel.cpp b/gpt4all-backend/llamamodel.cpp index 65374854..882674e3 100644 --- a/gpt4all-backend/llamamodel.cpp +++ b/gpt4all-backend/llamamodel.cpp @@ -163,6 +163,10 @@ bool LLamaModel::loadModel(const std::string &modelPath) d_ptr->ctx_params.seed = params.seed; d_ptr->ctx_params.f16_kv = params.memory_f16; + // The new batch API provides space for n_vocab*n_tokens logits. Tell llama.cpp early + // that we want this many logits so the state serializes consistently. + d_ptr->ctx_params.logits_all = true; + d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); d_ptr->ctx_params.n_threads = d_ptr->n_threads; d_ptr->ctx_params.n_threads_batch = d_ptr->n_threads; diff --git a/gpt4all-chat/chat.cpp b/gpt4all-chat/chat.cpp index 3e7f91c1..17a92cf9 100644 --- a/gpt4all-chat/chat.cpp +++ b/gpt4all-chat/chat.cpp @@ -435,8 +435,7 @@ bool Chat::deserialize(QDataStream &stream, int version) if (!m_chatModel->deserialize(stream, version)) return false; - if (!deserializeKV || discardKV) - m_llmodel->setStateFromText(m_chatModel->text()); + m_llmodel->setStateFromText(m_chatModel->text()); emit chatModelChanged(); return stream.status() == QDataStream::Ok; diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index e5a84548..78f73cd4 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -863,11 +863,11 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV, if (!discardKV) m_state = qUncompress(compressed); } else { - if (!discardKV) + if (!discardKV) { stream >> m_state; - else { + } else { QByteArray state; - stream >> m_state; + stream >> state; } } @@ -912,7 +912,7 @@ void ChatLLM::restoreState() stream >> context; chatGPT->setContext(context); m_state.clear(); - m_state.resize(0); + m_state.squeeze(); return; } @@ -923,10 +923,16 @@ void ChatLLM::restoreState() if (m_state.isEmpty()) return; - m_processedSystemPrompt = true; - m_llModelInfo.model->restoreState(static_cast(reinterpret_cast(m_state.data()))); + if (m_llModelInfo.model->stateSize() == m_state.size()) { + m_llModelInfo.model->restoreState(static_cast(reinterpret_cast(m_state.data()))); + m_processedSystemPrompt = true; + } else { + qWarning() << "restoring state from text because" << m_llModelInfo.model->stateSize() << "!=" << m_state.size() << "\n"; + m_restoreStateFromText = true; + } + m_state.clear(); - m_state.resize(0); + m_state.squeeze(); } void ChatLLM::processSystemPrompt()