|
|
|
@ -645,16 +645,12 @@ GPTJ::GPTJ()
|
|
|
|
|
d_ptr->modelLoaded = false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool GPTJ::loadModel(const std::string &modelPath)
|
|
|
|
|
{
|
|
|
|
|
std::cerr << "GPTJ ERROR: loading gpt model from file unsupported!\n";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool GPTJ::loadModel(const std::string &modelPath, std::istream &fin) {
|
|
|
|
|
bool GPTJ::loadModel(const std::string &modelPath) {
|
|
|
|
|
std::mt19937 rng(time(NULL));
|
|
|
|
|
d_ptr->rng = rng;
|
|
|
|
|
|
|
|
|
|
auto fin = std::ifstream(modelPath, std::ios::binary);
|
|
|
|
|
|
|
|
|
|
// load the model
|
|
|
|
|
if (!gptj_model_load(modelPath, fin, d_ptr->model, d_ptr->vocab)) {
|
|
|
|
|
std::cerr << "GPT-J ERROR: failed to load model from " << modelPath;
|
|
|
|
|