diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index b3106937..b6e929c9 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -33,18 +33,82 @@ void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) { smpl->rng.seed(seed); } +static void llama_sort(llama_token_data_array * candidates, int32_t k) { + if (candidates->sorted || candidates->size < 2) { + return; + } + if (k < 0) { + k = candidates->size; + } + auto comp = [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }; + if (k <= 1024) { //128) { + if (k == int(candidates->size)) { + std::sort(candidates->data, candidates->data + candidates->size, comp); + } else { + std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp); + } + } else { + constexpr int nbuckets = 128; + constexpr float bucket_low = -10.0f; + constexpr float bucket_high = 10.0f; + constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low); + constexpr float bucker_inter = -bucket_low * bucket_scale; + + std::vector bucket_idx(candidates->size); + std::vector histo(nbuckets, 0); + + for (int i = 0; i < (int)candidates->size; ++i) { + const float val = candidates->data[i].logit; + int ib = int(bucket_scale * val + bucker_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low); + ib = std::max(0, std::min(nbuckets-1, ib)); + bucket_idx[i] = ib; + ++histo[ib]; + } + int nhave = 0; + int ib = nbuckets - 1; + for ( ; ib >= 0; --ib) { + nhave += histo[ib]; + if (nhave >= k) break; + } + std::vector tmp_tokens(nhave); + auto ptr = tmp_tokens.data(); + std::vector bucket_ptrs; + bucket_ptrs.reserve(nbuckets - ib); + for (int j = nbuckets - 1; j >= ib; --j) { + bucket_ptrs.push_back(ptr); + ptr += histo[j]; + } + for (int i = 0; i < (int)candidates->size; ++i) { + int j = bucket_idx[i]; + if (j >= ib) { + *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i]; + } + } + + ptr = tmp_tokens.data(); + int ndone = 0; + for (int j = nbuckets-1; j > ib; --j) { + std::sort(ptr, ptr + histo[j], comp); + ptr += histo[j]; + ndone += histo[j]; + } + std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp); + + std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data)); + + } + candidates->sorted = true; +} + void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { GGML_ASSERT(candidates->size > 0); const int64_t t_start_sample_us = ggml_time_us(); - // Sort the logits in descending order - if (!candidates->sorted) { - std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }); - candidates->sorted = true; - } + // Sort the logits in descending order if necessary + llama_sort(candidates, -1); float max_l = candidates->data[0].logit; float cum_sum = 0.0f; @@ -63,10 +127,6 @@ void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_ar } void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) { - // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast - // if (k >= (int32_t)candidates->size) { - // return; - // } const int64_t t_start_sample_us = ggml_time_us(); @@ -77,65 +137,8 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra k = std::max(k, (int) min_keep); k = std::min(k, (int) candidates->size); - // Sort scores in descending order - if (!candidates->sorted) { - auto comp = [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }; - if (k <= 128) { - std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp); - } else { - constexpr int nbuckets = 128; - constexpr float bucket_low = -10.0f; - constexpr float bucket_high = 10.0f; - constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low); - constexpr float bucker_inter = -bucket_low * bucket_scale; + llama_sort(candidates, k); - std::vector bucket_idx(candidates->size); - std::vector histo(nbuckets, 0); - - for (int i = 0; i < (int)candidates->size; ++i) { - const float val = candidates->data[i].logit; - int ib = int(bucket_scale * val + bucker_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low); - ib = std::max(0, std::min(nbuckets-1, ib)); - bucket_idx[i] = ib; - ++histo[ib]; - } - int nhave = 0; - int ib = nbuckets - 1; - for ( ; ib >= 0; --ib) { - nhave += histo[ib]; - if (nhave >= k) break; - } - std::vector tmp_tokens(nhave); - auto ptr = tmp_tokens.data(); - std::vector bucket_ptrs; - bucket_ptrs.reserve(nbuckets - ib); - for (int j = nbuckets - 1; j >= ib; --j) { - bucket_ptrs.push_back(ptr); - ptr += histo[j]; - } - for (int i = 0; i < (int)candidates->size; ++i) { - int j = bucket_idx[i]; - if (j >= ib) { - *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i]; - } - } - - ptr = tmp_tokens.data(); - int ndone = 0; - for (int j = nbuckets-1; j > ib; --j) { - std::sort(ptr, ptr + histo[j], comp); - ptr += histo[j]; - ndone += histo[j]; - } - std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp); - - std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data)); - - } - candidates->sorted = true; - } candidates->size = k; if (smpl) { @@ -210,13 +213,8 @@ void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_arra // if the candidates are sorted or the unsorted implementation failed, use this implementation if (!min_p_applied) { - // Sort the logits in descending order - if (!candidates->sorted) { - std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }); - candidates->sorted = true; - } + // Sort the logits in descending order if needed + llama_sort(candidates, -1); const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max size_t i = 1; // first token always matches @@ -313,10 +311,9 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar } // Compute the absolute difference between negative log probability and entropy for each candidate - std::vector shifted_scores; + std::vector shifted_scores(candidates->size); for (size_t i = 0; i < candidates->size; ++i) { - float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy); - shifted_scores.push_back(shifted_score); + shifted_scores[i] = fabsf(-logf(candidates->data[i].p) - entropy); } // Sort tokens based on the shifted_scores and their corresponding indices @@ -343,10 +340,10 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar } // Resize the output vector to keep only the locally typical tokens - std::vector new_candidates; + std::vector new_candidates(last_idx); for (size_t i = 0; i < last_idx; ++i) { size_t idx = indices[i]; - new_candidates.push_back(candidates->data[idx]); + new_candidates[i] = candidates->data[idx]; } // Replace the data in candidates with the new_candidates data