diff --git a/chat.cpp b/chat.cpp index 6d1782ff..75cad7eb 100644 --- a/chat.cpp +++ b/chat.cpp @@ -202,6 +202,12 @@ bool Chat::deserialize(QDataStream &stream, int version) stream >> m_userName; emit nameChanged(); stream >> m_savedModelName; + + // Prior to version 2 gptj models had a bug that fixed the kv_cache to F32 instead of F16 so + // unfortunately, we cannot deserialize these + if (version < 2 && m_savedModelName.contains("gpt4all-j")) + return false; + if (!m_llmodel->deserialize(stream, version)) return false; if (!m_chatModel->deserialize(stream, version)) diff --git a/chatlistmodel.cpp b/chatlistmodel.cpp index 152d7f6f..1fbd0110 100644 --- a/chatlistmodel.cpp +++ b/chatlistmodel.cpp @@ -5,7 +5,7 @@ #include #define CHAT_FORMAT_MAGIC 0xF5D553CC -#define CHAT_FORMAT_VERSION 1 +#define CHAT_FORMAT_VERSION 2 ChatListModel::ChatListModel(QObject *parent) : QAbstractListModel(parent) diff --git a/chatllm.cpp b/chatllm.cpp index 196c2be1..7e95bf0b 100644 --- a/chatllm.cpp +++ b/chatllm.cpp @@ -16,6 +16,10 @@ //#define DEBUG +#define MPT_INTERNAL_STATE_VERSION 0 +#define GPTJ_INTERNAL_STATE_VERSION 0 +#define LLAMA_INTERNAL_STATE_VERSION 0 + static QString modelFilePath(const QString &modelName) { QString appPath = QCoreApplication::applicationDirPath() @@ -96,12 +100,15 @@ bool ChatLLM::loadModel(const QString &modelName) isGPTJ = magic == 0x67676d6c; isMPT = magic == 0x67676d6d; if (isGPTJ) { + m_modelType = ModelType::GPTJ_; m_llmodel = new GPTJ; m_llmodel->loadModel(filePath.toStdString()); } else if (isMPT) { + m_modelType = ModelType::MPT_; m_llmodel = new MPT; m_llmodel->loadModel(filePath.toStdString()); } else { + m_modelType = ModelType::LLAMA_; m_llmodel = new LLamaModel; m_llmodel->loadModel(filePath.toStdString()); } @@ -380,6 +387,15 @@ bool ChatLLM::handleNameRecalculate(bool isRecalc) bool ChatLLM::serialize(QDataStream &stream, int version) { + if (version > 1) { + stream << m_modelType; + switch (m_modelType) { + case MPT_: stream << MPT_INTERNAL_STATE_VERSION; break; + case GPTJ_: stream << GPTJ_INTERNAL_STATE_VERSION; break; + case LLAMA_: stream << LLAMA_INTERNAL_STATE_VERSION; break; + default: Q_UNREACHABLE(); + } + } stream << response(); stream << generatedName(); stream << m_promptResponseTokens; @@ -400,6 +416,11 @@ bool ChatLLM::serialize(QDataStream &stream, int version) bool ChatLLM::deserialize(QDataStream &stream, int version) { + if (version > 1) { + int internalStateVersion; + stream >> m_modelType; + stream >> internalStateVersion; // for future use + } QString response; stream >> response; m_response = response.toStdString(); diff --git a/chatllm.h b/chatllm.h index 8a2732d1..9e0b932f 100644 --- a/chatllm.h +++ b/chatllm.h @@ -17,6 +17,12 @@ class ChatLLM : public QObject Q_PROPERTY(QString generatedName READ generatedName NOTIFY generatedNameChanged) public: + enum ModelType { + MPT_, + GPTJ_, + LLAMA_ + }; + ChatLLM(Chat *parent); bool isModelLoaded() const; @@ -82,6 +88,7 @@ private: quint32 m_promptResponseTokens; quint32 m_responseLogits; QString m_modelName; + ModelType m_modelType; Chat *m_chat; QByteArray m_state; QThread m_llmThread; diff --git a/llmodel/gptj.cpp b/llmodel/gptj.cpp index a5d04ae7..8e5145f4 100644 --- a/llmodel/gptj.cpp +++ b/llmodel/gptj.cpp @@ -352,7 +352,7 @@ bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & m const int n_mem = n_layer*n_ctx; const int n_elements = n_embd*n_mem; - if (!kv_cache_init(hparams, model.kv_self, GGML_TYPE_F32, model.hparams.n_ctx)) { + if (!kv_cache_init(hparams, model.kv_self, GGML_TYPE_F16, model.hparams.n_ctx)) { fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); ggml_free(ctx); return false;