|
|
|
@ -16,6 +16,10 @@
|
|
|
|
|
|
|
|
|
|
//#define DEBUG
|
|
|
|
|
|
|
|
|
|
#define MPT_INTERNAL_STATE_VERSION 0
|
|
|
|
|
#define GPTJ_INTERNAL_STATE_VERSION 0
|
|
|
|
|
#define LLAMA_INTERNAL_STATE_VERSION 0
|
|
|
|
|
|
|
|
|
|
static QString modelFilePath(const QString &modelName)
|
|
|
|
|
{
|
|
|
|
|
QString appPath = QCoreApplication::applicationDirPath()
|
|
|
|
@ -96,12 +100,15 @@ bool ChatLLM::loadModel(const QString &modelName)
|
|
|
|
|
isGPTJ = magic == 0x67676d6c;
|
|
|
|
|
isMPT = magic == 0x67676d6d;
|
|
|
|
|
if (isGPTJ) {
|
|
|
|
|
m_modelType = ModelType::GPTJ_;
|
|
|
|
|
m_llmodel = new GPTJ;
|
|
|
|
|
m_llmodel->loadModel(filePath.toStdString());
|
|
|
|
|
} else if (isMPT) {
|
|
|
|
|
m_modelType = ModelType::MPT_;
|
|
|
|
|
m_llmodel = new MPT;
|
|
|
|
|
m_llmodel->loadModel(filePath.toStdString());
|
|
|
|
|
} else {
|
|
|
|
|
m_modelType = ModelType::LLAMA_;
|
|
|
|
|
m_llmodel = new LLamaModel;
|
|
|
|
|
m_llmodel->loadModel(filePath.toStdString());
|
|
|
|
|
}
|
|
|
|
@ -380,6 +387,15 @@ bool ChatLLM::handleNameRecalculate(bool isRecalc)
|
|
|
|
|
|
|
|
|
|
bool ChatLLM::serialize(QDataStream &stream, int version)
|
|
|
|
|
{
|
|
|
|
|
if (version > 1) {
|
|
|
|
|
stream << m_modelType;
|
|
|
|
|
switch (m_modelType) {
|
|
|
|
|
case MPT_: stream << MPT_INTERNAL_STATE_VERSION; break;
|
|
|
|
|
case GPTJ_: stream << GPTJ_INTERNAL_STATE_VERSION; break;
|
|
|
|
|
case LLAMA_: stream << LLAMA_INTERNAL_STATE_VERSION; break;
|
|
|
|
|
default: Q_UNREACHABLE();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
stream << response();
|
|
|
|
|
stream << generatedName();
|
|
|
|
|
stream << m_promptResponseTokens;
|
|
|
|
@ -400,6 +416,11 @@ bool ChatLLM::serialize(QDataStream &stream, int version)
|
|
|
|
|
|
|
|
|
|
bool ChatLLM::deserialize(QDataStream &stream, int version)
|
|
|
|
|
{
|
|
|
|
|
if (version > 1) {
|
|
|
|
|
int internalStateVersion;
|
|
|
|
|
stream >> m_modelType;
|
|
|
|
|
stream >> internalStateVersion; // for future use
|
|
|
|
|
}
|
|
|
|
|
QString response;
|
|
|
|
|
stream >> response;
|
|
|
|
|
m_response = response.toStdString();
|
|
|
|
|