Faster adaptive_p sampling (#1165)

* A hopefully more efficient adaptive_p sampling

* Once at it, lets fix the formatting too

* More formatting

* Hopefully better

* This should be better

* Correctly accumulate adaptive_p sampling time

* AVX2
This commit is contained in:
Kawrakow
2026-01-19 16:03:09 +02:00
committed by GitHub
parent fa58c20c42
commit 98b30e5e81
7 changed files with 87 additions and 91 deletions

View File

@@ -118,7 +118,9 @@ struct llama_sampling_context * common_sampler_init(const struct llama_vocab* vo
}
case llama_sampler_type::ADAPTIVE_P:
{
result->adapt_p_ctx = llama_init_adaptive_p(params.adaptive_target, params.adaptive_decay, result->rng());
GGML_ASSERT(vocab);
auto n_vocab = llama_vocab_n_tokens(vocab);
result->adapt_p_ctx = llama_init_adaptive_p(n_vocab, params.adaptive_target, params.adaptive_decay, result->rng());
break;
}
default:

View File

@@ -16,6 +16,7 @@
#include <algorithm>
#include <cmath>
#include <cstring>
//#include <thread>
#ifdef __ARM_NEON
#include <arm_neon.h>
@@ -503,3 +504,49 @@ void iqk_hadamard(struct ggml_tensor * dst, int ith, int nth) {
fast_ht(nh, y);
}
}
namespace {
float iqk_exp_with_thresh_impl(int n, float * logits, float max, float min) {
float sum = 0;
#ifdef __AVX2__
auto vmax = _mm256_set1_ps(max);
auto vmin = _mm256_set1_ps(min);
auto vsum = _mm256_setzero_ps();
for (int j = 0; j < n/8; ++j) {
auto x = _mm256_loadu_ps(logits);
auto mask = _mm256_cmp_ps(x, vmin, _CMP_GE_OQ);
auto exp_x = v_expf(_mm256_sub_ps(x, vmax));
exp_x = _mm256_and_ps(exp_x, mask);
vsum = _mm256_add_ps(vsum, exp_x);
_mm256_storeu_ps(logits, exp_x);
logits += 8;
}
sum = hsum_float_8(vsum);
for (int j = 0; j < n - 8*(n/8); ++j) {
float p = logits[j] > min ? expf(logits[j] - max) : 0;
sum += p;
logits[j] = p;
}
#else
for (int j = 0; j < n; ++j) {
float p = logits[j] > min ? expf(logits[j] - max) : 0;
sum += p;
logits[j] = p;
}
#endif
return sum;
}
}
float iqk_exp_with_thresh(int n, float * logits, float max, float min) {
return iqk_exp_with_thresh_impl(n, logits, max, min);
//if (n < (1 << 16)) return iqk_exp_with_thresh_impl(n, logits, max, min);
//std::array<float, 2> result;
//auto compute = [logits, max, min, &result] (int first, int last, int ith) {
// result[ith] = iqk_exp_with_thresh_impl(last - first, logits + first, max, min);
//};
//auto t = std::thread(compute, 0, n/2, 0);
//compute(n/2, n, 1);
//t.join();
//return result[0] + result[1];
}

View File

@@ -30,6 +30,8 @@ void iqk_mul_multi_add(struct ggml_tensor * dst, int ith, int nth);
void iqk_hadamard(struct ggml_tensor * dst, int ith, int nth);
float iqk_exp_with_thresh(int n, float * logits, float max, float min);
#ifdef __cplusplus
}
#endif

View File

