backend: fix buffer overrun in repeat penalty code

Caught with AddressSanitizer running a basic prompt test against llmodel
standalone. This fix allows ASan builds to complete a simple prompt
without illegal accesses but there are still notably several leaks.
This commit is contained in:
aaron miller 2023-05-16 23:20:08 -07:00 committed by AT
parent 4f2b7f7be4
commit 08f3bd2a82
3 changed files with 9 additions and 6 deletions

View File

@ -993,9 +993,10 @@ void GPTJ::prompt(const std::string &prompt,
gpt_vocab::id id = 0; gpt_vocab::id id = 0;
{ {
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size());
id = gpt_sample_top_k_top_p(d_ptr->vocab, n_vocab, id = gpt_sample_top_k_top_p(d_ptr->vocab, n_vocab,
promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.n_ctx, promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks,
promptCtx.n_ctx, n_prev_toks,
promptCtx.logits, promptCtx.logits,
promptCtx.top_k, promptCtx.top_p, promptCtx.temp, promptCtx.top_k, promptCtx.top_p, promptCtx.temp,
promptCtx.repeat_penalty, promptCtx.repeat_penalty,

View File

@ -180,9 +180,10 @@ void LLamaModel::prompt(const std::string &prompt,
int32_t totalPredictions = 0; int32_t totalPredictions = 0;
for (int i = 0; i < promptCtx.n_predict; i++) { for (int i = 0; i < promptCtx.n_predict; i++) {
// sample next token // sample next token
const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size());
llama_token id = llama_sample_top_p_top_k(d_ptr->ctx, llama_token id = llama_sample_top_p_top_k(d_ptr->ctx,
promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.repeat_last_n, promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks,
promptCtx.repeat_last_n, promptCtx.top_k, promptCtx.top_p, promptCtx.temp, n_prev_toks, promptCtx.top_k, promptCtx.top_p, promptCtx.temp,
promptCtx.repeat_penalty); promptCtx.repeat_penalty);
// Check if the context has run out... // Check if the context has run out...

View File

@ -918,9 +918,10 @@ void MPT::prompt(const std::string &prompt,
int id = 0; int id = 0;
{ {
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size());
id = gpt_sample_top_k_top_p(d_ptr->vocab, n_vocab, id = gpt_sample_top_k_top_p(d_ptr->vocab, n_vocab,
promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.n_ctx, promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks,
promptCtx.n_ctx, n_prev_toks,
promptCtx.logits, promptCtx.logits,
promptCtx.top_k, promptCtx.top_p, promptCtx.temp, promptCtx.top_k, promptCtx.top_p, promptCtx.temp,
promptCtx.repeat_penalty, promptCtx.repeat_penalty,