From 9afbaee94ed931aa6f84b467971fc0e142c6427f Mon Sep 17 00:00:00 2001 From: Aaron Miller Date: Thu, 8 Jun 2023 11:08:30 -0700 Subject: [PATCH] non-llama: explicitly greedy sampling for temp<=0 (#901) copied directly from llama.cpp - without this temp=0.0 will just scale all the logits to infinity and give bad output --- gpt4all-backend/utils.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/gpt4all-backend/utils.cpp b/gpt4all-backend/utils.cpp index 46827980..5dc44fad 100644 --- a/gpt4all-backend/utils.cpp +++ b/gpt4all-backend/utils.cpp @@ -232,6 +232,19 @@ gpt_vocab::id gpt_sample_top_k_top_p( const auto last_n_tokens = std::vector(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size); const auto * plogits = logits.data(); + if (temp <= 0) { + // select the token with the highest logit directly + float max_logit = plogits[0]; + gpt_vocab::id max_id = 0; + + for (int i = 1; i < n_logits; ++i) { + if (plogits[i] > max_logit) { + max_logit = plogits[i]; + max_id = i; + } + } + return max_id; + } std::vector> logits_id; logits_id.reserve(n_logits);