From fc60f0c09cc44f05b969d1cba895e5e8b099815b Mon Sep 17 00:00:00 2001 From: niansa/tuxifan Date: Thu, 1 Jun 2023 16:51:46 +0200 Subject: [PATCH] Cleaned up implementation management (#787) * Cleaned up implementation management * Initialize LLModel::m_implementation to nullptr * llmodel.h: Moved dlhandle fwd declare above LLModel class --- gpt4all-backend/llmodel.cpp | 47 ++++++++++++++----------------------- gpt4all-backend/llmodel.h | 25 +++++++++----------- 2 files changed, 29 insertions(+), 43 deletions(-) diff --git a/gpt4all-backend/llmodel.cpp b/gpt4all-backend/llmodel.cpp index 93bf4997..02cdfa6f 100644 --- a/gpt4all-backend/llmodel.cpp +++ b/gpt4all-backend/llmodel.cpp @@ -1,4 +1,5 @@ #include "llmodel.h" +#include "dlhandle.h" #include #include @@ -20,24 +21,28 @@ static bool requires_avxonly() { #endif } -LLModel::Implementation::Implementation(Dlhandle &&dlhandle_) : dlhandle(std::move(dlhandle_)) { - auto get_model_type = dlhandle.get("get_model_type"); +LLModel::Implementation::Implementation(Dlhandle &&dlhandle_) : dlhandle(new Dlhandle(std::move(dlhandle_))) { + auto get_model_type = dlhandle->get("get_model_type"); assert(get_model_type); modelType = get_model_type(); - auto get_build_variant = dlhandle.get("get_build_variant"); + auto get_build_variant = dlhandle->get("get_build_variant"); assert(get_build_variant); buildVariant = get_build_variant(); - magicMatch = dlhandle.get("magic_match"); + magicMatch = dlhandle->get("magic_match"); assert(magicMatch); - construct_ = dlhandle.get("construct"); + construct_ = dlhandle->get("construct"); assert(construct_); } +LLModel::Implementation::~Implementation() { + delete dlhandle; +} + bool LLModel::Implementation::isImplementation(const Dlhandle &dl) { return dl.get("is_g4a_backend_model_implementation"); } -const std::vector &LLModel::getImplementationList() { +const std::vector &LLModel::implementationList() { // NOTE: allocated on heap so we leak intentionally on exit so we have a chance to clean up the // individual models without the cleanup of the static list interfering static auto* libs = new std::vector([] () { @@ -46,12 +51,7 @@ const std::vector &LLModel::getImplementationList() { auto search_in_directory = [&](const std::filesystem::path& path) { // Iterate over all libraries for (const auto& f : std::filesystem::directory_iterator(path)) { - // Get path - // FIXME: Remove useless comment and avoid usage of 'auto' where having the type is - // helpful for code readability so someone doesn't have to look up the docs for what - // type is returned by 'path' as it is not std::string - const auto& p = f.path(); - // Check extension + const std::filesystem::path& p = f.path(); if (p.extension() != LIB_FILE_EXT) continue; // Add to list if model implementation try { @@ -74,29 +74,18 @@ const std::vector &LLModel::getImplementationList() { return *libs; } -const LLModel::Implementation* LLModel::getImplementation(std::ifstream& f, const std::string& buildVariant) { - // FIXME: Please remove all these useless comments as the code itself is more than enough in these - // instances to tell what is going on - // Iterate over all libraries - for (const auto& i : getImplementationList()) { +const LLModel::Implementation* LLModel::implementation(std::ifstream& f, const std::string& buildVariant) { + for (const auto& i : implementationList()) { f.seekg(0); - // Check that magic matches - if (!i.magicMatch(f)) { - continue; - } - // Check that build variant is correct - if (buildVariant != i.buildVariant) { - continue; - } - // Looks like we're good to go, return this dlhandle + if (!i.magicMatch(f)) continue; + if (buildVariant != i.buildVariant) continue; return &i; } - // Nothing found, so return nothing return nullptr; } LLModel *LLModel::construct(const std::string &modelPath, std::string buildVariant) { - //TODO: Auto-detect + //TODO: Auto-detect CUDA/OpenCL if (buildVariant == "auto") { if (requires_avxonly()) { buildVariant = "avxonly"; @@ -108,7 +97,7 @@ LLModel *LLModel::construct(const std::string &modelPath, std::string buildVaria std::ifstream f(modelPath, std::ios::binary); if (!f) return nullptr; // Get correct implementation - auto impl = getImplementation(f, buildVariant); + auto impl = implementation(f, buildVariant); if (!impl) return nullptr; f.close(); // Construct and return llmodel implementation diff --git a/gpt4all-backend/llmodel.h b/gpt4all-backend/llmodel.h index 7c9bafc7..882a369c 100644 --- a/gpt4all-backend/llmodel.h +++ b/gpt4all-backend/llmodel.h @@ -1,8 +1,6 @@ #ifndef LLMODEL_H #define LLMODEL_H -#include "dlhandle.h" // FIXME: would be nice to move this into implementation file - #include #include #include @@ -10,24 +8,27 @@ #include #include +class Dlhandle; + class LLModel { public: class Implementation { LLModel *(*construct_)(); public: - // FIXME: Move the whole implementation details to cpp file Implementation(Dlhandle&&); + ~Implementation(); static bool isImplementation(const Dlhandle&); std::string_view modelType, buildVariant; bool (*magicMatch)(std::ifstream& f); - Dlhandle dlhandle; + Dlhandle *dlhandle; + // The only way an implementation should be constructed LLModel *construct() const { auto fres = construct_(); - fres->implementation = this; + fres->m_implementation = this; return fres; } }; @@ -64,20 +65,16 @@ public: virtual void setThreadCount(int32_t /*n_threads*/) {} virtual int32_t threadCount() const { return 1; } - // FIXME: This is unused?? - const Implementation& getImplementation() const { - return *implementation; + const Implementation& implementation() const { + return *m_implementation; } - // FIXME: Maybe have an 'ImplementationInfo' class for the GUI here, but the DLHandle stuff should - // be hidden in cpp file - // FIXME: Avoid usage of 'get' for getters - static const std::vector& getImplementationList(); - static const Implementation *getImplementation(std::ifstream& f, const std::string& buildVariant); + static const std::vector& implementationList(); + static const Implementation *implementation(std::ifstream& f, const std::string& buildVariant); static LLModel *construct(const std::string &modelPath, std::string buildVariant = "default"); protected: - const Implementation *implementation; // FIXME: This is dangling! You don't initialize it in ctor either + const Implementation *m_implementation = nullptr; virtual void recalculateContext(PromptContext &promptCtx, std::function recalculate) = 0;