From efd36d286348cbdbea42fc50eca0eae3c1093df4 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Mon, 19 Jan 2026 14:04:16 +0000 Subject: [PATCH] sampling: refactor sorting --- src/llama-sampling.cpp | 179 +++++++++++++++++------------------------ 1 file changed, 74 insertions(+), 105 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index acfcefd4..83819209 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -31,18 +31,82 @@ void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) { smpl->rng.seed(seed); } +static void llama_sort(llama_token_data_array * candidates, int32_t k) { + if (candidates->sorted || candidates->size < 2) { + return; + } + if (k < 0) { + k = candidates->size; + } + auto comp = [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }; + if (k <= 1024) { //128) { + if (k == int(candidates->size)) { + std::sort(candidates->data, candidates->data + candidates->size, comp); + } else { + std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp); + } + } else { + constexpr int nbuckets = 128; + constexpr float bucket_low = -10.0f; + constexpr float bucket_high = 10.0f; + constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low); + constexpr float bucker_inter = -bucket_low * bucket_scale; + + std::vector bucket_idx(candidates->size); + std::vector histo(nbuckets, 0); + + for (int i = 0; i < (int)candidates->size; ++i) { + const float val = candidates->data[i].logit; + int ib = int(bucket_scale * val + bucker_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low); + ib = std::max(0, std::min(nbuckets-1, ib)); + bucket_idx[i] = ib; + ++histo[ib]; + } + int nhave = 0; + int ib = nbuckets - 1; + for ( ; ib >= 0; --ib) { + nhave += histo[ib]; + if (nhave >= k) break; + } + std::vector tmp_tokens(nhave); + auto ptr = tmp_tokens.data(); + std::vector bucket_ptrs; + bucket_ptrs.reserve(nbuckets - ib); + for (int j = nbuckets - 1; j >= ib; --j) { + bucket_ptrs.push_back(ptr); + ptr += histo[j]; + } + for (int i = 0; i < (int)candidates->size; ++i) { + int j = bucket_idx[i]; + if (j >= ib) { + *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i]; + } + } + + ptr = tmp_tokens.data(); + int ndone = 0; + for (int j = nbuckets-1; j > ib; --j) { + std::sort(ptr, ptr + histo[j], comp); + ptr += histo[j]; + ndone += histo[j]; + } + std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp); + + std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data)); + + } + candidates->sorted = true; +} + void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { GGML_ASSERT(candidates->size > 0); const int64_t t_start_sample_us = ggml_time_us(); - // Sort the logits in descending order - if (!candidates->sorted) { - std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }); - candidates->sorted = true; - } + // Sort the logits in descending order if necessary + llama_sort(candidates, -1); float max_l = candidates->data[0].logit; float cum_sum = 0.0f; @@ -61,10 +125,6 @@ void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_ar } void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) { - // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast - // if (k >= (int32_t)candidates->size) { - // return; - // } const int64_t t_start_sample_us = ggml_time_us(); @@ -75,65 +135,8 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra k = std::max(k, (int) min_keep); k = std::min(k, (int) candidates->size); - // Sort scores in descending order - if (!candidates->sorted) { - auto comp = [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }; - if (k <= 128) { - std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp); - } else { - constexpr int nbuckets = 128; - constexpr float bucket_low = -10.0f; - constexpr float bucket_high = 10.0f; - constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low); - constexpr float bucker_inter = -bucket_low * bucket_scale; + llama_sort(candidates, k); - std::vector bucket_idx(candidates->size); - std::vector histo(nbuckets, 0); - - for (int i = 0; i < (int)candidates->size; ++i) { - const float val = candidates->data[i].logit; - int ib = int(bucket_scale * val + bucker_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low); - ib = std::max(0, std::min(nbuckets-1, ib)); - bucket_idx[i] = ib; - ++histo[ib]; - } - int nhave = 0; - int ib = nbuckets - 1; - for ( ; ib >= 0; --ib) { - nhave += histo[ib]; - if (nhave >= k) break; - } - std::vector tmp_tokens(nhave); - auto ptr = tmp_tokens.data(); - std::vector bucket_ptrs; - bucket_ptrs.reserve(nbuckets - ib); - for (int j = nbuckets - 1; j >= ib; --j) { - bucket_ptrs.push_back(ptr); - ptr += histo[j]; - } - for (int i = 0; i < (int)candidates->size; ++i) { - int j = bucket_idx[i]; - if (j >= ib) { - *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i]; - } - } - - ptr = tmp_tokens.data(); - int ndone = 0; - for (int j = nbuckets-1; j > ib; --j) { - std::sort(ptr, ptr + histo[j], comp); - ptr += histo[j]; - ndone += histo[j]; - } - std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp); - - std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data)); - - } - candidates->sorted = true; - } candidates->size = k; if (smpl) { @@ -208,13 +211,8 @@ void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_arra // if the candidates are sorted or the unsorted implementation failed, use this implementation if (!min_p_applied) { - // Sort the logits in descending order - if (!candidates->sorted) { - std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }); - candidates->sorted = true; - } + // Sort the logits in descending order if needed + llama_sort(candidates, -1); const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max size_t i = 1; // first token always matches @@ -1178,35 +1176,6 @@ void llama_prep_adaptive_p_impl(struct llama_sampling * smpl, adapt_p_ctx->cum_orig_prob = cum_prob; if (smpl) smpl->t_sample_us += ggml_time_us() - t_start; - //if (!candidates->sorted) { - // std::sort(candidates->data, candidates->data + candidates->size, - // [](const llama_token_data & a, const llama_token_data & b) { - // return a.logit > b.logit; - // }); - // candidates->sorted = true; - //} - //const float max_logit = candidates->data[0].logit; - - //// decide how many tokens to track based on logit delta - //// i.e. do not track unlikely tokens - //auto iter = std::lower_bound( - // candidates->data, - // candidates->data + candidates->size, - // max_logit - kDelta, // delta - // [](const llama_token_data & data, const float delta) { - // return data.logit > delta; - // }); - //const size_t n_track = std::distance(candidates->data, iter); - - //// store orig_prob_map and cum_orig_prob to estimate original probability later - //float cum_prob = 0.0f; - //adapt_p_ctx->orig_prob_map.clear(); - //for (size_t i = 0; i < n_track; ++i) { - // const float prob = expf(candidates->data[i].logit - max_logit); - // cum_prob += prob; - // adapt_p_ctx->orig_prob_map[candidates->data[i].id] = prob; - //} - //adapt_p_ctx->cum_orig_prob = cum_prob; } struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(