diff --git a/chatllm.cpp b/chatllm.cpp index 2358457b..196c2be1 100644 --- a/chatllm.cpp +++ b/chatllm.cpp @@ -83,6 +83,7 @@ bool ChatLLM::loadModel(const QString &modelName) } bool isGPTJ = false; + bool isMPT = false; QString filePath = modelFilePath(modelName); QFileInfo info(filePath); if (info.exists()) { @@ -93,9 +94,13 @@ bool ChatLLM::loadModel(const QString &modelName) fin.seekg(0); fin.close(); isGPTJ = magic == 0x67676d6c; + isMPT = magic == 0x67676d6d; if (isGPTJ) { m_llmodel = new GPTJ; m_llmodel->loadModel(filePath.toStdString()); + } else if (isMPT) { + m_llmodel = new MPT; + m_llmodel->loadModel(filePath.toStdString()); } else { m_llmodel = new LLamaModel; m_llmodel->loadModel(filePath.toStdString()); diff --git a/llmodel/llmodel_c.cpp b/llmodel/llmodel_c.cpp index 9788f1fb..4361a900 100644 --- a/llmodel/llmodel_c.cpp +++ b/llmodel/llmodel_c.cpp @@ -2,6 +2,7 @@ #include "gptj.h" #include "llamamodel.h" +#include "mpt.h" struct LLModelWrapper { LLModel *llModel = nullptr; @@ -22,6 +23,20 @@ void llmodel_gptj_destroy(llmodel_model gptj) delete wrapper; } +llmodel_model llmodel_mpt_create() +{ + LLModelWrapper *wrapper = new LLModelWrapper; + wrapper->llModel = new MPT; + return reinterpret_cast(wrapper); +} + +void llmodel_mpt_destroy(llmodel_model mpt) +{ + LLModelWrapper *wrapper = reinterpret_cast(mpt); + delete wrapper->llModel; + delete wrapper; +} + llmodel_model llmodel_llama_create() { LLModelWrapper *wrapper = new LLModelWrapper; diff --git a/llmodel/llmodel_c.h b/llmodel/llmodel_c.h index 0907d765..f45bdd8d 100644 --- a/llmodel/llmodel_c.h +++ b/llmodel/llmodel_c.h @@ -71,6 +71,18 @@ llmodel_model llmodel_gptj_create(); */ void llmodel_gptj_destroy(llmodel_model gptj); +/** + * Create a MPT instance. + * @return A pointer to the MPT instance. + */ +llmodel_model llmodel_mpt_create(); + +/** + * Destroy a MPT instance. + * @param gptj A pointer to the MPT instance. + */ +void llmodel_mpt_destroy(llmodel_model mpt); + /** * Create a LLAMA instance. * @return A pointer to the LLAMA instance.