diff --git a/gpt4all-backend/utils.cpp b/gpt4all-backend/utils.cpp index 46827980..5dc44fad 100644 --- a/gpt4all-backend/utils.cpp +++ b/gpt4all-backend/utils.cpp @@ -232,6 +232,19 @@ gpt_vocab::id gpt_sample_top_k_top_p( const auto last_n_tokens = std::vector(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size); const auto * plogits = logits.data(); + if (temp <= 0) { + // select the token with the highest logit directly + float max_logit = plogits[0]; + gpt_vocab::id max_id = 0; + + for (int i = 1; i < n_logits; ++i) { + if (plogits[i] > max_logit) { + max_logit = plogits[i]; + max_id = i; + } + } + return max_id; + } std::vector> logits_id; logits_id.reserve(n_logits);