Implement Adaptive-P Sampler (#1100)

* initial implementation of adaptive-p sampler

* explicitly mark candidates unsorted + cleanup qualifiers

* cosmetic update

* reorg prototypes

* lockstep with mainline

* add _impl for _init + reorg

* add LLAMA_API to prototypes

* update sharpness to 10

* lockstep: rng seed

* delete llama_sampling member in llama_sampler_adaptive_p

* fix LLAMA_API return type

* lockstep: rng seed cont

* actually correct implementation

* lockstep: sorting behavior

* const -> constexpr for known constants

* add missing space

* fix softmax usage in adaptive p sampler

* cosmetic changes

* implement do-not-sort version of softmax

* simpify rng seed, add static to constexpr

* refactor: remove iface + use shared rng + use actually original probabilities

* adaptive-p: add dedicated rng back in

* fix initial max_logit + add float vector to adaptive p sampler context + stochastic sampling

* adaptive-p: fuse first softmax with transformation

* adaptive-p: implement binary search selection

* adaptive-p: update comment
This commit is contained in:
dungquixote42
2026-01-10 00:58:53 -05:00
committed by GitHub
parent dd3c3f72f2
commit 52ad1c6421
8 changed files with 226 additions and 10 deletions

View File

@@ -1033,6 +1033,111 @@ struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab&
}
// adaptive p
llama_token llama_sample_token_adaptive_p_impl(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx,
float * orig_probs)
{
GGML_ASSERT(candidates->size > 0);
const int64_t t_start_sample_us = ggml_time_us();
const size_t count = candidates->size;
adapt_p_ctx->probs.resize(count);
// cumulative distribution
const float max_logit = adapt_p_ctx->max_logit;
float cum_prob = 0.0f;
for (size_t i = 0; i < count; ++i) {
cum_prob += expf(candidates->data[i].logit - max_logit);
adapt_p_ctx->probs[i] = cum_prob;
}
adapt_p_ctx->probs.back() += 1.0f; // safety margin in case rng() ~= rng.max()
// find token with cum_prob > target_cum_prob
const float target_cum_prob = cum_prob * (float)adapt_p_ctx->rng() / (float)adapt_p_ctx->rng.max();
auto iter = std::upper_bound(adapt_p_ctx->probs.begin(), adapt_p_ctx->probs.end(), target_cum_prob);
GGML_ASSERT(iter != adapt_p_ctx->probs.end());
llama_token id = candidates->data[std::distance(adapt_p_ctx->probs.begin(), iter)].id;
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
smpl->n_sample++;
// update history with original probability of selected token
adapt_p_ctx->weighted_sum = adapt_p_ctx->decay * adapt_p_ctx->weighted_sum + orig_probs[id];
adapt_p_ctx->total_weight = adapt_p_ctx->decay * adapt_p_ctx->total_weight + 1.0f;
return id;
}
void llama_sampler_adaptive_p_apply(struct llama_sampler_adaptive_p * adapt_p_ctx, llama_token_data_array * candidates)
{
if (adapt_p_ctx->target < 0.0f) {
// sampler is disabled
llama_sample_softmax_impl(nullptr, candidates);
return;
}
// incomplete softmax because final division can be fused
float max_l = candidates->data[0].logit;
for (size_t i = 1; i < candidates->size; ++i) {
max_l = std::max(max_l, candidates->data[i].logit);
}
float cum_sum = 0.0f;
for (size_t i = 0; i < candidates->size; ++i) {
const float p = expf(candidates->data[i].logit - max_l);
candidates->data[i].p = p;
cum_sum += p;
}
// compute adapted target probability
const float target = std::clamp(adapt_p_ctx->target, 0.0f, 1.0f);
const float adapted_target = std::clamp(adapt_p_ctx->total_weight == 0.0f
? target
: 2.0f * target - (adapt_p_ctx->weighted_sum / adapt_p_ctx->total_weight),
0.0f, 1.0f);
// transformation constants
static constexpr float peak_logit_value = 5.0f;
static constexpr float inv_width = 1.0f / 0.3f;
static constexpr float sharpness = 10.0f;
const float fused_target = adapted_target * inv_width;
const float fused_width = inv_width / cum_sum;
// quadratic near target for finite differentiation, transitioning to linear decay in tails
// unbounded negative logits suppress far-from-target tokens after softmax
float max_logit = -INFINITY;
for (size_t i = 0; i < candidates->size; ++i) {
const float dist = std::abs(candidates->data[i].p * fused_width - fused_target);
const float logit = peak_logit_value - sharpness * dist * dist / (1.0f + dist);
candidates->data[i].logit = logit;
max_logit = std::max(max_logit, logit);
}
candidates->sorted = false;
adapt_p_ctx->max_logit = max_logit;
}
struct llama_sampler_adaptive_p * llama_sampler_init_adaptive_p_impl(
const float target,
const float decay,
const uint32_t seed)
{
const float clamped_decay = std::clamp(decay, 0.0f, 0.99f);
return new llama_sampler_adaptive_p {
/* .target = */ target,
/* .decay = */ clamped_decay,
/* .rng = */ std::mt19937(seed),
/* .weighted_sum = */ target / (1.0f - clamped_decay),
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
/* .max_logit = */ 0.0f,
/* .probs = */ {},
};
}
// grammar
struct llama_sampler_grammar {