Implement Adaptive-P Sampler (#1100)

* initial implementation of adaptive-p sampler

* explicitly mark candidates unsorted + cleanup qualifiers

* cosmetic update

* reorg prototypes

* lockstep with mainline

* add _impl for _init + reorg

* add LLAMA_API to prototypes

* update sharpness to 10

* lockstep: rng seed

* delete llama_sampling member in llama_sampler_adaptive_p

* fix LLAMA_API return type

* lockstep: rng seed cont

* actually correct implementation

* lockstep: sorting behavior

* const -> constexpr for known constants

* add missing space

* fix softmax usage in adaptive p sampler

* cosmetic changes

* implement do-not-sort version of softmax

* simpify rng seed, add static to constexpr

* refactor: remove iface + use shared rng + use actually original probabilities

* adaptive-p: add dedicated rng back in

* fix initial max_logit + add float vector to adaptive p sampler context + stochastic sampling

* adaptive-p: fuse first softmax with transformation

* adaptive-p: implement binary search selection

* adaptive-p: update comment
This commit is contained in:
dungquixote42
2026-01-10 00:58:53 -05:00
committed by GitHub
parent dd3c3f72f2
commit 52ad1c6421
8 changed files with 226 additions and 10 deletions

View File

@@ -1033,6 +1033,111 @@ struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab&
}
// adaptive p
llama_token llama_sample_token_adaptive_p_impl(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx,
float * orig_probs)
{
GGML_ASSERT(candidates->size > 0);
const int64_t t_start_sample_us = ggml_time_us();
const size_t count = candidates->size;
adapt_p_ctx->probs.resize(count);
// cumulative distribution
const float max_logit = adapt_p_ctx->max_logit;
float cum_prob = 0.0f;
for (size_t i = 0; i < count; ++i) {
cum_prob += expf(candidates->data[i].logit - max_logit);
adapt_p_ctx->probs[i] = cum_prob;
}
adapt_p_ctx->probs.back() += 1.0f; // safety margin in case rng() ~= rng.max()
// find token with cum_prob > target_cum_prob
const float target_cum_prob = cum_prob * (float)adapt_p_ctx->rng() / (float)adapt_p_ctx->rng.max();
auto iter = std::upper_bound(adapt_p_ctx->probs.begin(), adapt_p_ctx->probs.end(), target_cum_prob);
GGML_ASSERT(iter != adapt_p_ctx->probs.end());
llama_token id = candidates->data[std::distance(adapt_p_ctx->probs.begin(), iter)].id;
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
smpl->n_sample++;
// update history with original probability of selected token
adapt_p_ctx->weighted_sum = adapt_p_ctx->decay * adapt_p_ctx->weighted_sum + orig_probs[id];
adapt_p_ctx->total_weight = adapt_p_ctx->decay * adapt_p_ctx->total_weight + 1.0f;
return id;
}
void llama_sampler_adaptive_p_apply(struct llama_sampler_adaptive_p * adapt_p_ctx, llama_token_data_array * candidates)
{
if (adapt_p_ctx->target < 0.0f) {
// sampler is disabled
llama_sample_softmax_impl(nullptr, candidates);
return;
}
// incomplete softmax because final division can be fused
float max_l = candidates->data[0].logit;
for (size_t i = 1; i < candidates->size; ++i) {
max_l = std::max(max_l, candidates->data[i].logit);
}
float cum_sum = 0.0f;
for (size_t i = 0; i < candidates->size; ++i) {
const float p = expf(candidates->data[i].logit - max_l);
candidates->data[i].p = p;
cum_sum += p;
}
// compute adapted target probability
const float target = std::clamp(adapt_p_ctx->target, 0.0f, 1.0f);
const float adapted_target = std::clamp(adapt_p_ctx->total_weight == 0.0f
? target
: 2.0f * target - (adapt_p_ctx->weighted_sum / adapt_p_ctx->total_weight),
0.0f, 1.0f);
// transformation constants
static constexpr float peak_logit_value = 5.0f;
static constexpr float inv_width = 1.0f / 0.3f;
static constexpr float sharpness = 10.0f;
const float fused_target = adapted_target * inv_width;
const float fused_width = inv_width / cum_sum;
// quadratic near target for finite differentiation, transitioning to linear decay in tails
// unbounded negative logits suppress far-from-target tokens after softmax
float max_logit = -INFINITY;
for (size_t i = 0; i < candidates->size; ++i) {
const float dist = std::abs(candidates->data[i].p * fused_width - fused_target);
const float logit = peak_logit_value - sharpness * dist * dist / (1.0f + dist);
candidates->data[i].logit = logit;
max_logit = std::max(max_logit, logit);
}
candidates->sorted = false;
adapt_p_ctx->max_logit = max_logit;
}
struct llama_sampler_adaptive_p * llama_sampler_init_adaptive_p_impl(
const float target,
const float decay,
const uint32_t seed)
{
const float clamped_decay = std::clamp(decay, 0.0f, 0.99f);
return new llama_sampler_adaptive_p {
/* .target = */ target,
/* .decay = */ clamped_decay,
/* .rng = */ std::mt19937(seed),
/* .weighted_sum = */ target / (1.0f - clamped_decay),
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
/* .max_logit = */ 0.0f,
/* .probs = */ {},
};
}
// grammar
struct llama_sampler_grammar {

View File

@@ -61,6 +61,24 @@ struct llama_sampler_dry * llama_sampler_init_dry_impl(
void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_array* cur_p);
// maintains an exponential moving average of the *ORIGINAL* probabilities of selected tokens
// used to compute an adapted target at each sampling step.
// see llama.h for a full description of the sampler
struct llama_sampler_adaptive_p {
const float target; // target probability (0.0 - 1.0; negative = disabled)
const float decay; // EMA decay; history ≈ 1/(1-decay) tokens (0.0 - 0.99)
std::mt19937 rng; // RNG
float weighted_sum; // sum(p_n * decay^N)
float total_weight; // sum(decay^i), converges to 1/(1-decay)
float max_logit; // maximum logit found during transform
std::vector<float> probs; // cumulative probabilities
};
void llama_sampler_adaptive_p_apply(
struct llama_sampler_adaptive_p * adapt_p_ctx,
llama_token_data_array * candidates);
struct llama_sampler_adaptive_p * llama_sampler_init_adaptive_p_impl(const float target, const float decay, const uint32_t seed);
void llama_sample_repetition_penalties_impl(
@@ -83,6 +101,6 @@ llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, ll
llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
llama_token llama_sample_token_adaptive_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx, float * orig_probs);

View File

@@ -7581,6 +7581,13 @@ void llama_sample_dry([[maybe_unused]] struct llama_context* ctx, struct llama_s
llama_sampler_dry_apply(smpl, candidates_p);
}
void llama_sample_adaptive_p(
[[maybe_unused]] struct llama_context * ctx,
struct llama_sampler_adaptive_p * adapt_p_ctx,
llama_token_data_array * candidates) {
llama_sampler_adaptive_p_apply(adapt_p_ctx, candidates);
}
void llama_sample_repetition_penalties(
struct llama_context * ctx,
llama_token_data_array * candidates,
@@ -7620,6 +7627,15 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, ctx->sampling.rng);
}
llama_token llama_sample_token_adaptive_p(
struct llama_context * ctx,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx,
float * orig_probs)
{
return llama_sample_token_adaptive_p_impl(&ctx->sampling, candidates, adapt_p_ctx, orig_probs);
}
int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf";
if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) {
@@ -7671,6 +7687,13 @@ void llama_sampler_dry_accept(struct llama_sampler_dry* smpl, llama_token token)
smpl->last_tokens.push_back(token);
}
struct llama_sampler_adaptive_p * llama_sampler_init_adaptive_p(const float target, const float decay, const uint32_t seed)
{
return llama_sampler_init_adaptive_p_impl(target, decay, seed);
}
int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int split_no, int split_count) {
std::string str_split_path(split_path);
char postfix[32];