llmodel: add model wrapper destructor, fix mem leak in golang bindings (#862)

Signed-off-by: Juuso Alasuutari <juuso.alasuutari@gmail.com>
This commit is contained in:
Juuso Alasuutari 2023-06-12 19:41:22 +03:00 committed by GitHub
parent ae4a275bcd
commit 5cfb1bda89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 10 deletions

View File

@ -9,6 +9,7 @@
struct LLModelWrapper {
LLModel *llModel = nullptr;
LLModel::PromptContext promptContext;
~LLModelWrapper() { delete llModel; }
};
@ -25,33 +26,33 @@ llmodel_model llmodel_model_create(const char *model_path) {
llmodel_model llmodel_model_create2(const char *model_path, const char *build_variant, llmodel_error *error) {
auto wrapper = new LLModelWrapper;
llmodel_error new_error{};
int error_code = 0;
try {
wrapper->llModel = LLModel::construct(model_path, build_variant);
} catch (const std::exception& e) {
new_error.code = EINVAL;
error_code = EINVAL;
last_error_message = e.what();
}
if (!wrapper->llModel) {
delete std::exchange(wrapper, nullptr);
// Get errno and error message if none
if (new_error.code == 0) {
new_error.code = errno;
last_error_message = strerror(errno);
if (error_code == 0) {
error_code = errno;
last_error_message = std::strerror(error_code);
}
// Set message pointer
new_error.message = last_error_message.c_str();
// Set error argument
if (error) *error = new_error;
if (error) {
error->message = last_error_message.c_str();
error->code = error_code;
}
}
return reinterpret_cast<llmodel_model*>(wrapper);
}
void llmodel_model_destroy(llmodel_model model) {
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
delete wrapper->llModel;
delete reinterpret_cast<LLModelWrapper*>(model);
}
bool llmodel_loadModel(llmodel_model model, const char *model_path)

View File

@ -25,6 +25,7 @@ void* load_model(const char *fname, int n_threads) {
return nullptr;
}
if (!llmodel_loadModel(model, fname)) {
llmodel_model_destroy(model);
return nullptr;
}