@@ -1384,7 +1384,7 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns(
/// @details Adaptive p sampler initializer
/// @param target Select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled)
/// @param decay Decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation)
LLAMA_API struct llama_sampler_adaptive_p * llama_init_adaptive_p(
LLAMA_API struct llama_sampler_adaptive_p * llama_init_adaptive_p(int n_vocab,
const float target,
const float decay,
const uint32_t seed);
@@ -1394,8 +1394,7 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns(
struct llama_sampler_adaptive_p * adapt_p_ctx);
/// @details Adaptive p sampler described in https://github.com/MrJackSpade/adaptive-p-docs/blob/main/README.md
void llama_sample_adaptive_p(
struct llama_context * ctx,
void llama_sample_adaptive_p(struct llama_context * ctx,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx);

View File

@@ -2,6 +2,8 @@
#include "llama-vocab.h"
#include "llama-grammar.h"
#include "iqk/iqk_cpu_ops.h"
#include <algorithm>
#include <cstring>
#include <ctime>
@@ -1061,8 +1063,8 @@ llama_token llama_sample_token_adaptive_p_impl(
const size_t idx = std::distance(ctx->cum_probs.begin(), iter);
llama_token id = candidates->data[idx].id;
if (auto it = ctx->orig_prob_map.find(id); it != ctx->orig_prob_map.end()) {
float update_prob = it->second / ctx->cum_orig_prob;
GGML_ASSERT(id < int(ctx->orig_prob.size()));
if (auto update_prob = ctx->orig_prob[id]; update_prob > 0) {
ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob;
ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f;
}
@@ -1070,16 +1072,6 @@ llama_token llama_sample_token_adaptive_p_impl(
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
smpl->n_sample++;
//float update_prob = candidates->data[idx].p; // not ideal
//if (ctx->orig_prob_map.contains(id)) {
// // selected token id is among tracked ids
// update_prob = ctx->orig_prob_map[id] / ctx->cum_orig_prob;
//}
//// update history with original probability of selected token
//ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob;
//ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f;
return id;
}
@@ -1137,94 +1129,49 @@ void llama_sample_adaptive_p_impl(struct llama_sampling * ctx, llama_token_data_
ctx->t_sample_us += ggml_time_us() - t_start;
}
void llama_prep_adaptive_p_impl(struct llama_sampling * smpl,
void llama_prep_adaptive_p_impl(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx) {
constexpr float kDelta = 16.6f;
constexpr float kDelta = 30.0f; //16.6f;
auto t_start = ggml_time_us();
if (!candidates->sorted) {
float max_logit = candidates->data[0].logit;
for (int j = 1; j < int(candidates->size); ++j) {
max_logit = std::max(max_logit, candidates->data[j].logit);
auto & orig_prob = adapt_p_ctx->orig_prob;
if (candidates->size != orig_prob.size() || candidates->sorted) {
LLAMA_LOG_ERROR("%s: this function must be called before any other sampler has been applied\n", __func__);
LLAMA_LOG_ERROR("%s: the sampler has been initialized with a vocabulary of %zu, but is being called with %zu candidates\n",
__func__, orig_prob.size(), candidates->size);
GGML_ABORT("Bad candidates in adaptive_p sampler");
}
float min_logit = max_logit - kDelta;
float cum_prob = 0.0f;
adapt_p_ctx->orig_prob_map.clear();
float max_logit = -INFINITY;
for (int j = 0; j < int(candidates->size); ++j) {
if (candidates->data[j].logit > min_logit) {
float prob = expf(candidates->data[j].logit - max_logit);
cum_prob += prob;
adapt_p_ctx->orig_prob_map[candidates->data[j].id] = prob;
orig_prob[j] = candidates->data[j].logit;
max_logit = std::max(max_logit, orig_prob[j]);
}
}
adapt_p_ctx->cum_orig_prob = cum_prob;
adapt_p_ctx->cum_orig_prob = iqk_exp_with_thresh(orig_prob.size(), orig_prob.data(), max_logit, max_logit - kDelta);
if (smpl) smpl->t_sample_us += ggml_time_us() - t_start;
return;
}
float max_logit = candidates->data[0].logit;
float min_logit = max_logit - kDelta;
float cum_prob = 0.0f;
adapt_p_ctx->orig_prob_map.clear();
for (int j = 0; j < int(candidates->size); ++j) {
auto logit = candidates->data[j].logit;
if (logit <= min_logit) {
break;
}
float prob = expf(logit - max_logit);
cum_prob += prob;
adapt_p_ctx->orig_prob_map[candidates->data[j].id] = prob;
}
adapt_p_ctx->cum_orig_prob = cum_prob;
if (smpl) smpl->t_sample_us += ggml_time_us() - t_start;
//if (!candidates->sorted) {
// std::sort(candidates->data, candidates->data + candidates->size,
// [](const llama_token_data & a, const llama_token_data & b) {
// return a.logit > b.logit;
// });
// candidates->sorted = true;
//}
//const float max_logit = candidates->data[0].logit;
//// decide how many tokens to track based on logit delta
//// i.e. do not track unlikely tokens
//auto iter = std::lower_bound(
// candidates->data,
// candidates->data + candidates->size,
// max_logit - kDelta, // delta
// [](const llama_token_data & data, const float delta) {
// return data.logit > delta;
// });
//const size_t n_track = std::distance(candidates->data, iter);
//// store orig_prob_map and cum_orig_prob to estimate original probability later
//float cum_prob = 0.0f;
//adapt_p_ctx->orig_prob_map.clear();
//for (size_t i = 0; i < n_track; ++i) {
// const float prob = expf(candidates->data[i].logit - max_logit);
// cum_prob += prob;
// adapt_p_ctx->orig_prob_map[candidates->data[i].id] = prob;
//}
//adapt_p_ctx->cum_orig_prob = cum_prob;
}
struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(
struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab,
const float target,
const float decay,
const uint32_t seed) {
GGML_ASSERT(n_vocab > 0);
const float clamped_decay = std::clamp(decay, 0.0f, 0.99f);
return new llama_sampler_adaptive_p {
auto result = new llama_sampler_adaptive_p {
/* .target = */ target,
/* .decay = */ clamped_decay,
/* .rng = */ std::mt19937(seed),
/* .weighted_sum = */ target / (1.0f - clamped_decay),
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
/* .orig_logit_map = */ {},
/* .orig_prob = */ {},
/* .cum_orig_prob = */ 0.0f,
/* .max_xform_logit = */ -INFINITY,
/* .cum_probs = */ {},
};
result->orig_prob.resize(n_vocab);
return result;
}
// grammar

View File

@@ -73,7 +73,7 @@ struct llama_sampler_adaptive_p {
float total_weight; // sum(decay^i), converges to 1/(1-decay)
// first referenced in prep
std::unordered_map<llama_token, float> orig_prob_map; // probabilities before sampler_queue
std::vector<float> orig_prob; // for storing the original proibabilities
float cum_orig_prob; // for normalizing orig_prob in sample_token
// first referenced in sample
@@ -83,7 +83,7 @@ struct llama_sampler_adaptive_p {
std::vector<float> cum_probs; // cumulative probability distribution
};
struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(
struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab,
const float target,
const float decay,
const uint32_t seed);

View File

@@ -7687,10 +7687,9 @@ void llama_sample_dry([[maybe_unused]] struct llama_context* ctx, struct llama_s
llama_sampler_dry_apply(smpl, candidates_p);
}
void llama_sample_adaptive_p(
[[maybe_unused]] struct llama_context * ctx,
void llama_sample_adaptive_p(llama_context * ctx,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx) {
llama_sampler_adaptive_p * adapt_p_ctx) {
llama_sample_adaptive_p_impl(&ctx->sampling, candidates, adapt_p_ctx);
}
@@ -7797,8 +7796,8 @@ void llama_sampler_dry_accept(struct llama_sampler_dry* smpl, llama_token token)
}
struct llama_sampler_adaptive_p * llama_init_adaptive_p(const float target, const float decay, const uint32_t seed) {
return llama_init_adaptive_p_impl(target, decay, seed);
struct llama_sampler_adaptive_p * llama_init_adaptive_p(int n_vocab, const float target, const float decay, const uint32_t seed) {
return llama_init_adaptive_p_impl(n_vocab, target, decay, seed);
}