diff --git a/.gitmodules b/.gitmodules index e00584ea..50de0692 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,9 @@ -[submodule "llama.cpp"] - path = gpt4all-backend/llama.cpp +[submodule "llama.cpp-230519"] + path = gpt4all-backend/llama.cpp-230519 + url = https://github.com/ggerganov/llama.cpp.git +[submodule "llama.cpp-230511"] + path = gpt4all-backend/llama.cpp-230511 url = https://github.com/manyoso/llama.cpp.git +[submodule "llama.cpp-mainline"] + path = gpt4all-backend/llama.cpp-mainline + url = https://github.com/ggerganov/llama.cpp.git diff --git a/gpt4all-backend/CMakeLists.txt b/gpt4all-backend/CMakeLists.txt index 0c06b60e..69917f46 100644 --- a/gpt4all-backend/CMakeLists.txt +++ b/gpt4all-backend/CMakeLists.txt @@ -17,36 +17,97 @@ endif() include_directories("${CMAKE_CURRENT_BINARY_DIR}") set(LLMODEL_VERSION_MAJOR 0) -set(LLMODEL_VERSION_MINOR 1) -set(LLMODEL_VERSION_PATCH 1) +set(LLMODEL_VERSION_MINOR 2) +set(LLMODEL_VERSION_PATCH 0) set(LLMODEL_VERSION "${LLMODEL_VERSION_MAJOR}.${LLMODEL_VERSION_MINOR}.${LLMODEL_VERSION_PATCH}") project(llmodel VERSION ${LLMODEL_VERSION} LANGUAGES CXX C) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) +set(BUILD_SHARED_LIBS ON) -set(LLAMA_BUILD_EXAMPLES ON CACHE BOOL "llama: build examples" FORCE) -set(BUILD_SHARED_LIBS ON FORCE) +# Check for IPO support +include(CheckIPOSupported) +check_ipo_supported(RESULT IPO_SUPPORTED OUTPUT IPO_ERROR) +if (NOT IPO_SUPPORTED) + message(WARNING "Interprocedural optimization is not supported by your toolchain! This will lead to bigger file sizes and worse performance: ${IPO_ERROR}") +else() + message(STATUS "Interprocedural optimization support detected") +endif() + +include(llama.cpp.cmake) + +set(BUILD_VARIANTS default avxonly) set(CMAKE_VERBOSE_MAKEFILE ON) -if (GPT4ALL_AVX_ONLY) - set(LLAMA_AVX2 OFF CACHE BOOL "llama: enable AVX2" FORCE) - set(LLAMA_F16C OFF CACHE BOOL "llama: enable F16C" FORCE) - set(LLAMA_FMA OFF CACHE BOOL "llama: enable FMA" FORCE) -endif() -add_subdirectory(llama.cpp) +# Go through each build variant +foreach(BUILD_VARIANT IN LISTS BUILD_VARIANTS) + # Determine flags + if (BUILD_VARIANT STREQUAL avxonly) + set(GPT4ALL_ALLOW_NON_AVX NO) + else() + set(GPT4ALL_ALLOW_NON_AVX YES) + endif() + set(LLAMA_AVX2 ${GPT4ALL_ALLOW_NON_AVX}) + set(LLAMA_F16C ${GPT4ALL_ALLOW_NON_AVX}) + set(LLAMA_FMA ${GPT4ALL_ALLOW_NON_AVX}) + + # Include GGML + include_ggml(llama.cpp-mainline -mainline-${BUILD_VARIANT} ON) + include_ggml(llama.cpp-230511 -230511-${BUILD_VARIANT} ON) + include_ggml(llama.cpp-230519 -230519-${BUILD_VARIANT} ON) + + # Function for preparing individual implementations + function(prepare_target TARGET_NAME BASE_LIB) + set(TARGET_NAME ${TARGET_NAME}-${BUILD_VARIANT}) + message(STATUS "Configuring model implementation target ${TARGET_NAME}") + # Link to ggml/llama + target_link_libraries(${TARGET_NAME} + PUBLIC ${BASE_LIB}-${BUILD_VARIANT}) + # Let it know about its build variant + target_compile_definitions(${TARGET_NAME} + PRIVATE GGML_BUILD_VARIANT="${BUILD_VARIANT}") + # Enable IPO if possible + set_property(TARGET ${TARGET_NAME} + PROPERTY INTERPROCEDURAL_OPTIMIZATION ${IPO_SUPPORTED}) + endfunction() + + # Add each individual implementations + add_library(llamamodel-mainline-${BUILD_VARIANT} SHARED + llamamodel.cpp) + target_compile_definitions(llamamodel-mainline-${BUILD_VARIANT} PRIVATE + LLAMA_VERSIONS=>=3 LLAMA_DATE=999999) + prepare_target(llamamodel-mainline llama-mainline) + + add_library(llamamodel-230519-${BUILD_VARIANT} SHARED + llamamodel.cpp) + target_compile_definitions(llamamodel-230519-${BUILD_VARIANT} PRIVATE + LLAMA_VERSIONS===2 LLAMA_DATE=230519) + prepare_target(llamamodel-230519 llama-230519) + + add_library(llamamodel-230511-${BUILD_VARIANT} SHARED + llamamodel.cpp) + target_compile_definitions(llamamodel-230511-${BUILD_VARIANT} PRIVATE + LLAMA_VERSIONS=<=1 LLAMA_DATE=230511) + prepare_target(llamamodel-230511 llama-230511) + + add_library(gptj-${BUILD_VARIANT} SHARED + gptj.cpp utils.h utils.cpp) + prepare_target(gptj ggml-230511) + + add_library(mpt-${BUILD_VARIANT} SHARED + mpt.cpp utils.h utils.cpp) + prepare_target(mpt ggml-230511) +endforeach() add_library(llmodel - gptj.h gptj.cpp - llamamodel.h llamamodel.cpp - llama.cpp/examples/common.cpp - llmodel.h llmodel_c.h llmodel_c.cpp - mpt.h mpt.cpp - utils.h utils.cpp + llmodel.h llmodel.cpp + llmodel_c.h llmodel_c.cpp + dlhandle.h ) - -target_link_libraries(llmodel - PRIVATE llama) +target_compile_definitions(llmodel PRIVATE LIB_FILE_EXT="${CMAKE_SHARED_LIBRARY_SUFFIX}") set_target_properties(llmodel PROPERTIES VERSION ${PROJECT_VERSION} diff --git a/gpt4all-backend/dlhandle.h b/gpt4all-backend/dlhandle.h new file mode 100644 index 00000000..1c23c101 --- /dev/null +++ b/gpt4all-backend/dlhandle.h @@ -0,0 +1,101 @@ +#ifndef DLHANDLE_H +#define DLHANDLE_H +#ifndef _WIN32 +#include +#include +#include +#include + + + +class Dlhandle { + void *chandle; + +public: + class Exception : public std::runtime_error { + public: + using std::runtime_error::runtime_error; + }; + + Dlhandle() : chandle(nullptr) {} + Dlhandle(const std::string& fpath, int flags = RTLD_LAZY) { + chandle = dlopen(fpath.c_str(), flags); + if (!chandle) { + throw Exception("dlopen(\""+fpath+"\"): "+dlerror()); + } + } + Dlhandle(const Dlhandle& o) = delete; + Dlhandle(Dlhandle&& o) : chandle(o.chandle) { + o.chandle = nullptr; + } + ~Dlhandle() { + if (chandle) dlclose(chandle); + } + + auto operator =(Dlhandle&& o) { + chandle = std::exchange(o.chandle, nullptr); + } + + bool is_valid() const { + return chandle != nullptr; + } + operator bool() const { + return is_valid(); + } + + template + T* get(const std::string& fname) { + auto fres = reinterpret_cast(dlsym(chandle, fname.c_str())); + return (dlerror()==NULL)?fres:nullptr; + } + auto get_fnc(const std::string& fname) { + return get(fname); + } +}; +#else +#include +#include +#include +#include +#include + + + +class Dlhandle { + HMODULE chandle; + +public: + class Exception : public std::runtime_error { + public: + using std::runtime_error::runtime_error; + }; + + Dlhandle() : chandle(nullptr) {} + Dlhandle(const std::string& fpath) { + chandle = LoadLibraryA(fpath.c_str()); + if (!chandle) { + throw Exception("dlopen(\""+fpath+"\"): Error"); + } + } + Dlhandle(const Dlhandle& o) = delete; + Dlhandle(Dlhandle&& o) : chandle(o.chandle) { + o.chandle = nullptr; + } + ~Dlhandle() { + if (chandle) FreeLibrary(chandle); + } + + bool is_valid() const { + return chandle != nullptr; + } + + template + T* get(const std::string& fname) { + return reinterpret_cast(GetProcAddress(chandle, fname.c_str())); + } + auto get_fnc(const std::string& fname) { + return get(fname); + } +}; +#endif +#endif // DLHANDLE_H diff --git a/gpt4all-backend/gptj.cpp b/gpt4all-backend/gptj.cpp index 946fdeb6..302c7ee4 100644 --- a/gpt4all-backend/gptj.cpp +++ b/gpt4all-backend/gptj.cpp @@ -1,5 +1,5 @@ -#include "gptj.h" -#include "llama.cpp/ggml.h" +#define GPTJ_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE +#include "gptj_impl.h" #include "utils.h" @@ -25,10 +25,16 @@ #endif #include #include +#include + + +namespace { +const char *modelType_ = "MPT"; -// default hparams (GPT-J 6B) static const size_t MB = 1024*1024; +} +// default hparams (GPT-J 6B) struct gptj_hparams { int32_t n_vocab = 50400; int32_t n_ctx = 2048; @@ -229,8 +235,6 @@ bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & m } } - const ggml_type wtype2 = GGML_TYPE_F32; - auto & ctx = model.ctx; size_t ctx_size = 0; @@ -279,6 +283,7 @@ bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & m struct ggml_init_params params = { .mem_size = ctx_size, .mem_buffer = NULL, + .no_alloc = false }; model.ctx = ggml_init(params); @@ -294,7 +299,6 @@ bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & m const int n_embd = hparams.n_embd; const int n_layer = hparams.n_layer; - const int n_ctx = hparams.n_ctx; const int n_vocab = hparams.n_vocab; model.layers.resize(n_layer); @@ -355,14 +359,6 @@ bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & m // key + value memory { const auto & hparams = model.hparams; - - const int n_embd = hparams.n_embd; - const int n_layer = hparams.n_layer; - const int n_ctx = hparams.n_ctx; - - const int n_mem = n_layer*n_ctx; - const int n_elements = n_embd*n_mem; - if (!kv_cache_init(hparams, model.kv_self, GGML_TYPE_F16, model.hparams.n_ctx)) { fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); ggml_free(ctx); @@ -505,8 +501,6 @@ bool gptj_eval( const int n_vocab = hparams.n_vocab; const int n_rot = hparams.n_rot; - const int d_key = n_embd/n_head; - const size_t init_buf_size = 1024u*MB; if (!model.buf.addr || model.buf.size < init_buf_size) model.buf.resize(init_buf_size); @@ -526,10 +520,12 @@ bool gptj_eval( struct ggml_init_params params = { .mem_size = model.buf.size, .mem_buffer = model.buf.addr, + .no_alloc = false }; struct ggml_context * ctx0 = ggml_init(params); - struct ggml_cgraph gf = { .n_threads = n_threads }; + struct ggml_cgraph gf = {}; + gf.n_threads = n_threads; struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd)); @@ -772,8 +768,7 @@ size_t gptj_copy_state_data(const gptj_model &model, const std::mt19937 &rng, ui } const size_t written = out - dest; - const size_t expected = gptj_get_state_size(model); - assert(written == expected); + assert(written == gptj_get_state_size(model)); fflush(stdout); return written; } @@ -822,8 +817,7 @@ size_t gptj_set_state_data(gptj_model *model, std::mt19937 *rng, const uint8_t * } const size_t nread = in - src; - const size_t expected = gptj_get_state_size(*model); - assert(nread == expected); + assert(nread == gptj_get_state_size(*model)); fflush(stdout); return nread; } @@ -840,6 +834,7 @@ struct GPTJPrivate { GPTJ::GPTJ() : d_ptr(new GPTJPrivate) { + modelType = modelType_; d_ptr->model = new gptj_model; d_ptr->modelLoaded = false; @@ -908,12 +903,6 @@ void GPTJ::prompt(const std::string &prompt, return; } - const int64_t t_main_start_us = ggml_time_us(); - - int64_t t_sample_us = 0; - int64_t t_predict_us = 0; - int64_t t_prompt_us = 0; - // tokenize the prompt std::vector embd_inp = ::gpt_tokenize(d_ptr->vocab, prompt); @@ -942,20 +931,19 @@ void GPTJ::prompt(const std::string &prompt, // process the prompt in batches size_t i = 0; - const int64_t t_start_prompt_us = ggml_time_us(); while (i < embd_inp.size()) { size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size()); std::vector batch(embd_inp.begin() + i, embd_inp.begin() + batch_end); // Check if the context has run out... - if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) { + if (promptCtx.n_past + int32_t(batch.size()) > promptCtx.n_ctx) { const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase; // Erase the first percentage of context from the tokens... std::cerr << "GPTJ: reached the end of the context window so resizing\n"; promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); promptCtx.n_past = promptCtx.tokens.size(); recalculateContext(promptCtx, recalculateCallback); - assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx); + assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx); } if (!gptj_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits, @@ -966,7 +954,7 @@ void GPTJ::prompt(const std::string &prompt, size_t tokens = batch_end - i; for (size_t t = 0; t < tokens; ++t) { - if (promptCtx.tokens.size() == promptCtx.n_ctx) + if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx) promptCtx.tokens.erase(promptCtx.tokens.begin()); promptCtx.tokens.push_back(batch.at(t)); if (!promptCallback(batch.at(t))) @@ -975,10 +963,6 @@ void GPTJ::prompt(const std::string &prompt, promptCtx.n_past += batch.size(); i = batch_end; } - t_prompt_us += ggml_time_us() - t_start_prompt_us; - - int p_instructFound = 0; - int r_instructFound = 0; std::string cachedResponse; std::vector cachedTokens; @@ -986,24 +970,20 @@ void GPTJ::prompt(const std::string &prompt, = { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context" }; // predict next tokens - int32_t totalPredictions = 0; for (int i = 0; i < promptCtx.n_predict; i++) { // sample next token const int n_vocab = d_ptr->model->hparams.n_vocab; gpt_vocab::id id = 0; { - const int64_t t_start_sample_us = ggml_time_us(); const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size()); - id = gpt_sample_top_k_top_p(d_ptr->vocab, n_vocab, + id = gpt_sample_top_k_top_p(n_vocab, promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks, n_prev_toks, promptCtx.logits, promptCtx.top_k, promptCtx.top_p, promptCtx.temp, promptCtx.repeat_penalty, d_ptr->rng); - - t_sample_us += ggml_time_us() - t_start_sample_us; } // Check if the context has run out... @@ -1017,29 +997,24 @@ void GPTJ::prompt(const std::string &prompt, assert(promptCtx.n_past + 1 <= promptCtx.n_ctx); } - const int64_t t_start_predict_us = ggml_time_us(); if (!gptj_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits, d_ptr->mem_per_token)) { std::cerr << "GPT-J ERROR: Failed to predict next token\n"; return; } - t_predict_us += ggml_time_us() - t_start_predict_us; promptCtx.n_past += 1; // display text - ++totalPredictions; - if (id == 50256 /*end of text*/) - goto stop_generating; + return; const std::string str = d_ptr->vocab.id_to_token[id]; // Check if the provided str is part of our reverse prompts bool foundPartialReversePrompt = false; const std::string completed = cachedResponse + str; - if (reversePrompts.find(completed) != reversePrompts.end()) { - goto stop_generating; - } + if (reversePrompts.find(completed) != reversePrompts.end()) + return; // Check if it partially matches our reverse prompts and if so, cache for (auto s : reversePrompts) { @@ -1059,32 +1034,14 @@ void GPTJ::prompt(const std::string &prompt, // Empty the cache for (auto t : cachedTokens) { - if (promptCtx.tokens.size() == promptCtx.n_ctx) + if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx) promptCtx.tokens.erase(promptCtx.tokens.begin()); promptCtx.tokens.push_back(t); if (!responseCallback(t, d_ptr->vocab.id_to_token[t])) - goto stop_generating; + return; } cachedTokens.clear(); } - -stop_generating: - -#if 0 - // report timing - { - const int64_t t_main_end_us = ggml_time_us(); - - std::cout << "GPT-J INFO: mem per token = " << mem_per_token << " bytes\n"; - std::cout << "GPT-J INFO: sample time = " << t_sample_us/1000.0f << " ms\n"; - std::cout << "GPT-J INFO: prompt time = " << t_prompt_us/1000.0f << " ms\n"; - std::cout << "GPT-J INFO: predict time = " << t_predict_us/1000.0f << " ms / " << t_predict_us/1000.0f/totalPredictions << " ms per token\n"; - std::cout << "GPT-J INFO: total time = " << (t_main_end_us - t_main_start_us)/1000.0f << " ms\n"; - fflush(stdout); - } -#endif - - return; } void GPTJ::recalculateContext(PromptContext &promptCtx, std::function recalculate) @@ -1095,7 +1052,7 @@ void GPTJ::recalculateContext(PromptContext &promptCtx, std::function batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end); - assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx); + assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx); if (!gptj_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits, d_ptr->mem_per_token)) { @@ -1107,8 +1064,38 @@ void GPTJ::recalculateContext(PromptContext &promptCtx, std::function(&magic), sizeof(magic)); + return magic == 0x67676d6c; +} + +DLL_EXPORT LLModel *construct() { + return new GPTJ; +} +} diff --git a/gpt4all-backend/gptj.h b/gpt4all-backend/gptj_impl.h similarity index 76% rename from gpt4all-backend/gptj.h rename to gpt4all-backend/gptj_impl.h index 48da82dd..c5c99d21 100644 --- a/gpt4all-backend/gptj.h +++ b/gpt4all-backend/gptj_impl.h @@ -1,3 +1,6 @@ +#ifndef GPTJ_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE +#error This file is NOT meant to be included outside of gptj.cpp. Doing so is DANGEROUS. Be sure to know what you are doing before proceeding to #define GPTJ_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE +#endif #ifndef GPTJ_H #define GPTJ_H @@ -6,7 +9,7 @@ #include #include "llmodel.h" -class GPTJPrivate; +struct GPTJPrivate; class GPTJ : public LLModel { public: GPTJ(); diff --git a/gpt4all-backend/llama.cpp b/gpt4all-backend/llama.cpp-230511 similarity index 100% rename from gpt4all-backend/llama.cpp rename to gpt4all-backend/llama.cpp-230511 diff --git a/gpt4all-backend/llama.cpp-230519 b/gpt4all-backend/llama.cpp-230519 new file mode 160000 index 00000000..5ea43392 --- /dev/null +++ b/gpt4all-backend/llama.cpp-230519 @@ -0,0 +1 @@ +Subproject commit 5ea43392731040b454c293123839b90e159cbb99 diff --git a/gpt4all-backend/llama.cpp-mainline b/gpt4all-backend/llama.cpp-mainline new file mode 160000 index 00000000..ea600071 --- /dev/null +++ b/gpt4all-backend/llama.cpp-mainline @@ -0,0 +1 @@ +Subproject commit ea600071cb005267e9e8f2629c1e406dd5fde083 diff --git a/gpt4all-backend/llama.cpp.cmake b/gpt4all-backend/llama.cpp.cmake new file mode 100644 index 00000000..db85d4c6 --- /dev/null +++ b/gpt4all-backend/llama.cpp.cmake @@ -0,0 +1,364 @@ +cmake_minimum_required(VERSION 3.12) # Don't bump this version for no reason + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE) + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") +endif() + +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) + +if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) + set(LLAMA_STANDALONE ON) + + # configure project version + # TODO +else() + set(LLAMA_STANDALONE OFF) +endif() + +if (EMSCRIPTEN) + set(BUILD_SHARED_LIBS_DEFAULT OFF) + + option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" ON) +else() + if (MINGW) + set(BUILD_SHARED_LIBS_DEFAULT OFF) + else() + set(BUILD_SHARED_LIBS_DEFAULT ON) + endif() +endif() + + +# +# Option list +# + +# general +option(LLAMA_STATIC "llama: static link libraries" OFF) +option(LLAMA_NATIVE "llama: enable -march=native flag" OFF) +option(LLAMA_LTO "llama: enable link time optimization" OFF) + +# debug +option(LLAMA_ALL_WARNINGS "llama: enable all compiler warnings" ON) +option(LLAMA_ALL_WARNINGS_3RD_PARTY "llama: enable all compiler warnings in 3rd party libs" OFF) +option(LLAMA_GPROF "llama: enable gprof" OFF) + +# sanitizers +option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF) +option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF) +option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF) + +# instruction set specific +#option(LLAMA_AVX "llama: enable AVX" ON) +#option(LLAMA_AVX2 "llama: enable AVX2" ON) +#option(LLAMA_AVX512 "llama: enable AVX512" OFF) +#option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF) +#option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF) +#option(LLAMA_FMA "llama: enable FMA" ON) +# in MSVC F16C is implied with AVX2/AVX512 +#if (NOT MSVC) +# option(LLAMA_F16C "llama: enable F16C" ON) +#endif() + +# 3rd party libs +option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) +option(LLAMA_OPENBLAS "llama: use OpenBLAS" OFF) +option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) +option(LLAMA_CLBLAST "llama: use CLBlast" OFF) + +# +# Compile flags +# + +set(CMAKE_C_STANDARD 11) +set(CMAKE_C_STANDARD_REQUIRED true) +set(THREADS_PREFER_PTHREAD_FLAG ON) +find_package(Threads REQUIRED) + +if (NOT MSVC) + if (LLAMA_SANITIZE_THREAD) + add_compile_options(-fsanitize=thread) + link_libraries(-fsanitize=thread) + endif() + + if (LLAMA_SANITIZE_ADDRESS) + add_compile_options(-fsanitize=address -fno-omit-frame-pointer) + link_libraries(-fsanitize=address) + endif() + + if (LLAMA_SANITIZE_UNDEFINED) + add_compile_options(-fsanitize=undefined) + link_libraries(-fsanitize=undefined) + endif() +endif() + +if (APPLE AND LLAMA_ACCELERATE) + find_library(ACCELERATE_FRAMEWORK Accelerate) + if (ACCELERATE_FRAMEWORK) + message(STATUS "Accelerate framework found") + + add_compile_definitions(GGML_USE_ACCELERATE) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK}) + else() + message(WARNING "Accelerate framework not found") + endif() +endif() + +if (LLAMA_OPENBLAS) + if (LLAMA_STATIC) + set(BLA_STATIC ON) + endif() + + set(BLA_VENDOR OpenBLAS) + find_package(BLAS) + if (BLAS_FOUND) + message(STATUS "OpenBLAS found") + + add_compile_definitions(GGML_USE_OPENBLAS) + add_link_options(${BLAS_LIBRARIES}) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} openblas) + + # find header file + set(OPENBLAS_INCLUDE_SEARCH_PATHS + /usr/include + /usr/include/openblas + /usr/include/openblas-base + /usr/local/include + /usr/local/include/openblas + /usr/local/include/openblas-base + /opt/OpenBLAS/include + $ENV{OpenBLAS_HOME} + $ENV{OpenBLAS_HOME}/include + ) + find_path(OPENBLAS_INC NAMES cblas.h PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS}) + add_compile_options(-I${OPENBLAS_INC}) + else() + message(WARNING "OpenBLAS not found") + endif() +endif() + +if (LLAMA_ALL_WARNINGS) + if (NOT MSVC) + set(c_flags + -Wall + -Wextra + -Wpedantic + -Wcast-qual + -Wdouble-promotion + -Wshadow + -Wstrict-prototypes + -Wpointer-arith + ) + set(cxx_flags + -Wall + -Wextra + -Wpedantic + -Wcast-qual + -Wno-unused-function + -Wno-multichar + ) + else() + # todo : msvc + endif() + + add_compile_options( + "$<$:${c_flags}>" + "$<$:${cxx_flags}>" + ) + +endif() + +if (MSVC) + add_compile_definitions(_CRT_SECURE_NO_WARNINGS) + + if (BUILD_SHARED_LIBS) + set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) + endif() +endif() + +if (LLAMA_LTO) + include(CheckIPOSupported) + check_ipo_supported(RESULT result OUTPUT output) + if (result) + set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE) + else() + message(WARNING "IPO is not supported: ${output}") + endif() +endif() + +# Architecture specific +# TODO: probably these flags need to be tweaked on some architectures +# feel free to update the Makefile for your architecture and send a pull request or issue +message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") +if (NOT MSVC) + if (LLAMA_STATIC) + add_link_options(-static) + if (MINGW) + add_link_options(-static-libgcc -static-libstdc++) + endif() + endif() + if (LLAMA_GPROF) + add_compile_options(-pg) + endif() + if (LLAMA_NATIVE) + add_compile_options(-march=native) + endif() +endif() + +function(include_ggml DIRECTORY SUFFIX WITH_LLAMA) + message(STATUS "Configuring ggml implementation target llama${SUFFIX} in ${CMAKE_CURRENT_SOURCE_DIR}/${DIRECTORY}") + + if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") + message(STATUS "ARM detected") + if (MSVC) + # TODO: arm msvc? + else() + if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") + add_compile_options(-mcpu=native) + endif() + # TODO: armv6,7,8 version specific flags + endif() + elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$") + message(STATUS "x86 detected") + if (MSVC) + if (LLAMA_AVX512) + add_compile_options($<$:/arch:AVX512>) + add_compile_options($<$:/arch:AVX512>) + # MSVC has no compile-time flags enabling specific + # AVX512 extensions, neither it defines the + # macros corresponding to the extensions. + # Do it manually. + if (LLAMA_AVX512_VBMI) + add_compile_definitions($<$:__AVX512VBMI__>) + add_compile_definitions($<$:__AVX512VBMI__>) + endif() + if (LLAMA_AVX512_VNNI) + add_compile_definitions($<$:__AVX512VNNI__>) + add_compile_definitions($<$:__AVX512VNNI__>) + endif() + elseif (LLAMA_AVX2) + add_compile_options($<$:/arch:AVX2>) + add_compile_options($<$:/arch:AVX2>) + elseif (LLAMA_AVX) + add_compile_options($<$:/arch:AVX>) + add_compile_options($<$:/arch:AVX>) + endif() + else() + if (LLAMA_F16C) + add_compile_options(-mf16c) + endif() + if (LLAMA_FMA) + add_compile_options(-mfma) + endif() + if (LLAMA_AVX) + add_compile_options(-mavx) + endif() + if (LLAMA_AVX2) + add_compile_options(-mavx2) + endif() + if (LLAMA_AVX512) + add_compile_options(-mavx512f) + add_compile_options(-mavx512bw) + endif() + if (LLAMA_AVX512_VBMI) + add_compile_options(-mavx512vbmi) + endif() + if (LLAMA_AVX512_VNNI) + add_compile_options(-mavx512vnni) + endif() + endif() + else() + # TODO: support PowerPC + message(STATUS "Unknown architecture") + endif() + + # + # Build libraries + # + + if (LLAMA_CUBLAS AND EXISTS ${DIRECTORY}/ggml-cuda.h) + cmake_minimum_required(VERSION 3.17) + + find_package(CUDAToolkit) + if (CUDAToolkit_FOUND) + message(STATUS "cuBLAS found") + + enable_language(CUDA) + + set(GGML_CUDA_SOURCES ${DIRECTORY}/ggml-cuda.cu ${DIRECTORY}/ggml-cuda.h) + + add_compile_definitions(GGML_USE_CUBLAS) + + if (LLAMA_STATIC) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) + else() + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) + endif() + + else() + message(WARNING "cuBLAS not found") + endif() + endif() + + if (LLAMA_CLBLAST AND EXISTS ${DIRECTORY}/ggml-opencl.h) + find_package(CLBlast) + if (CLBlast_FOUND) + message(STATUS "CLBlast found") + + set(GGML_OPENCL_SOURCES ${DIRECTORY}/ggml-opencl.c ${DIRECTORY}/ggml-opencl.h) + + add_compile_definitions(GGML_USE_CLBLAST) + + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast) + else() + message(WARNING "CLBlast not found") + endif() + endif() + + add_library(ggml${SUFFIX} OBJECT + ${DIRECTORY}/ggml.c + ${DIRECTORY}/ggml.h + ${GGML_CUDA_SOURCES} + ${GGML_OPENCL_SOURCES}) + + target_include_directories(ggml${SUFFIX} PUBLIC ${DIRECTORY}) + target_compile_features(ggml${SUFFIX} PUBLIC c_std_11) # don't bump + target_link_libraries(ggml${SUFFIX} PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS}) + + if (BUILD_SHARED_LIBS) + set_target_properties(ggml${SUFFIX} PROPERTIES POSITION_INDEPENDENT_CODE ON) + endif() + + if (WITH_LLAMA) + # Backwards compatibility with old llama.cpp versions + set(LLAMA_UTIL_SOURCE_FILE llama-util.h) + if (NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${DIRECTORY}/${LLAMA_UTIL_SOURCE_FILE}) + set(LLAMA_UTIL_SOURCE_FILE llama_util.h) + endif() + + add_library(llama${SUFFIX} + ${DIRECTORY}/llama.cpp + ${DIRECTORY}/llama.h + ${DIRECTORY}/${LLAMA_UTIL_SOURCE_FILE}) + + target_include_directories(llama${SUFFIX} PUBLIC ${DIRECTORY}) + target_compile_features(llama${SUFFIX} PUBLIC cxx_std_11) # don't bump + target_link_libraries(llama${SUFFIX} PRIVATE ggml${SUFFIX} ${LLAMA_EXTRA_LIBS}) + + if (BUILD_SHARED_LIBS) + set_target_properties(llama${SUFFIX} PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_compile_definitions(llama${SUFFIX} PRIVATE LLAMA_SHARED LLAMA_BUILD) + endif() + endif() + + if (GGML_CUDA_SOURCES) + message(STATUS "GGML CUDA sources found, configuring CUDA architecture") + set_property(TARGET ggml${SUFFIX} PROPERTY CUDA_ARCHITECTURES OFF) + set_property(TARGET ggml${SUFFIX} PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto") + if (WITH_LLAMA) + set_property(TARGET llama${SUFFIX} PROPERTY CUDA_ARCHITECTURES OFF) + endif() + endif() +endfunction() diff --git a/gpt4all-backend/llamamodel.cpp b/gpt4all-backend/llamamodel.cpp index 3149af82..9830d08a 100644 --- a/gpt4all-backend/llamamodel.cpp +++ b/gpt4all-backend/llamamodel.cpp @@ -1,8 +1,5 @@ -#include "llamamodel.h" - -#include "llama.cpp/examples/common.h" -#include "llama.cpp/llama.h" -#include "llama.cpp/ggml.h" +#define LLAMAMODEL_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE +#include "llamamodel_impl.h" #include #include @@ -28,16 +25,77 @@ #include #include +#include +#include + + +namespace { +const char *modelType_ = "LLaMA"; +} + +struct gpt_params { + int32_t seed = -1; // RNG seed + int32_t n_keep = 0; // number of tokens to keep from initial prompt +#if LLAMA_DATE <= 230511 + int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions) +#endif + +#if LLAMA_DATE >= 230519 + // sampling parameters + float tfs_z = 1.0f; // 1.0 = disabled + float typical_p = 1.0f; // 1.0 = disabled +#endif + + std::string prompt = ""; + + bool memory_f16 = true; // use f16 instead of f32 for memory kv + + bool use_mmap = true; // use mmap for faster loads + bool use_mlock = false; // use mlock to keep model in memory +}; + +#if LLAMA_DATE >= 230519 +static int llama_sample_top_p_top_k( + llama_context *ctx, + const llama_token *last_n_tokens_data, + int last_n_tokens_size, + int top_k, + float top_p, + float temp, + float repeat_penalty) { + auto logits = llama_get_logits(ctx); + auto n_vocab = llama_n_vocab(ctx); + // Populate initial list of all candidates + std::vector candidates; + candidates.reserve(n_vocab); + for (int token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false}; + // Sample repeat penalty + llama_sample_repetition_penalty(nullptr, &candidates_p, last_n_tokens_data, last_n_tokens_size, repeat_penalty); + // Temperature sampling + llama_sample_top_k(ctx, &candidates_p, top_k, 1); + llama_sample_tail_free(ctx, &candidates_p, 1.0f, 1); + llama_sample_typical(ctx, &candidates_p, 1.0f, 1); + llama_sample_top_p(ctx, &candidates_p, top_p, 1); + llama_sample_temperature(ctx, &candidates_p, temp); + return llama_sample_token(ctx, &candidates_p); +} +#endif + struct LLamaPrivate { const std::string modelPath; bool modelLoaded; llama_context *ctx = nullptr; llama_context_params params; int64_t n_threads = 0; + bool empty = true; }; LLamaModel::LLamaModel() : d_ptr(new LLamaPrivate) { + modelType = modelType_; d_ptr->modelLoaded = false; } @@ -49,14 +107,12 @@ bool LLamaModel::loadModel(const std::string &modelPath) gpt_params params; d_ptr->params.n_ctx = 2048; - d_ptr->params.n_parts = params.n_parts; d_ptr->params.seed = params.seed; d_ptr->params.f16_kv = params.memory_f16; d_ptr->params.use_mmap = params.use_mmap; -#if defined (__APPLE__) - d_ptr->params.use_mlock = true; -#else d_ptr->params.use_mlock = params.use_mlock; +#if LLAMA_DATE <= 230511 + d_ptr->params.n_parts = params.n_parts; #endif d_ptr->ctx = llama_init_from_file(modelPath.c_str(), d_ptr->params); @@ -75,8 +131,7 @@ void LLamaModel::setThreadCount(int32_t n_threads) { d_ptr->n_threads = n_threads; } -int32_t LLamaModel::threadCount() const -{ +int32_t LLamaModel::threadCount() const { return d_ptr->n_threads; } @@ -102,7 +157,8 @@ size_t LLamaModel::saveState(uint8_t *dest) const size_t LLamaModel::restoreState(const uint8_t *src) { - return llama_set_state_data(d_ptr->ctx, src); + // const_cast is required, see: https://github.com/ggerganov/llama.cpp/pull/1540 + return llama_set_state_data(d_ptr->ctx, const_cast(src)); } void LLamaModel::prompt(const std::string &prompt, @@ -123,7 +179,11 @@ void LLamaModel::prompt(const std::string &prompt, params.prompt.insert(0, 1, ' '); // tokenize the prompt - auto embd_inp = ::llama_tokenize(d_ptr->ctx, params.prompt, false); + std::vector embd_inp(params.prompt.size() + 4); + int n = llama_tokenize(d_ptr->ctx, params.prompt.c_str(), embd_inp.data(), embd_inp.size(), d_ptr->empty); + assert(n >= 0); + embd_inp.resize(n); + d_ptr->empty = false; // save the context size promptCtx.n_ctx = llama_n_ctx(d_ptr->ctx); @@ -143,20 +203,19 @@ void LLamaModel::prompt(const std::string &prompt, // process the prompt in batches size_t i = 0; - const int64_t t_start_prompt_us = ggml_time_us(); while (i < embd_inp.size()) { size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size()); std::vector batch(embd_inp.begin() + i, embd_inp.begin() + batch_end); // Check if the context has run out... - if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) { + if (promptCtx.n_past + int32_t(batch.size()) > promptCtx.n_ctx) { const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase; // Erase the first percentage of context from the tokens... std::cerr << "LLAMA: reached the end of the context window so resizing\n"; promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); promptCtx.n_past = promptCtx.tokens.size(); recalculateContext(promptCtx, recalculateCallback); - assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx); + assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx); } if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) { @@ -166,7 +225,7 @@ void LLamaModel::prompt(const std::string &prompt, size_t tokens = batch_end - i; for (size_t t = 0; t < tokens; ++t) { - if (promptCtx.tokens.size() == promptCtx.n_ctx) + if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx) promptCtx.tokens.erase(promptCtx.tokens.begin()); promptCtx.tokens.push_back(batch.at(t)); if (!promptCallback(batch.at(t))) @@ -179,10 +238,9 @@ void LLamaModel::prompt(const std::string &prompt, std::string cachedResponse; std::vector cachedTokens; std::unordered_set reversePrompts - = { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context" }; + = { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant" }; // predict next tokens - int32_t totalPredictions = 0; for (int i = 0; i < promptCtx.n_predict; i++) { // sample next token const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size()); @@ -209,7 +267,6 @@ void LLamaModel::prompt(const std::string &prompt, promptCtx.n_past += 1; // display text - ++totalPredictions; if (id == llama_token_eos()) return; @@ -240,7 +297,7 @@ void LLamaModel::prompt(const std::string &prompt, // Empty the cache for (auto t : cachedTokens) { - if (promptCtx.tokens.size() == promptCtx.n_ctx) + if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx) promptCtx.tokens.erase(promptCtx.tokens.begin()); promptCtx.tokens.push_back(t); if (!responseCallback(t, llama_token_to_str(d_ptr->ctx, t))) @@ -258,7 +315,7 @@ void LLamaModel::recalculateContext(PromptContext &promptCtx, std::function batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end); - assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx); + assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx); if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) { std::cerr << "LLAMA ERROR: Failed to process prompt\n"; @@ -269,8 +326,43 @@ void LLamaModel::recalculateContext(PromptContext &promptCtx, std::function(&magic), sizeof(magic)); + if (magic != 0x67676a74) return false; + // Check version + uint32_t version = 0; + f.read(reinterpret_cast(&version), sizeof(version)); + return version LLAMA_VERSIONS; +} + +DLL_EXPORT LLModel *construct() { + return new LLamaModel; +} +} diff --git a/gpt4all-backend/llamamodel.h b/gpt4all-backend/llamamodel_impl.h similarity index 74% rename from gpt4all-backend/llamamodel.h rename to gpt4all-backend/llamamodel_impl.h index f4e07782..a8f18936 100644 --- a/gpt4all-backend/llamamodel.h +++ b/gpt4all-backend/llamamodel_impl.h @@ -1,3 +1,6 @@ +#ifndef LLAMAMODEL_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE +#error This file is NOT meant to be included outside of llamamodel.cpp. Doing so is DANGEROUS. Be sure to know what you are doing before proceeding to #define LLAMAMODEL_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE +#endif #ifndef LLAMAMODEL_H #define LLAMAMODEL_H @@ -6,7 +9,7 @@ #include #include "llmodel.h" -class LLamaPrivate; +struct LLamaPrivate; class LLamaModel : public LLModel { public: LLamaModel(); @@ -33,4 +36,4 @@ private: LLamaPrivate *d_ptr; }; -#endif // LLAMAMODEL_H \ No newline at end of file +#endif // LLAMAMODEL_H diff --git a/gpt4all-backend/llmodel.cpp b/gpt4all-backend/llmodel.cpp new file mode 100644 index 00000000..bd466921 --- /dev/null +++ b/gpt4all-backend/llmodel.cpp @@ -0,0 +1,90 @@ +#include "llmodel.h" +#include "dlhandle.h" + +#include +#include +#include +#include + +static Dlhandle *get_implementation(std::ifstream& f, const std::string& buildVariant) { + // Collect all model implementation libraries + // NOTE: allocated on heap so we leak intentionally on exit so we have a chance to clean up the + // individual models without the cleanup of the static list interfering + static auto* libs = new std::vector([] () { + std::vector fres; + + auto search_in_directory = [&](const std::filesystem::path& path) { + // Iterate over all libraries + for (const auto& f : std::filesystem::directory_iterator(path)) { + // Get path + const auto& p = f.path(); + // Check extension + if (p.extension() != LIB_FILE_EXT) continue; + // Add to list if model implementation + try { + Dlhandle dl(p.string()); + if (dl.get("is_g4a_backend_model_implementation")) { + fres.emplace_back(std::move(dl)); + } + } catch (...) {} + } + }; + + search_in_directory("."); +#if defined(__APPLE__) + search_in_directory("../../../"); +#endif + return fres; + }()); + // Iterate over all libraries + for (auto& dl : *libs) { + f.seekg(0); + // Check that magic matches + auto magic_match = dl.get("magic_match"); + if (!magic_match || !magic_match(f)) { + continue; + } + // Check that build variant is correct + auto get_build_variant = dl.get("get_build_variant"); + if (buildVariant != (get_build_variant?get_build_variant():"default")) { + continue; + } + // Looks like we're good to go, return this dlhandle + return &dl; + } + // Nothing found, so return nothing + return nullptr; +} + +static bool requires_avxonly() { +#ifdef __x86_64__ + return !__builtin_cpu_supports("avx2") && !__builtin_cpu_supports("fma"); +#else + return false; // Don't know how to handle ARM +#endif +} + +LLModel *LLModel::construct(const std::string &modelPath, std::string buildVariant) { + //TODO: Auto-detect + if (buildVariant == "auto") { + if (requires_avxonly()) { + buildVariant = "avxonly"; + } else { + buildVariant = "default"; + } + } + // Read magic + std::ifstream f(modelPath, std::ios::binary); + if (!f) return nullptr; + // Get correct implementation + auto impl = get_implementation(f, buildVariant); + if (!impl) return nullptr; + f.close(); + // Get inference constructor + auto constructor = impl->get("construct"); + if (!constructor) return nullptr; + // Construct llmodel implementation + auto fres = constructor(); + // Return final instance + return fres; +} diff --git a/gpt4all-backend/llmodel.h b/gpt4all-backend/llmodel.h index 50c313e6..4bcf2716 100644 --- a/gpt4all-backend/llmodel.h +++ b/gpt4all-backend/llmodel.h @@ -1,21 +1,23 @@ #ifndef LLMODEL_H #define LLMODEL_H - #include #include #include #include + class LLModel { public: explicit LLModel() {} virtual ~LLModel() {} + static LLModel *construct(const std::string &modelPath, std::string buildVariant = "default"); + virtual bool loadModel(const std::string &modelPath) = 0; virtual bool isModelLoaded() const = 0; virtual size_t stateSize() const { return 0; } - virtual size_t saveState(uint8_t *dest) const { return 0; } - virtual size_t restoreState(const uint8_t *src) { return 0; } + virtual size_t saveState(uint8_t */*dest*/) const { return 0; } + virtual size_t restoreState(const uint8_t */*src*/) { return 0; } struct PromptContext { std::vector logits; // logits of current context std::vector tokens; // current tokens in the context window @@ -36,12 +38,18 @@ public: std::function responseCallback, std::function recalculateCallback, PromptContext &ctx) = 0; - virtual void setThreadCount(int32_t n_threads) {} + virtual void setThreadCount(int32_t /*n_threads*/) {} virtual int32_t threadCount() const { return 1; } + const char *getModelType() const { + return modelType; + } + protected: virtual void recalculateContext(PromptContext &promptCtx, std::function recalculate) = 0; + + const char *modelType; }; #endif // LLMODEL_H diff --git a/gpt4all-backend/llmodel_c.cpp b/gpt4all-backend/llmodel_c.cpp index 3e20c4d5..a09e20dc 100644 --- a/gpt4all-backend/llmodel_c.cpp +++ b/gpt4all-backend/llmodel_c.cpp @@ -1,79 +1,57 @@ #include "llmodel_c.h" +#include "llmodel.h" + +#include +#include +#include -#include "gptj.h" -#include "llamamodel.h" -#include "mpt.h" struct LLModelWrapper { LLModel *llModel = nullptr; LLModel::PromptContext promptContext; }; -llmodel_model llmodel_gptj_create() -{ - LLModelWrapper *wrapper = new LLModelWrapper; - wrapper->llModel = new GPTJ; - return reinterpret_cast(wrapper); -} - -void llmodel_gptj_destroy(llmodel_model gptj) -{ - LLModelWrapper *wrapper = reinterpret_cast(gptj); - delete wrapper->llModel; - delete wrapper; -} -llmodel_model llmodel_mpt_create() -{ - LLModelWrapper *wrapper = new LLModelWrapper; - wrapper->llModel = new MPT; - return reinterpret_cast(wrapper); -} +thread_local static std::string last_error_message; -void llmodel_mpt_destroy(llmodel_model mpt) -{ - LLModelWrapper *wrapper = reinterpret_cast(mpt); - delete wrapper->llModel; - delete wrapper; -} - -llmodel_model llmodel_llama_create() -{ - LLModelWrapper *wrapper = new LLModelWrapper; - wrapper->llModel = new LLamaModel; - return reinterpret_cast(wrapper); -} - -void llmodel_llama_destroy(llmodel_model llama) -{ - LLModelWrapper *wrapper = reinterpret_cast(llama); - delete wrapper->llModel; - delete wrapper; -} llmodel_model llmodel_model_create(const char *model_path) { - - uint32_t magic; - llmodel_model model; - FILE *f = fopen(model_path, "rb"); - fread(&magic, sizeof(magic), 1, f); - - if (magic == 0x67676d6c) { model = llmodel_gptj_create(); } - else if (magic == 0x67676a74) { model = llmodel_llama_create(); } - else if (magic == 0x67676d6d) { model = llmodel_mpt_create(); } - else {fprintf(stderr, "Invalid model file\n");} - fclose(f); - return model; + auto fres = llmodel_model_create2(model_path, "auto", nullptr); + if (!fres) { + fprintf(stderr, "Invalid model file\n"); + } + return fres; +} + +llmodel_model llmodel_model_create2(const char *model_path, const char *build_variant, llmodel_error *error) { + auto wrapper = new LLModelWrapper; + llmodel_error new_error{}; + + try { + wrapper->llModel = LLModel::construct(model_path, build_variant); + } catch (const std::exception& e) { + new_error.code = EINVAL; + last_error_message = e.what(); + } + + if (!wrapper->llModel) { + delete std::exchange(wrapper, nullptr); + // Get errno and error message if none + if (new_error.code == 0) { + new_error.code = errno; + last_error_message = strerror(errno); + } + // Set message pointer + new_error.message = last_error_message.c_str(); + // Set error argument + if (error) *error = new_error; + } + return reinterpret_cast(wrapper); } void llmodel_model_destroy(llmodel_model model) { - LLModelWrapper *wrapper = reinterpret_cast(model); - const std::type_info &modelTypeInfo = typeid(*wrapper->llModel); - - if (modelTypeInfo == typeid(GPTJ)) { llmodel_gptj_destroy(model); } - if (modelTypeInfo == typeid(LLamaModel)) { llmodel_llama_destroy(model); } - if (modelTypeInfo == typeid(MPT)) { llmodel_mpt_destroy(model); } + delete wrapper->llModel; } bool llmodel_loadModel(llmodel_model model, const char *model_path) @@ -84,20 +62,20 @@ bool llmodel_loadModel(llmodel_model model, const char *model_path) bool llmodel_isModelLoaded(llmodel_model model) { - const auto *llm = reinterpret_cast(model)->llModel; - return llm->isModelLoaded(); + LLModelWrapper *wrapper = reinterpret_cast(model); + return wrapper->llModel->isModelLoaded(); } uint64_t llmodel_get_state_size(llmodel_model model) { - const auto *llm = reinterpret_cast(model)->llModel; - return llm->stateSize(); + LLModelWrapper *wrapper = reinterpret_cast(model); + return wrapper->llModel->stateSize(); } uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest) { - const auto *llm = reinterpret_cast(model)->llModel; - return llm->saveState(dest); + LLModelWrapper *wrapper = reinterpret_cast(model); + return wrapper->llModel->saveState(dest); } uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src) @@ -181,6 +159,6 @@ void llmodel_setThreadCount(llmodel_model model, int32_t n_threads) int32_t llmodel_threadCount(llmodel_model model) { - const auto *llm = reinterpret_cast(model)->llModel; - return llm->threadCount(); + LLModelWrapper *wrapper = reinterpret_cast(model); + return wrapper->llModel->threadCount(); } diff --git a/gpt4all-backend/llmodel_c.h b/gpt4all-backend/llmodel_c.h index 0b6972b7..ebbd4782 100644 --- a/gpt4all-backend/llmodel_c.h +++ b/gpt4all-backend/llmodel_c.h @@ -5,6 +5,15 @@ #include #include +#ifdef __GNUC__ +#define DEPRECATED __attribute__ ((deprecated)) +#elif defined(_MSC_VER) +#define DEPRECATED __declspec(deprecated) +#else +#pragma message("WARNING: You need to implement DEPRECATED for this compiler") +#define DEPRECATED +#endif + #ifdef __cplusplus extern "C" { #endif @@ -14,13 +23,24 @@ extern "C" { */ typedef void *llmodel_model; +/** + * Structure containing any errors that may eventually occur + */ +struct llmodel_error { + const char *message; // Human readable error description; Thread-local; guaranteed to survive until next llmodel C API call + int code; // errno; 0 if none +}; +#ifndef __cplusplus +typedef struct llmodel_error llmodel_error; +#endif + /** * llmodel_prompt_context structure for holding the prompt context. * NOTE: The implementation takes care of all the memory handling of the raw logits pointer and the * raw tokens pointer. Attempting to resize them or modify them in any way can lead to undefined * behavior. */ -typedef struct { +struct llmodel_prompt_context { float *logits; // logits of current context size_t logits_size; // the size of the raw logits vector int32_t *tokens; // current tokens in the context window @@ -35,7 +55,10 @@ typedef struct { float repeat_penalty; // penalty factor for repeated tokens int32_t repeat_last_n; // last n tokens to penalize float context_erase; // percent of context to erase if we exceed the context window -} llmodel_prompt_context; +}; +#ifndef __cplusplus +typedef struct llmodel_prompt_context llmodel_prompt_context; +#endif /** * Callback type for prompt processing. @@ -60,48 +83,22 @@ typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response typedef bool (*llmodel_recalculate_callback)(bool is_recalculating); /** - * Create a GPTJ instance. - * @return A pointer to the GPTJ instance. - */ -llmodel_model llmodel_gptj_create(); - -/** - * Destroy a GPTJ instance. - * @param gptj A pointer to the GPTJ instance. - */ -void llmodel_gptj_destroy(llmodel_model gptj); - -/** - * Create a MPT instance. - * @return A pointer to the MPT instance. - */ -llmodel_model llmodel_mpt_create(); - -/** - * Destroy a MPT instance. - * @param gptj A pointer to the MPT instance. - */ -void llmodel_mpt_destroy(llmodel_model mpt); - -/** - * Create a LLAMA instance. - * @return A pointer to the LLAMA instance. - */ -llmodel_model llmodel_llama_create(); - -/** - * Destroy a LLAMA instance. - * @param llama A pointer to the LLAMA instance. + * Create a llmodel instance. + * Recognises correct model type from file at model_path + * @param model_path A string representing the path to the model file. + * @return A pointer to the llmodel_model instance; NULL on error. */ -void llmodel_llama_destroy(llmodel_model llama); +DEPRECATED llmodel_model llmodel_model_create(const char *model_path); /** * Create a llmodel instance. * Recognises correct model type from file at model_path - * @param model_path A string representing the path to the model file. - * @return A pointer to the llmodel_model instance. + * @param model_path A string representing the path to the model file; will only be used to detect model type. + * @param build_variant A string representing the implementation to use (auto, default, avxonly, ...), + * @param error A pointer to a llmodel_error; will only be set on error. + * @return A pointer to the llmodel_model instance; NULL on error. */ -llmodel_model llmodel_model_create(const char *model_path); +llmodel_model llmodel_model_create2(const char *model_path, const char *build_variant, llmodel_error *error); /** * Destroy a llmodel instance. @@ -110,7 +107,6 @@ llmodel_model llmodel_model_create(const char *model_path); */ void llmodel_model_destroy(llmodel_model model); - /** * Load a model from a file. * @param model A pointer to the llmodel_model instance. diff --git a/gpt4all-backend/mpt.cpp b/gpt4all-backend/mpt.cpp index 61e71cc4..e526f5ac 100644 --- a/gpt4all-backend/mpt.cpp +++ b/gpt4all-backend/mpt.cpp @@ -1,5 +1,5 @@ -#include "mpt.h" -#include "llama.cpp/ggml.h" +#define MPT_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE +#include "mpt_impl.h" #include "utils.h" @@ -28,8 +28,14 @@ #include #include #include +#include + + +namespace { +const char *modelType_ = "MPT"; static const size_t MB = 1024*1024; +} // default hparams (MPT 7B) struct mpt_hparams { @@ -293,7 +299,6 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod const int n_embd = hparams.n_embd; const int n_layer = hparams.n_layer; - const int n_ctx = hparams.n_ctx; const int n_vocab = hparams.n_vocab; const int expand = hparams.expand; @@ -331,14 +336,6 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod // key + value memory { const auto & hparams = model.hparams; - - const int n_embd = hparams.n_embd; - const int n_layer = hparams.n_layer; - const int n_ctx = hparams.n_ctx; - - const int n_mem = n_layer*n_ctx; - const int n_elements = n_embd*n_mem; - if (!kv_cache_init(hparams, model.kv_self, GGML_TYPE_F16, model.hparams.n_ctx)) { fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); ggml_free(ctx); @@ -457,9 +454,6 @@ bool mpt_eval( const int n_ctx = hparams.n_ctx; const int n_head = hparams.n_head; const int n_vocab = hparams.n_vocab; - const int expand = hparams.expand; - - const int d_key = n_embd/n_head; const size_t init_buf_size = 1024u*MB; if (!model.buf.addr || model.buf.size < init_buf_size) @@ -480,10 +474,12 @@ bool mpt_eval( struct ggml_init_params params = { .mem_size = model.buf.size, .mem_buffer = model.buf.addr, + .no_alloc = false }; struct ggml_context * ctx0 = ggml_init(params); - struct ggml_cgraph gf = { .n_threads = n_threads }; + struct ggml_cgraph gf = {}; + gf.n_threads = n_threads; struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd)); @@ -695,8 +691,7 @@ size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937 &rng, uint } const size_t written = out - dest; - const size_t expected = mpt_get_state_size(model); - assert(written == expected); + assert(written == mpt_get_state_size(model)); fflush(stdout); return written; } @@ -745,8 +740,7 @@ size_t mpt_set_state_data(mpt_model *model, std::mt19937 *rng, const uint8_t *sr } const size_t nread = in - src; - const size_t expected = mpt_get_state_size(*model); - assert(nread == expected); + assert(nread == mpt_get_state_size(*model)); fflush(stdout); return nread; } @@ -764,6 +758,7 @@ struct MPTPrivate { MPT::MPT() : d_ptr(new MPTPrivate) { + modelType = modelType_; d_ptr->model = new mpt_model; d_ptr->modelLoaded = false; @@ -833,12 +828,6 @@ void MPT::prompt(const std::string &prompt, return; } - const int64_t t_main_start_us = ggml_time_us(); - - int64_t t_sample_us = 0; - int64_t t_predict_us = 0; - int64_t t_prompt_us = 0; - // tokenize the prompt std::vector embd_inp = gpt_tokenize(d_ptr->vocab, prompt); @@ -867,20 +856,19 @@ void MPT::prompt(const std::string &prompt, // process the prompt in batches size_t i = 0; - const int64_t t_start_prompt_us = ggml_time_us(); while (i < embd_inp.size()) { size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size()); std::vector batch(embd_inp.begin() + i, embd_inp.begin() + batch_end); // Check if the context has run out... - if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) { + if (promptCtx.n_past + int32_t(batch.size()) > promptCtx.n_ctx) { const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase; // Erase the first percentage of context from the tokens... std::cerr << "MPT: reached the end of the context window so resizing\n"; promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); promptCtx.n_past = promptCtx.tokens.size(); recalculateContext(promptCtx, recalculateCallback); - assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx); + assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx); } if (!mpt_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits, @@ -891,7 +879,7 @@ void MPT::prompt(const std::string &prompt, size_t tokens = batch_end - i; for (size_t t = 0; t < tokens; ++t) { - if (promptCtx.tokens.size() == promptCtx.n_ctx) + if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx) promptCtx.tokens.erase(promptCtx.tokens.begin()); promptCtx.tokens.push_back(batch.at(t)); if (!promptCallback(batch.at(t))) @@ -900,10 +888,6 @@ void MPT::prompt(const std::string &prompt, promptCtx.n_past += batch.size(); i = batch_end; } - t_prompt_us += ggml_time_us() - t_start_prompt_us; - - int p_instructFound = 0; - int r_instructFound = 0; std::string cachedResponse; std::vector cachedTokens; @@ -911,24 +895,20 @@ void MPT::prompt(const std::string &prompt, = { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context" }; // predict next tokens - int32_t totalPredictions = 0; for (int i = 0; i < promptCtx.n_predict; i++) { // sample next token const int n_vocab = d_ptr->model->hparams.n_vocab; int id = 0; { - const int64_t t_start_sample_us = ggml_time_us(); const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size()); - id = gpt_sample_top_k_top_p(d_ptr->vocab, n_vocab, + id = gpt_sample_top_k_top_p(n_vocab, promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks, n_prev_toks, promptCtx.logits, promptCtx.top_k, promptCtx.top_p, promptCtx.temp, promptCtx.repeat_penalty, d_ptr->rng); - - t_sample_us += ggml_time_us() - t_start_sample_us; } // Check if the context has run out... @@ -942,33 +922,28 @@ void MPT::prompt(const std::string &prompt, assert(promptCtx.n_past + 1 <= promptCtx.n_ctx); } - const int64_t t_start_predict_us = ggml_time_us(); if (!mpt_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits, d_ptr->mem_per_token)) { std::cerr << "GPT-J ERROR: Failed to predict next token\n"; return; } - t_predict_us += ggml_time_us() - t_start_predict_us; promptCtx.n_past += 1; - // display text - ++totalPredictions; - + // display tex // mpt-7b-chat has special token for end if (d_ptr->has_im_end && id == d_ptr->vocab.token_to_id["<|im_end|>"]) - goto stop_generating; + return; if (id == 0 /*end of text*/) - goto stop_generating; + return; const std::string str = d_ptr->vocab.id_to_token[id]; // Check if the provided str is part of our reverse prompts bool foundPartialReversePrompt = false; const std::string completed = cachedResponse + str; - if (reversePrompts.find(completed) != reversePrompts.end()) { - goto stop_generating; - } + if (reversePrompts.find(completed) != reversePrompts.end()) + return; // Check if it partially matches our reverse prompts and if so, cache for (auto s : reversePrompts) { @@ -988,32 +963,14 @@ void MPT::prompt(const std::string &prompt, // Empty the cache for (auto t : cachedTokens) { - if (promptCtx.tokens.size() == promptCtx.n_ctx) + if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx) promptCtx.tokens.erase(promptCtx.tokens.begin()); promptCtx.tokens.push_back(t); if (!responseCallback(t, d_ptr->vocab.id_to_token[t])) - goto stop_generating; + return; } cachedTokens.clear(); } - -stop_generating: - -#if 0 - // report timing - { - const int64_t t_main_end_us = ggml_time_us(); - - std::cout << "GPT-J INFO: mem per token = " << mem_per_token << " bytes\n"; - std::cout << "GPT-J INFO: sample time = " << t_sample_us/1000.0f << " ms\n"; - std::cout << "GPT-J INFO: prompt time = " << t_prompt_us/1000.0f << " ms\n"; - std::cout << "GPT-J INFO: predict time = " << t_predict_us/1000.0f << " ms / " << t_predict_us/1000.0f/totalPredictions << " ms per token\n"; - std::cout << "GPT-J INFO: total time = " << (t_main_end_us - t_main_start_us)/1000.0f << " ms\n"; - fflush(stdout); - } -#endif - - return; } void MPT::recalculateContext(PromptContext &promptCtx, std::function recalculate) @@ -1024,7 +981,7 @@ void MPT::recalculateContext(PromptContext &promptCtx, std::function size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size()); std::vector batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end); - assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx); + assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx); if (!mpt_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits, d_ptr->mem_per_token)) { @@ -1036,8 +993,38 @@ void MPT::recalculateContext(PromptContext &promptCtx, std::function goto stop_generating; i = batch_end; } - assert(promptCtx.n_past == promptCtx.tokens.size()); + assert(promptCtx.n_past == int32_t(promptCtx.tokens.size())); stop_generating: recalculate(false); } + +#if defined(_WIN32) +#define DLL_EXPORT __declspec(dllexport) +#else +#define DLL_EXPORT __attribute__ ((visibility ("default"))) +#endif + +extern "C" { +DLL_EXPORT bool is_g4a_backend_model_implementation() { + return true; +} + +DLL_EXPORT const char *get_model_type() { + return modelType_; +} + +DLL_EXPORT const char *get_build_variant() { + return GGML_BUILD_VARIANT; +} + +DLL_EXPORT bool magic_match(std::istream& f) { + uint32_t magic = 0; + f.read(reinterpret_cast(&magic), sizeof(magic)); + return magic == 0x67676d6d; +} + +DLL_EXPORT LLModel *construct() { + return new MPT; +} +} diff --git a/gpt4all-backend/mpt.h b/gpt4all-backend/mpt_impl.h similarity index 76% rename from gpt4all-backend/mpt.h rename to gpt4all-backend/mpt_impl.h index 15122d6e..31095afd 100644 --- a/gpt4all-backend/mpt.h +++ b/gpt4all-backend/mpt_impl.h @@ -1,3 +1,6 @@ +#ifndef MPT_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE +#error This file is NOT meant to be included outside of mpt.cpp. Doing so is DANGEROUS. Be sure to know what you are doing before proceeding to #define MPT_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE +#endif #ifndef MPT_H #define MPT_H @@ -6,7 +9,7 @@ #include #include "llmodel.h" -class MPTPrivate; +struct MPTPrivate; class MPT : public LLModel { public: MPT(); diff --git a/gpt4all-backend/utils.cpp b/gpt4all-backend/utils.cpp index 783054f5..8769315e 100644 --- a/gpt4all-backend/utils.cpp +++ b/gpt4all-backend/utils.cpp @@ -218,7 +218,6 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) { } gpt_vocab::id gpt_sample_top_k_top_p( - const gpt_vocab & vocab, const size_t actualVocabSize, const int32_t * last_n_tokens_data, int last_n_tokens_size, diff --git a/gpt4all-backend/utils.h b/gpt4all-backend/utils.h index 9c9f5c60..e3b90efe 100644 --- a/gpt4all-backend/utils.h +++ b/gpt4all-backend/utils.h @@ -79,7 +79,6 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab); // TODO: not sure if this implementation is correct // gpt_vocab::id gpt_sample_top_k_top_p( - const gpt_vocab & vocab, const size_t actualVocabSize, const int32_t * last_n_tokens_data, int last_n_tokens_size, diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index c654ae82..4eabe985 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -55,10 +55,10 @@ get_filename_component(Qt6_ROOT_DIR "${Qt6_ROOT_DIR}/.." ABSOLUTE) message(STATUS "qmake binary: ${QMAKE_EXECUTABLE}") message(STATUS "Qt 6 root directory: ${Qt6_ROOT_DIR}") -add_subdirectory(../gpt4all-backend llmodel) - set (CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) +add_subdirectory(../gpt4all-backend llmodel) + qt_add_executable(chat main.cpp chat.h chat.cpp @@ -146,7 +146,6 @@ endif() install(TARGETS chat DESTINATION bin COMPONENT ${COMPONENT_NAME_MAIN}) install(TARGETS llmodel DESTINATION lib COMPONENT ${COMPONENT_NAME_MAIN}) -install(TARGETS llama DESTINATION lib COMPONENT ${COMPONENT_NAME_MAIN}) set(CPACK_GENERATOR "IFW") set(CPACK_VERBATIM_VARIABLES YES) diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index ff3c8331..1638a4f6 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -2,9 +2,7 @@ #include "chat.h" #include "download.h" #include "network.h" -#include "../gpt4all-backend/gptj.h" -#include "../gpt4all-backend/llamamodel.h" -#include "../gpt4all-backend/mpt.h" +#include "../gpt4all-backend/llmodel.h" #include "chatgpt.h" #include @@ -215,25 +213,15 @@ bool ChatLLM::loadModel(const QString &modelName) model->setAPIKey(apiKey); m_modelInfo.model = model; } else { - auto fin = std::ifstream(filePath.toStdString(), std::ios::binary); - uint32_t magic; - fin.read((char *) &magic, sizeof(magic)); - fin.seekg(0); - fin.close(); - const bool isGPTJ = magic == 0x67676d6c; - const bool isMPT = magic == 0x67676d6d; - if (isGPTJ) { - m_modelType = LLModelType::GPTJ_; - m_modelInfo.model = new GPTJ; - m_modelInfo.model->loadModel(filePath.toStdString()); - } else if (isMPT) { - m_modelType = LLModelType::MPT_; - m_modelInfo.model = new MPT; - m_modelInfo.model->loadModel(filePath.toStdString()); - } else { - m_modelType = LLModelType::LLAMA_; - m_modelInfo.model = new LLamaModel; + m_modelInfo.model = LLModel::construct(filePath.toStdString()); + if (m_modelInfo.model) { m_modelInfo.model->loadModel(filePath.toStdString()); + switch (m_modelInfo.model->getModelType()[0]) { + case 'L': m_modelType = LLModelType::LLAMA_; break; + case 'G': m_modelType = LLModelType::GPTJ_; break; + case 'M': m_modelType = LLModelType::MPT_; break; + default: delete std::exchange(m_modelInfo.model, nullptr); + } } } #if defined(DEBUG_MODEL_LOADING) diff --git a/gpt4all-chat/server.cpp b/gpt4all-chat/server.cpp index 9211bf8a..8ba59e67 100644 --- a/gpt4all-chat/server.cpp +++ b/gpt4all-chat/server.cpp @@ -112,7 +112,7 @@ void Server::start() ); m_server->route("/v1/completions", QHttpServerRequest::Method::Post, - [=](const QHttpServerRequest &request) { + [this](const QHttpServerRequest &request) { if (!LLM::globalInstance()->serverEnabled()) return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized); return handleCompletionRequest(request, false); @@ -120,7 +120,7 @@ void Server::start() ); m_server->route("/v1/chat/completions", QHttpServerRequest::Method::Post, - [=](const QHttpServerRequest &request) { + [this](const QHttpServerRequest &request) { if (!LLM::globalInstance()->serverEnabled()) return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized); return handleCompletionRequest(request, true);