From e6fd0a240df084e0a93ae8ef5daa013d85dd87c1 Mon Sep 17 00:00:00 2001 From: aaron miller Date: Tue, 16 May 2023 23:20:08 -0700 Subject: [PATCH] backend: fix buffer overrun in repeat penalty code Caught with AddressSanitizer running a basic prompt test against llmodel standalone. This fix allows ASan builds to complete a simple prompt without illegal accesses but there are still notably several leaks. --- gpt4all-backend/gptj.cpp | 5 +++-- gpt4all-backend/llamamodel.cpp | 5 +++-- gpt4all-backend/mpt.cpp | 5 +++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/gpt4all-backend/gptj.cpp b/gpt4all-backend/gptj.cpp index 28164318..f31b4c60 100644 --- a/gpt4all-backend/gptj.cpp +++ b/gpt4all-backend/gptj.cpp @@ -993,9 +993,10 @@ void GPTJ::prompt(const std::string &prompt, gpt_vocab::id id = 0; { const int64_t t_start_sample_us = ggml_time_us(); + const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size()); id = gpt_sample_top_k_top_p(d_ptr->vocab, n_vocab, - promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.n_ctx, - promptCtx.n_ctx, + promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks, + n_prev_toks, promptCtx.logits, promptCtx.top_k, promptCtx.top_p, promptCtx.temp, promptCtx.repeat_penalty, diff --git a/gpt4all-backend/llamamodel.cpp b/gpt4all-backend/llamamodel.cpp index cccef1f9..05fcd5e1 100644 --- a/gpt4all-backend/llamamodel.cpp +++ b/gpt4all-backend/llamamodel.cpp @@ -180,9 +180,10 @@ void LLamaModel::prompt(const std::string &prompt, int32_t totalPredictions = 0; for (int i = 0; i < promptCtx.n_predict; i++) { // sample next token + const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size()); llama_token id = llama_sample_top_p_top_k(d_ptr->ctx, - promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.repeat_last_n, - promptCtx.repeat_last_n, promptCtx.top_k, promptCtx.top_p, promptCtx.temp, + promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks, + n_prev_toks, promptCtx.top_k, promptCtx.top_p, promptCtx.temp, promptCtx.repeat_penalty); // Check if the context has run out... diff --git a/gpt4all-backend/mpt.cpp b/gpt4all-backend/mpt.cpp index 0eeb9211..42a2aaae 100644 --- a/gpt4all-backend/mpt.cpp +++ b/gpt4all-backend/mpt.cpp @@ -918,9 +918,10 @@ void MPT::prompt(const std::string &prompt, int id = 0; { const int64_t t_start_sample_us = ggml_time_us(); + const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size()); id = gpt_sample_top_k_top_p(d_ptr->vocab, n_vocab, - promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.n_ctx, - promptCtx.n_ctx, + promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks, + n_prev_toks, promptCtx.logits, promptCtx.top_k, promptCtx.top_p, promptCtx.temp, promptCtx.repeat_penalty,