Fix gptj to have lower memory requirements for kv cache and add versioning to the internal state to smoothly handle such a fix in the future.

This commit is contained in:
Adam Treat 2023-05-08 17:23:02 -04:00
parent ccbd16cf18
commit 8c4b8f215f
5 changed files with 36 additions and 2 deletions

View File

@ -202,6 +202,12 @@ bool Chat::deserialize(QDataStream &stream, int version)
stream >> m_userName; stream >> m_userName;
emit nameChanged(); emit nameChanged();
stream >> m_savedModelName; stream >> m_savedModelName;
// Prior to version 2 gptj models had a bug that fixed the kv_cache to F32 instead of F16 so
// unfortunately, we cannot deserialize these
if (version < 2 && m_savedModelName.contains("gpt4all-j"))
return false;
if (!m_llmodel->deserialize(stream, version)) if (!m_llmodel->deserialize(stream, version))
return false; return false;
if (!m_chatModel->deserialize(stream, version)) if (!m_chatModel->deserialize(stream, version))

View File

@ -5,7 +5,7 @@
#include <QDataStream> #include <QDataStream>
#define CHAT_FORMAT_MAGIC 0xF5D553CC #define CHAT_FORMAT_MAGIC 0xF5D553CC
#define CHAT_FORMAT_VERSION 1 #define CHAT_FORMAT_VERSION 2
ChatListModel::ChatListModel(QObject *parent) ChatListModel::ChatListModel(QObject *parent)
: QAbstractListModel(parent) : QAbstractListModel(parent)

View File

@ -16,6 +16,10 @@
//#define DEBUG //#define DEBUG
#define MPT_INTERNAL_STATE_VERSION 0
#define GPTJ_INTERNAL_STATE_VERSION 0
#define LLAMA_INTERNAL_STATE_VERSION 0
static QString modelFilePath(const QString &modelName) static QString modelFilePath(const QString &modelName)
{ {
QString appPath = QCoreApplication::applicationDirPath() QString appPath = QCoreApplication::applicationDirPath()
@ -96,12 +100,15 @@ bool ChatLLM::loadModel(const QString &modelName)
isGPTJ = magic == 0x67676d6c; isGPTJ = magic == 0x67676d6c;
isMPT = magic == 0x67676d6d; isMPT = magic == 0x67676d6d;
if (isGPTJ) { if (isGPTJ) {
m_modelType = ModelType::GPTJ_;
m_llmodel = new GPTJ; m_llmodel = new GPTJ;
m_llmodel->loadModel(filePath.toStdString()); m_llmodel->loadModel(filePath.toStdString());
} else if (isMPT) { } else if (isMPT) {
m_modelType = ModelType::MPT_;
m_llmodel = new MPT; m_llmodel = new MPT;
m_llmodel->loadModel(filePath.toStdString()); m_llmodel->loadModel(filePath.toStdString());
} else { } else {
m_modelType = ModelType::LLAMA_;
m_llmodel = new LLamaModel; m_llmodel = new LLamaModel;
m_llmodel->loadModel(filePath.toStdString()); m_llmodel->loadModel(filePath.toStdString());
} }
@ -380,6 +387,15 @@ bool ChatLLM::handleNameRecalculate(bool isRecalc)
bool ChatLLM::serialize(QDataStream &stream, int version) bool ChatLLM::serialize(QDataStream &stream, int version)
{ {
if (version > 1) {
stream << m_modelType;
switch (m_modelType) {
case MPT_: stream << MPT_INTERNAL_STATE_VERSION; break;
case GPTJ_: stream << GPTJ_INTERNAL_STATE_VERSION; break;
case LLAMA_: stream << LLAMA_INTERNAL_STATE_VERSION; break;
default: Q_UNREACHABLE();
}
}
stream << response(); stream << response();
stream << generatedName(); stream << generatedName();
stream << m_promptResponseTokens; stream << m_promptResponseTokens;
@ -400,6 +416,11 @@ bool ChatLLM::serialize(QDataStream &stream, int version)
bool ChatLLM::deserialize(QDataStream &stream, int version) bool ChatLLM::deserialize(QDataStream &stream, int version)
{ {
if (version > 1) {
int internalStateVersion;
stream >> m_modelType;
stream >> internalStateVersion; // for future use
}
QString response; QString response;
stream >> response; stream >> response;
m_response = response.toStdString(); m_response = response.toStdString();

View File

@ -17,6 +17,12 @@ class ChatLLM : public QObject
Q_PROPERTY(QString generatedName READ generatedName NOTIFY generatedNameChanged) Q_PROPERTY(QString generatedName READ generatedName NOTIFY generatedNameChanged)
public: public:
enum ModelType {
MPT_,
GPTJ_,
LLAMA_
};
ChatLLM(Chat *parent); ChatLLM(Chat *parent);
bool isModelLoaded() const; bool isModelLoaded() const;
@ -82,6 +88,7 @@ private:
quint32 m_promptResponseTokens; quint32 m_promptResponseTokens;
quint32 m_responseLogits; quint32 m_responseLogits;
QString m_modelName; QString m_modelName;
ModelType m_modelType;
Chat *m_chat; Chat *m_chat;
QByteArray m_state; QByteArray m_state;
QThread m_llmThread; QThread m_llmThread;

View File

@ -352,7 +352,7 @@ bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & m
const int n_mem = n_layer*n_ctx; const int n_mem = n_layer*n_ctx;
const int n_elements = n_embd*n_mem; const int n_elements = n_embd*n_mem;
if (!kv_cache_init(hparams, model.kv_self, GGML_TYPE_F32, model.hparams.n_ctx)) { if (!kv_cache_init(hparams, model.kv_self, GGML_TYPE_F16, model.hparams.n_ctx)) {
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
ggml_free(ctx); ggml_free(ctx);
return false; return false;