diff --git a/gpt4all-backend/gptj.cpp b/gpt4all-backend/gptj.cpp index 69b97fdf..5031ecdf 100644 --- a/gpt4all-backend/gptj.cpp +++ b/gpt4all-backend/gptj.cpp @@ -345,6 +345,13 @@ bool gptj_eval( struct ggml_context * ctx0 = ggml_init(params); struct ggml_cgraph gf = {}; + // KQ_pos - contains the positions + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + int * data = (int *) KQ_pos->data; + for (int i = 0; i < N; ++i) { + data[i] = n_past + i; + } + struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd)); @@ -370,8 +377,14 @@ bool gptj_eval( // self-attention { - struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_q_proj_w, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0); - struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_k_proj_w, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0); + struct ggml_tensor * Qcur = ggml_rope( + ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_q_proj_w, cur), n_embd/n_head, n_head, N), + KQ_pos, n_rot, 0, 0 + ); + struct ggml_tensor * Kcur = ggml_rope( + ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_k_proj_w, cur), n_embd/n_head, n_head, N), + KQ_pos, n_rot, 0, 0 + ); // store key and value to memory { diff --git a/gpt4all-backend/llama.cpp-mainline b/gpt4all-backend/llama.cpp-mainline index 74f977c1..abd7dc4e 160000 --- a/gpt4all-backend/llama.cpp-mainline +++ b/gpt4all-backend/llama.cpp-mainline @@ -1 +1 @@ -Subproject commit 74f977c196286e937fc3a40af9f1638f018761a8 +Subproject commit abd7dc4e89c92384017cc1ddb772e1d092055b3e diff --git a/gpt4all-backend/llama.cpp.cmake b/gpt4all-backend/llama.cpp.cmake index 72c3fee1..62b57226 100644 --- a/gpt4all-backend/llama.cpp.cmake +++ b/gpt4all-backend/llama.cpp.cmake @@ -77,7 +77,6 @@ option(LLAMA_OPENBLAS "llama: use OpenBLAS" #option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) #option(LLAMA_CLBLAST "llama: use CLBlast" OFF) #option(LLAMA_METAL "llama: use Metal" OFF) -#option(LLAMA_K_QUANTS "llama: use k-quants" ON) set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor") set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels") @@ -228,6 +227,7 @@ if (LLAMA_KOMPUTE) # Compile our shaders compile_shader(SOURCES kompute/op_scale.comp + kompute/op_scale_8.comp kompute/op_add.comp kompute/op_addrow.comp kompute/op_mul.comp @@ -249,7 +249,8 @@ if (LLAMA_KOMPUTE) kompute/op_getrows_q4_0.comp kompute/op_getrows_q4_1.comp kompute/op_getrows_q6_k.comp - kompute/op_rope.comp + kompute/op_rope_f16.comp + kompute/op_rope_f32.comp kompute/op_cpy_f16_f16.comp kompute/op_cpy_f16_f32.comp kompute/op_cpy_f32_f16.comp @@ -259,6 +260,7 @@ if (LLAMA_KOMPUTE) # Create a custom target for our generated shaders add_custom_target(generated_shaders DEPENDS shaderop_scale.h + shaderop_scale_8.h shaderop_add.h shaderop_addrow.h shaderop_mul.h @@ -280,7 +282,8 @@ if (LLAMA_KOMPUTE) shaderop_getrows_q4_0.h shaderop_getrows_q4_1.h shaderop_getrows_q6_k.h - shaderop_rope.h + shaderop_rope_f16.h + shaderop_rope_f32.h shaderop_cpy_f16_f16.h shaderop_cpy_f16_f32.h shaderop_cpy_f32_f16.h @@ -564,33 +567,26 @@ function(include_ggml DIRECTORY SUFFIX WITH_LLAMA) endif() endif() - set(GGML_SOURCES_QUANT_K ) - set(GGML_METAL_SOURCES ) - if (LLAMA_K_QUANTS) - set(GGML_SOURCES_QUANT_K - ${DIRECTORY}/k_quants.h - ${DIRECTORY}/k_quants.c) + set(GGML_METAL_SOURCES) + if (LLAMA_METAL) + find_library(FOUNDATION_LIBRARY Foundation REQUIRED) + find_library(METAL_FRAMEWORK Metal REQUIRED) + find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) + find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED) - if (LLAMA_METAL) - find_library(FOUNDATION_LIBRARY Foundation REQUIRED) - find_library(METAL_FRAMEWORK Metal REQUIRED) - find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) - find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED) + set(GGML_METAL_SOURCES ${DIRECTORY}/ggml-metal.m ${DIRECTORY}/ggml-metal.h) + # get full path to the file + #add_compile_definitions(GGML_METAL_DIR_KERNELS="${CMAKE_CURRENT_SOURCE_DIR}/") - set(GGML_METAL_SOURCES ${DIRECTORY}/ggml-metal.m ${DIRECTORY}/ggml-metal.h) - # get full path to the file - #add_compile_definitions(GGML_METAL_DIR_KERNELS="${CMAKE_CURRENT_SOURCE_DIR}/") + # copy ggml-metal.metal to bin directory + configure_file(${DIRECTORY}/ggml-metal.metal bin/ggml-metal.metal COPYONLY) - # copy ggml-metal.metal to bin directory - configure_file(${DIRECTORY}/ggml-metal.metal bin/ggml-metal.metal COPYONLY) - - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} - ${FOUNDATION_LIBRARY} - ${METAL_FRAMEWORK} - ${METALKIT_FRAMEWORK} - ${METALPERFORMANCE_FRAMEWORK} - ) - endif() + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} + ${FOUNDATION_LIBRARY} + ${METAL_FRAMEWORK} + ${METALKIT_FRAMEWORK} + ${METALPERFORMANCE_FRAMEWORK} + ) endif() add_library(ggml${SUFFIX} OBJECT @@ -598,16 +594,15 @@ function(include_ggml DIRECTORY SUFFIX WITH_LLAMA) ${DIRECTORY}/ggml.h ${DIRECTORY}/ggml-alloc.c ${DIRECTORY}/ggml-alloc.h - ${GGML_SOURCES_QUANT_K} + ${DIRECTORY}/ggml-backend.c + ${DIRECTORY}/ggml-backend.h + ${DIRECTORY}/ggml-quants.h + ${DIRECTORY}/ggml-quants.c ${GGML_SOURCES_CUDA} ${GGML_METAL_SOURCES} ${GGML_OPENCL_SOURCES} ${GGML_SOURCES_KOMPUTE}) - if (LLAMA_K_QUANTS) - target_compile_definitions(ggml${SUFFIX} PUBLIC GGML_USE_K_QUANTS) - endif() - if (LLAMA_METAL AND GGML_METAL_SOURCES) target_compile_definitions(ggml${SUFFIX} PUBLIC GGML_USE_METAL GGML_METAL_NDEBUG) endif() diff --git a/gpt4all-backend/llamamodel.cpp b/gpt4all-backend/llamamodel.cpp index 491a80c6..65374854 100644 --- a/gpt4all-backend/llamamodel.cpp +++ b/gpt4all-backend/llamamodel.cpp @@ -71,9 +71,10 @@ static int llama_sample_top_p_top_k( int top_k, float top_p, float temp, - float repeat_penalty) { - auto logits = llama_get_logits(ctx); - auto n_vocab = llama_n_vocab(ctx); + float repeat_penalty, + int32_t pos) { + auto logits = llama_get_logits_ith(ctx, pos); + auto n_vocab = llama_n_vocab(llama_get_model(ctx)); // Populate initial list of all candidates std::vector candidates; candidates.reserve(n_vocab); @@ -82,21 +83,23 @@ static int llama_sample_top_p_top_k( } llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false}; // Sample repeat penalty - llama_sample_repetition_penalty(nullptr, &candidates_p, last_n_tokens_data, last_n_tokens_size, repeat_penalty); + llama_sample_repetition_penalties(nullptr, &candidates_p, last_n_tokens_data, last_n_tokens_size, repeat_penalty, 0.0f, 0.0f); // Temperature sampling llama_sample_top_k(ctx, &candidates_p, top_k, 1); llama_sample_tail_free(ctx, &candidates_p, 1.0f, 1); llama_sample_typical(ctx, &candidates_p, 1.0f, 1); llama_sample_top_p(ctx, &candidates_p, top_p, 1); - llama_sample_temperature(ctx, &candidates_p, temp); + llama_sample_temp(ctx, &candidates_p, temp); return llama_sample_token(ctx, &candidates_p); } struct LLamaPrivate { const std::string modelPath; bool modelLoaded; + llama_model *model = nullptr; llama_context *ctx = nullptr; - llama_context_params params; + llama_model_params model_params; + llama_context_params ctx_params; int64_t n_threads = 0; std::vector end_tokens; }; @@ -142,37 +145,46 @@ size_t LLamaModel::requiredMem(const std::string &modelPath) { bool LLamaModel::loadModel(const std::string &modelPath) { - // load the model - d_ptr->params = llama_context_default_params(); - gpt_params params; - d_ptr->params.n_ctx = 2048; - d_ptr->params.seed = params.seed; - d_ptr->params.f16_kv = params.memory_f16; - d_ptr->params.use_mmap = params.use_mmap; + + // load the model + d_ptr->model_params = llama_model_default_params(); + + d_ptr->model_params.use_mmap = params.use_mmap; #if defined (__APPLE__) - d_ptr->params.use_mlock = true; + d_ptr->model_params.use_mlock = true; #else - d_ptr->params.use_mlock = params.use_mlock; + d_ptr->model_params.use_mlock = params.use_mlock; #endif + + d_ptr->ctx_params = llama_context_default_params(); + + d_ptr->ctx_params.n_ctx = 2048; + d_ptr->ctx_params.seed = params.seed; + d_ptr->ctx_params.f16_kv = params.memory_f16; + + d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + d_ptr->ctx_params.n_threads = d_ptr->n_threads; + d_ptr->ctx_params.n_threads_batch = d_ptr->n_threads; + #ifdef GGML_USE_METAL if (llama_verbose()) { std::cerr << "llama.cpp: using Metal" << std::endl; } // metal always runs the whole model if n_gpu_layers is not 0, at least // currently - d_ptr->params.n_gpu_layers = 1; + d_ptr->model_params.n_gpu_layers = 1; #endif #ifdef GGML_USE_KOMPUTE if (ggml_vk_has_device()) { // vulkan always runs the whole model if n_gpu_layers is not 0, at least // currently - d_ptr->params.n_gpu_layers = 1; + d_ptr->model_params.n_gpu_layers = 1; } #endif - d_ptr->ctx = llama_init_from_file(modelPath.c_str(), d_ptr->params); - if (!d_ptr->ctx) { + d_ptr->model = llama_load_model_from_file_gpt4all(modelPath.c_str(), &d_ptr->model_params); + if (!d_ptr->model) { #ifdef GGML_USE_KOMPUTE // Explicitly free the device so next load it doesn't use it ggml_vk_free_device(); @@ -181,7 +193,17 @@ bool LLamaModel::loadModel(const std::string &modelPath) return false; } - d_ptr->end_tokens = {llama_token_eos(d_ptr->ctx)}; + d_ptr->ctx = llama_new_context_with_model(d_ptr->model, d_ptr->ctx_params); + if (!d_ptr->ctx) { +#ifdef GGML_USE_KOMPUTE + // Explicitly free the device so next load it doesn't use it + ggml_vk_free_device(); +#endif + std::cerr << "LLAMA ERROR: failed to init context for model " << modelPath << std::endl; + return false; + } + + d_ptr->end_tokens = {llama_token_eos(d_ptr->model)}; #ifdef GGML_USE_KOMPUTE if (ggml_vk_has_device()) { @@ -189,7 +211,6 @@ bool LLamaModel::loadModel(const std::string &modelPath) } #endif - d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); d_ptr->modelLoaded = true; fflush(stderr); return true; @@ -197,6 +218,7 @@ bool LLamaModel::loadModel(const std::string &modelPath) void LLamaModel::setThreadCount(int32_t n_threads) { d_ptr->n_threads = n_threads; + llama_set_n_threads(d_ptr->ctx, n_threads, n_threads); } int32_t LLamaModel::threadCount() const { @@ -208,6 +230,7 @@ LLamaModel::~LLamaModel() if (d_ptr->ctx) { llama_free(d_ptr->ctx); } + llama_free_model(d_ptr->model); } bool LLamaModel::isModelLoaded() const @@ -233,16 +256,17 @@ size_t LLamaModel::restoreState(const uint8_t *src) std::vector LLamaModel::tokenize(PromptContext &ctx, const std::string &str) const { - const bool useBOS = ctx.n_past == 0 && (ctx.tokens.empty() || ctx.tokens.front() != llama_token_bos(d_ptr->ctx)); + const bool useBOS = ctx.n_past == 0 && (ctx.tokens.empty() || ctx.tokens.front() != llama_token_bos(d_ptr->model)); std::vector fres(str.size()+4); - auto fres_len = llama_tokenize(d_ptr->ctx, str.c_str(), str.length(), fres.data(), fres.size(), useBOS); + // TODO(cebtenzzre): we may want to use special=true here to process special tokens + auto fres_len = llama_tokenize(d_ptr->model, str.c_str(), str.length(), fres.data(), fres.size(), useBOS, false); fres.resize(fres_len); return fres; } std::string LLamaModel::tokenToString(Token id) const { - return llama_token_to_str(d_ptr->ctx, id); + return llama_token_to_piece(d_ptr->ctx, id); } LLModel::Token LLamaModel::sampleToken(PromptContext &promptCtx) const @@ -251,12 +275,30 @@ LLModel::Token LLamaModel::sampleToken(PromptContext &promptCtx) const return llama_sample_top_p_top_k(d_ptr->ctx, promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks, n_prev_toks, promptCtx.top_k, promptCtx.top_p, promptCtx.temp, - promptCtx.repeat_penalty); + promptCtx.repeat_penalty, promptCtx.n_last_batch_tokens - 1); } bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector &tokens) const { - return llama_eval(d_ptr->ctx, tokens.data(), tokens.size(), ctx.n_past, d_ptr->n_threads) == 0; + llama_batch batch = llama_batch_init(tokens.size(), 0, 1); + + batch.n_tokens = tokens.size(); + ctx.n_last_batch_tokens = tokens.size(); + + for (int32_t i = 0; i < batch.n_tokens; i++) { + batch.token [i] = tokens[i]; + batch.pos [i] = ctx.n_past + i; + batch.n_seq_id[i] = 1; + batch.seq_id [i][0] = 0; + batch.logits [i] = false; + } + + // llama_decode will output logits only for the last token of the prompt + batch.logits[batch.n_tokens - 1] = true; + + int res = llama_decode(d_ptr->ctx, batch); + llama_batch_free(batch); + return res == 0; } int32_t LLamaModel::contextLength() const diff --git a/gpt4all-backend/llmodel.h b/gpt4all-backend/llmodel.h index 3001281b..5fdabc30 100644 --- a/gpt4all-backend/llmodel.h +++ b/gpt4all-backend/llmodel.h @@ -54,8 +54,8 @@ public: int32_t n_batch = 9; float repeat_penalty = 1.10f; int32_t repeat_last_n = 64; // last n tokens to penalize - float contextErase = 0.75f; // percent of context to erase if we exceed the context - // window + float contextErase = 0.75f; // percent of context to erase if we exceed the context window + int32_t n_last_batch_tokens = 0; }; struct GPUDevice {