mirror of
https://github.com/nomic-ai/gpt4all
synced 2024-11-18 03:25:46 +00:00
Fix gptj to have lower memory requirements for kv cache and add versioning to the internal state to smoothly handle such a fix in the future.
This commit is contained in:
parent
ccbd16cf18
commit
8c4b8f215f
6
chat.cpp
6
chat.cpp
@ -202,6 +202,12 @@ bool Chat::deserialize(QDataStream &stream, int version)
|
|||||||
stream >> m_userName;
|
stream >> m_userName;
|
||||||
emit nameChanged();
|
emit nameChanged();
|
||||||
stream >> m_savedModelName;
|
stream >> m_savedModelName;
|
||||||
|
|
||||||
|
// Prior to version 2 gptj models had a bug that fixed the kv_cache to F32 instead of F16 so
|
||||||
|
// unfortunately, we cannot deserialize these
|
||||||
|
if (version < 2 && m_savedModelName.contains("gpt4all-j"))
|
||||||
|
return false;
|
||||||
|
|
||||||
if (!m_llmodel->deserialize(stream, version))
|
if (!m_llmodel->deserialize(stream, version))
|
||||||
return false;
|
return false;
|
||||||
if (!m_chatModel->deserialize(stream, version))
|
if (!m_chatModel->deserialize(stream, version))
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
#include <QDataStream>
|
#include <QDataStream>
|
||||||
|
|
||||||
#define CHAT_FORMAT_MAGIC 0xF5D553CC
|
#define CHAT_FORMAT_MAGIC 0xF5D553CC
|
||||||
#define CHAT_FORMAT_VERSION 1
|
#define CHAT_FORMAT_VERSION 2
|
||||||
|
|
||||||
ChatListModel::ChatListModel(QObject *parent)
|
ChatListModel::ChatListModel(QObject *parent)
|
||||||
: QAbstractListModel(parent)
|
: QAbstractListModel(parent)
|
||||||
|
21
chatllm.cpp
21
chatllm.cpp
@ -16,6 +16,10 @@
|
|||||||
|
|
||||||
//#define DEBUG
|
//#define DEBUG
|
||||||
|
|
||||||
|
#define MPT_INTERNAL_STATE_VERSION 0
|
||||||
|
#define GPTJ_INTERNAL_STATE_VERSION 0
|
||||||
|
#define LLAMA_INTERNAL_STATE_VERSION 0
|
||||||
|
|
||||||
static QString modelFilePath(const QString &modelName)
|
static QString modelFilePath(const QString &modelName)
|
||||||
{
|
{
|
||||||
QString appPath = QCoreApplication::applicationDirPath()
|
QString appPath = QCoreApplication::applicationDirPath()
|
||||||
@ -96,12 +100,15 @@ bool ChatLLM::loadModel(const QString &modelName)
|
|||||||
isGPTJ = magic == 0x67676d6c;
|
isGPTJ = magic == 0x67676d6c;
|
||||||
isMPT = magic == 0x67676d6d;
|
isMPT = magic == 0x67676d6d;
|
||||||
if (isGPTJ) {
|
if (isGPTJ) {
|
||||||
|
m_modelType = ModelType::GPTJ_;
|
||||||
m_llmodel = new GPTJ;
|
m_llmodel = new GPTJ;
|
||||||
m_llmodel->loadModel(filePath.toStdString());
|
m_llmodel->loadModel(filePath.toStdString());
|
||||||
} else if (isMPT) {
|
} else if (isMPT) {
|
||||||
|
m_modelType = ModelType::MPT_;
|
||||||
m_llmodel = new MPT;
|
m_llmodel = new MPT;
|
||||||
m_llmodel->loadModel(filePath.toStdString());
|
m_llmodel->loadModel(filePath.toStdString());
|
||||||
} else {
|
} else {
|
||||||
|
m_modelType = ModelType::LLAMA_;
|
||||||
m_llmodel = new LLamaModel;
|
m_llmodel = new LLamaModel;
|
||||||
m_llmodel->loadModel(filePath.toStdString());
|
m_llmodel->loadModel(filePath.toStdString());
|
||||||
}
|
}
|
||||||
@ -380,6 +387,15 @@ bool ChatLLM::handleNameRecalculate(bool isRecalc)
|
|||||||
|
|
||||||
bool ChatLLM::serialize(QDataStream &stream, int version)
|
bool ChatLLM::serialize(QDataStream &stream, int version)
|
||||||
{
|
{
|
||||||
|
if (version > 1) {
|
||||||
|
stream << m_modelType;
|
||||||
|
switch (m_modelType) {
|
||||||
|
case MPT_: stream << MPT_INTERNAL_STATE_VERSION; break;
|
||||||
|
case GPTJ_: stream << GPTJ_INTERNAL_STATE_VERSION; break;
|
||||||
|
case LLAMA_: stream << LLAMA_INTERNAL_STATE_VERSION; break;
|
||||||
|
default: Q_UNREACHABLE();
|
||||||
|
}
|
||||||
|
}
|
||||||
stream << response();
|
stream << response();
|
||||||
stream << generatedName();
|
stream << generatedName();
|
||||||
stream << m_promptResponseTokens;
|
stream << m_promptResponseTokens;
|
||||||
@ -400,6 +416,11 @@ bool ChatLLM::serialize(QDataStream &stream, int version)
|
|||||||
|
|
||||||
bool ChatLLM::deserialize(QDataStream &stream, int version)
|
bool ChatLLM::deserialize(QDataStream &stream, int version)
|
||||||
{
|
{
|
||||||
|
if (version > 1) {
|
||||||
|
int internalStateVersion;
|
||||||
|
stream >> m_modelType;
|
||||||
|
stream >> internalStateVersion; // for future use
|
||||||
|
}
|
||||||
QString response;
|
QString response;
|
||||||
stream >> response;
|
stream >> response;
|
||||||
m_response = response.toStdString();
|
m_response = response.toStdString();
|
||||||
|
@ -17,6 +17,12 @@ class ChatLLM : public QObject
|
|||||||
Q_PROPERTY(QString generatedName READ generatedName NOTIFY generatedNameChanged)
|
Q_PROPERTY(QString generatedName READ generatedName NOTIFY generatedNameChanged)
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
enum ModelType {
|
||||||
|
MPT_,
|
||||||
|
GPTJ_,
|
||||||
|
LLAMA_
|
||||||
|
};
|
||||||
|
|
||||||
ChatLLM(Chat *parent);
|
ChatLLM(Chat *parent);
|
||||||
|
|
||||||
bool isModelLoaded() const;
|
bool isModelLoaded() const;
|
||||||
@ -82,6 +88,7 @@ private:
|
|||||||
quint32 m_promptResponseTokens;
|
quint32 m_promptResponseTokens;
|
||||||
quint32 m_responseLogits;
|
quint32 m_responseLogits;
|
||||||
QString m_modelName;
|
QString m_modelName;
|
||||||
|
ModelType m_modelType;
|
||||||
Chat *m_chat;
|
Chat *m_chat;
|
||||||
QByteArray m_state;
|
QByteArray m_state;
|
||||||
QThread m_llmThread;
|
QThread m_llmThread;
|
||||||
|
@ -352,7 +352,7 @@ bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & m
|
|||||||
const int n_mem = n_layer*n_ctx;
|
const int n_mem = n_layer*n_ctx;
|
||||||
const int n_elements = n_embd*n_mem;
|
const int n_elements = n_embd*n_mem;
|
||||||
|
|
||||||
if (!kv_cache_init(hparams, model.kv_self, GGML_TYPE_F32, model.hparams.n_ctx)) {
|
if (!kv_cache_init(hparams, model.kv_self, GGML_TYPE_F16, model.hparams.n_ctx)) {
|
||||||
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
||||||
ggml_free(ctx);
|
ggml_free(ctx);
|
||||||
return false;
|
return false;
|
||||||
|
Loading…
Reference in New Issue
Block a user