diff --git a/llm.cpp b/llm.cpp index 5b08bcb0..4125eb7e 100644 --- a/llm.cpp +++ b/llm.cpp @@ -91,14 +91,14 @@ bool LLMObject::loadModelPrivate(const QString &modelName) if (info.exists()) { auto fin = std::ifstream(filePath.toStdString(), std::ios::binary); - uint32_t magic; fin.read((char *) &magic, sizeof(magic)); fin.seekg(0); + fin.close(); isGPTJ = magic == 0x67676d6c; if (isGPTJ) { m_llmodel = new GPTJ; - m_llmodel->loadModel(modelName.toStdString(), fin); + m_llmodel->loadModel(filePath.toStdString()); } else { m_llmodel = new LLamaModel; m_llmodel->loadModel(filePath.toStdString()); diff --git a/llmodel/gptj.cpp b/llmodel/gptj.cpp index c3ee6585..0d65c5cb 100644 --- a/llmodel/gptj.cpp +++ b/llmodel/gptj.cpp @@ -645,16 +645,12 @@ GPTJ::GPTJ() d_ptr->modelLoaded = false; } -bool GPTJ::loadModel(const std::string &modelPath) -{ - std::cerr << "GPTJ ERROR: loading gpt model from file unsupported!\n"; - return false; -} - -bool GPTJ::loadModel(const std::string &modelPath, std::istream &fin) { +bool GPTJ::loadModel(const std::string &modelPath) { std::mt19937 rng(time(NULL)); d_ptr->rng = rng; + auto fin = std::ifstream(modelPath, std::ios::binary); + // load the model if (!gptj_model_load(modelPath, fin, d_ptr->model, d_ptr->vocab)) { std::cerr << "GPT-J ERROR: failed to load model from " << modelPath; diff --git a/llmodel/gptj.h b/llmodel/gptj.h index 6f19dcd1..70a4655a 100644 --- a/llmodel/gptj.h +++ b/llmodel/gptj.h @@ -13,7 +13,6 @@ public: ~GPTJ(); bool loadModel(const std::string &modelPath) override; - bool loadModel(const std::string &modelPath, std::istream &fin) override; bool isModelLoaded() const override; void prompt(const std::string &prompt, std::function promptCallback, diff --git a/llmodel/llamamodel.cpp b/llmodel/llamamodel.cpp index c1638c10..0da930b6 100644 --- a/llmodel/llamamodel.cpp +++ b/llmodel/llamamodel.cpp @@ -31,12 +31,6 @@ LLamaModel::LLamaModel() d_ptr->modelLoaded = false; } -bool LLamaModel::loadModel(const std::string &modelPath, std::istream &fin) -{ - std::cerr << "LLAMA ERROR: loading llama model from stream unsupported!\n"; - return false; -} - bool LLamaModel::loadModel(const std::string &modelPath) { // load the model diff --git a/llmodel/llamamodel.h b/llmodel/llamamodel.h index c97f80b7..13e221a7 100644 --- a/llmodel/llamamodel.h +++ b/llmodel/llamamodel.h @@ -13,7 +13,6 @@ public: ~LLamaModel(); bool loadModel(const std::string &modelPath) override; - bool loadModel(const std::string &modelPath, std::istream &fin) override; bool isModelLoaded() const override; void prompt(const std::string &prompt, std::function promptCallback, diff --git a/llmodel/llmodel.h b/llmodel/llmodel.h index 0cc53689..08dc1764 100644 --- a/llmodel/llmodel.h +++ b/llmodel/llmodel.h @@ -11,7 +11,6 @@ public: virtual ~LLModel() {} virtual bool loadModel(const std::string &modelPath) = 0; - virtual bool loadModel(const std::string &modelPath, std::istream &fin) = 0; virtual bool isModelLoaded() const = 0; struct PromptContext { std::vector logits; // logits of current context