mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-12 23:10:01 +00:00
Implement Adaptive-P Sampler (#1100)
* initial implementation of adaptive-p sampler * explicitly mark candidates unsorted + cleanup qualifiers * cosmetic update * reorg prototypes * lockstep with mainline * add _impl for _init + reorg * add LLAMA_API to prototypes * update sharpness to 10 * lockstep: rng seed * delete llama_sampling member in llama_sampler_adaptive_p * fix LLAMA_API return type * lockstep: rng seed cont * actually correct implementation * lockstep: sorting behavior * const -> constexpr for known constants * add missing space * fix softmax usage in adaptive p sampler * cosmetic changes * implement do-not-sort version of softmax * simpify rng seed, add static to constexpr * refactor: remove iface + use shared rng + use actually original probabilities * adaptive-p: add dedicated rng back in * fix initial max_logit + add float vector to adaptive p sampler context + stochastic sampling * adaptive-p: fuse first softmax with transformation * adaptive-p: implement binary search selection * adaptive-p: update comment
This commit is contained in:
@@ -1033,6 +1033,111 @@ struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab&
|
||||
}
|
||||
|
||||
|
||||
// adaptive p
|
||||
|
||||
llama_token llama_sample_token_adaptive_p_impl(
|
||||
struct llama_sampling * smpl,
|
||||
llama_token_data_array * candidates,
|
||||
struct llama_sampler_adaptive_p * adapt_p_ctx,
|
||||
float * orig_probs)
|
||||
{
|
||||
GGML_ASSERT(candidates->size > 0);
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
const size_t count = candidates->size;
|
||||
adapt_p_ctx->probs.resize(count);
|
||||
|
||||
// cumulative distribution
|
||||
const float max_logit = adapt_p_ctx->max_logit;
|
||||
float cum_prob = 0.0f;
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
cum_prob += expf(candidates->data[i].logit - max_logit);
|
||||
adapt_p_ctx->probs[i] = cum_prob;
|
||||
}
|
||||
adapt_p_ctx->probs.back() += 1.0f; // safety margin in case rng() ~= rng.max()
|
||||
|
||||
// find token with cum_prob > target_cum_prob
|
||||
const float target_cum_prob = cum_prob * (float)adapt_p_ctx->rng() / (float)adapt_p_ctx->rng.max();
|
||||
auto iter = std::upper_bound(adapt_p_ctx->probs.begin(), adapt_p_ctx->probs.end(), target_cum_prob);
|
||||
GGML_ASSERT(iter != adapt_p_ctx->probs.end());
|
||||
llama_token id = candidates->data[std::distance(adapt_p_ctx->probs.begin(), iter)].id;
|
||||
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
smpl->n_sample++;
|
||||
|
||||
// update history with original probability of selected token
|
||||
adapt_p_ctx->weighted_sum = adapt_p_ctx->decay * adapt_p_ctx->weighted_sum + orig_probs[id];
|
||||
adapt_p_ctx->total_weight = adapt_p_ctx->decay * adapt_p_ctx->total_weight + 1.0f;
|
||||
|
||||
return id;
|
||||
}
|
||||
|
||||
void llama_sampler_adaptive_p_apply(struct llama_sampler_adaptive_p * adapt_p_ctx, llama_token_data_array * candidates)
|
||||
{
|
||||
if (adapt_p_ctx->target < 0.0f) {
|
||||
// sampler is disabled
|
||||
llama_sample_softmax_impl(nullptr, candidates);
|
||||
return;
|
||||
}
|
||||
|
||||
// incomplete softmax because final division can be fused
|
||||
float max_l = candidates->data[0].logit;
|
||||
for (size_t i = 1; i < candidates->size; ++i) {
|
||||
max_l = std::max(max_l, candidates->data[i].logit);
|
||||
}
|
||||
float cum_sum = 0.0f;
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
const float p = expf(candidates->data[i].logit - max_l);
|
||||
candidates->data[i].p = p;
|
||||
cum_sum += p;
|
||||
}
|
||||
|
||||
// compute adapted target probability
|
||||
const float target = std::clamp(adapt_p_ctx->target, 0.0f, 1.0f);
|
||||
const float adapted_target = std::clamp(adapt_p_ctx->total_weight == 0.0f
|
||||
? target
|
||||
: 2.0f * target - (adapt_p_ctx->weighted_sum / adapt_p_ctx->total_weight),
|
||||
0.0f, 1.0f);
|
||||
|
||||
// transformation constants
|
||||
static constexpr float peak_logit_value = 5.0f;
|
||||
static constexpr float inv_width = 1.0f / 0.3f;
|
||||
static constexpr float sharpness = 10.0f;
|
||||
|
||||
const float fused_target = adapted_target * inv_width;
|
||||
const float fused_width = inv_width / cum_sum;
|
||||
|
||||
// quadratic near target for finite differentiation, transitioning to linear decay in tails
|
||||
// unbounded negative logits suppress far-from-target tokens after softmax
|
||||
float max_logit = -INFINITY;
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
const float dist = std::abs(candidates->data[i].p * fused_width - fused_target);
|
||||
const float logit = peak_logit_value - sharpness * dist * dist / (1.0f + dist);
|
||||
candidates->data[i].logit = logit;
|
||||
max_logit = std::max(max_logit, logit);
|
||||
}
|
||||
candidates->sorted = false;
|
||||
adapt_p_ctx->max_logit = max_logit;
|
||||
}
|
||||
|
||||
struct llama_sampler_adaptive_p * llama_sampler_init_adaptive_p_impl(
|
||||
const float target,
|
||||
const float decay,
|
||||
const uint32_t seed)
|
||||
{
|
||||
const float clamped_decay = std::clamp(decay, 0.0f, 0.99f);
|
||||
return new llama_sampler_adaptive_p {
|
||||
/* .target = */ target,
|
||||
/* .decay = */ clamped_decay,
|
||||
/* .rng = */ std::mt19937(seed),
|
||||
/* .weighted_sum = */ target / (1.0f - clamped_decay),
|
||||
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
|
||||
/* .max_logit = */ 0.0f,
|
||||
/* .probs = */ {},
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
// grammar
|
||||
|
||||
struct llama_sampler_grammar {
|
||||
|
||||
Reference in New Issue
Block a user