Implement Adaptive-P Sampler (#1100)

* initial implementation of adaptive-p sampler

* explicitly mark candidates unsorted + cleanup qualifiers

* cosmetic update

* reorg prototypes

* lockstep with mainline

* add _impl for _init + reorg

* add LLAMA_API to prototypes

* update sharpness to 10

* lockstep: rng seed

* delete llama_sampling member in llama_sampler_adaptive_p

* fix LLAMA_API return type

* lockstep: rng seed cont

* actually correct implementation

* lockstep: sorting behavior

* const -> constexpr for known constants

* add missing space

* fix softmax usage in adaptive p sampler

* cosmetic changes

* implement do-not-sort version of softmax

* simpify rng seed, add static to constexpr

* refactor: remove iface + use shared rng + use actually original probabilities

* adaptive-p: add dedicated rng back in

* fix initial max_logit + add float vector to adaptive p sampler context + stochastic sampling

* adaptive-p: fuse first softmax with transformation

* adaptive-p: implement binary search selection

* adaptive-p: update comment
This commit is contained in:
dungquixote42
2026-01-10 00:58:53 -05:00
committed by GitHub
parent dd3c3f72f2
commit 52ad1c6421
8 changed files with 226 additions and 10 deletions

View File

@@ -925,6 +925,16 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
}
return true;
}
if (arg == "--adaptive-target") {
CHECK_ARG
sparams.adaptive_target = std::stof(argv[i]);
return true;
}
if (arg == "--adaptive-decay") {
CHECK_ARG
sparams.adaptive_decay = std::stof(argv[i]);
return true;
}
if (arg == "--spec-replace") {
CHECK_ARG
std::string target = argv[i];
@@ -2201,6 +2211,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", " --xtc-probability p", "xtc probability (default: %.1f, 0.0 = disabled)", (double)sparams.xtc_probability });
options.push_back({ "*", " --xtc-threshold t", "xtc threshold (default: %.1f, >0.5 = disabled)", (double)sparams.xtc_threshold});
options.push_back({ "*", " --top-n-sigma t", "top-n-sigma parmeter (default: %.1f, 0.0 = disabled)", (double)sparams.top_n_sigma});
options.push_back({ "*", " --adaptive-target", "adaptive-p sampling: (default: %.2f, <0.0 = disabled)", (double)sparams.adaptive_target});
options.push_back({ "*", " --adaptive-decay", "adaptive-p sampling: (default: %.2f)", (double)sparams.adaptive_decay});
options.push_back({ "*", " -l TOKEN_ID(+/-)BIAS", "modifies the likelihood of token appearing in the completion,\n"
"i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"
"or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" });
@@ -4174,6 +4186,8 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
fprintf(stream, "adaptive_target: %f # default: -1.0\n", sparams.adaptive_target);
fprintf(stream, "adaptive_decay: %f # default: 0.9\n", sparams.adaptive_decay);
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
}

View File

@@ -99,7 +99,7 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vo
result->n_valid = 0;
}
result->grammar = grmr;
// init DRY
llama_sampling_set_rng_seed(result, params.seed);
for (const auto& cnstr : params.samplers_sequence)
{
switch (cnstr)
@@ -116,11 +116,16 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vo
break;
}
case llama_sampler_type::ADAPTIVE_P:
{
result->adapt_p_ctx=llama_sampler_init_adaptive_p(params.adaptive_target, params.adaptive_decay, result->rng());
break;
}
default:
break;
}
}
llama_sampling_set_rng_seed(result, params.seed);
return result;
}
@@ -247,11 +252,13 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f\n"
"\txtc_probability = %.3f, xtc_threshold = %.3f, top_n_sigma = %.3f",
"\txtc_probability = %.3f, xtc_threshold = %.3f, top_n_sigma = %.3f\n"
"\tadaptive_target = %.2f, adaptive_decay = %.2f",
params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
params.mirostat, params.mirostat_eta, params.mirostat_tau,
params.xtc_probability, params.xtc_threshold, params.top_n_sigma);
params.xtc_probability, params.xtc_threshold, params.top_n_sigma,
params.adaptive_target, params.adaptive_decay);
return std::string(result);
}
@@ -283,6 +290,7 @@ std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
case llama_sampler_type::TEMPERATURE: return "temperature";
case llama_sampler_type::XTC : return "xtc";
case llama_sampler_type::TOP_N_SIGMA: return "top_n_sigma";
case llama_sampler_type::ADAPTIVE_P : return "adaptive_p";
default : return "";
}
}
@@ -297,7 +305,8 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto
{"tfs_z", llama_sampler_type::TFS_Z},
{"xtc", llama_sampler_type::XTC},
{"top_n_sigma", llama_sampler_type::TOP_N_SIGMA},
{"temperature", llama_sampler_type::TEMPERATURE}
{"temperature", llama_sampler_type::TEMPERATURE},
{"adaptive_p", llama_sampler_type::ADAPTIVE_P},
};
// since samplers names are written multiple ways
@@ -314,7 +323,8 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto
{"tfs", llama_sampler_type::TFS_Z},
{"xtc", llama_sampler_type::XTC},
{"top-n-sigma", llama_sampler_type::TOP_N_SIGMA},
{"temp", llama_sampler_type::TEMPERATURE}
{"temp", llama_sampler_type::TEMPERATURE},
{"adaptive-p", llama_sampler_type::ADAPTIVE_P},
};
std::vector<llama_sampler_type> sampler_types;
@@ -351,7 +361,8 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin
{'f', llama_sampler_type::TFS_Z},
{'x', llama_sampler_type::XTC},
{'n', llama_sampler_type::TOP_N_SIGMA},
{'t', llama_sampler_type::TEMPERATURE}
{'t', llama_sampler_type::TEMPERATURE},
{'w', llama_sampler_type::ADAPTIVE_P},
};
std::vector<llama_sampler_type> sampler_types;
@@ -405,6 +416,7 @@ static void sampler_queue(
llama_sample_temp(ctx_main, &cur_p, temp);
}
break;
case llama_sampler_type::ADAPTIVE_P: llama_sample_adaptive_p(ctx_main, ctx_sampling->adapt_p_ctx, &cur_p); break;
default : break;
}
}
@@ -422,6 +434,7 @@ static llama_token llama_sampling_sample_impl(
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
const float adaptive_target = params.adaptive_target;
std::vector<float> original_logits;
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
@@ -451,6 +464,17 @@ static llama_token llama_sampling_sample_impl(
} else if (mirostat == 2) {
llama_sample_temp(ctx_main, &cur_p, temp);
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
} else if (adaptive_target >= 0.0f) {
// adaptive p sampling
static thread_local std::vector<float> orig_probs;
orig_probs.resize(cur_p.size);
// store original probabilities
for (size_t ii = 0; ii < cur_p.size; ++ii) {
orig_probs[ii] = cur_p.data[ii].p;
}
sampler_queue(ctx_main, params, ctx_sampling, cur_p, std::max(1, params.min_keep));
id = llama_sample_token_adaptive_p(ctx_main, &cur_p, ctx_sampling->adapt_p_ctx, orig_probs.data());
} else {
// temperature sampling
size_t min_keep = std::max(1, params.min_keep);

View File

@@ -18,7 +18,8 @@ enum class llama_sampler_type : char {
XTC = 'x',
TOP_N_SIGMA = 'n',
TYPICAL_P = 'y',
TEMPERATURE = 't'
TEMPERATURE = 't',
ADAPTIVE_P = 'w',
};
enum common_grammar_trigger_type {
@@ -66,6 +67,8 @@ typedef struct llama_sampling_params {
float xtc_probability = 0.0f; // xtc probability
float xtc_threshold = 1.0f; // xtc threshold, disabled if > 0.5
float top_n_sigma = 0.0f; // top-n-sigma
float adaptive_target = -1.0f; // select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled)
float adaptive_decay = 0.90f; // decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation)
bool penalize_nl = false; // consider newlines as a repeatable token
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
@@ -80,7 +83,8 @@ typedef struct llama_sampling_params {
llama_sampler_type::MIN_P,
llama_sampler_type::XTC,
llama_sampler_type::TOP_N_SIGMA,
llama_sampler_type::TEMPERATURE
llama_sampler_type::TEMPERATURE,
llama_sampler_type::ADAPTIVE_P,
};
@@ -118,6 +122,8 @@ struct llama_sampling_context {
std::vector<llama_token_data> cur;
llama_sampler_dry* smpl;
llama_sampler_adaptive_p * adapt_p_ctx; // adaptive p sampler
size_t n_valid; // Number of correct top tokens with correct probabilities.
llama_token_data_array cur_p; // current candidates

View File

@@ -807,6 +807,8 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
slot.sparams.adaptive_target = json_value(data, "adaptive_target", default_sparams.adaptive_target);
slot.sparams.adaptive_decay = json_value(data, "adaptive_decay", default_sparams.adaptive_decay);
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
@@ -1405,6 +1407,8 @@ json server_context::get_formated_generation(const server_slot& slot) const {
{"mirostat", slot.sparams.mirostat},
{"mirostat_tau", slot.sparams.mirostat_tau},
{"mirostat_eta", slot.sparams.mirostat_eta},
{"adaptive_target", slot.sparams.adaptive_target},
{"adaptive_decay", slot.sparams.adaptive_decay},
{"penalize_nl", slot.sparams.penalize_nl},
{"stop", slot.params.antiprompt},
{"max_tokens", slot.params.n_predict}, // User configured n_predict

View File

@@ -1380,6 +1380,21 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns(
/// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982
/// @details Adaptive p sampler initializer
/// @param target Select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled)
/// @param decay Decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation)
LLAMA_API struct llama_sampler_adaptive_p * llama_sampler_init_adaptive_p(
const float target,
const float decay,
const uint32_t seed);
/// @details Adaptive p sampler described in https://github.com/MrJackSpade/adaptive-p-docs/blob/main/README.md
void llama_sample_adaptive_p(
struct llama_context * ctx,
struct llama_sampler_adaptive_p * adapt_p_ctx,
llama_token_data_array * candidates);
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
@@ -1417,6 +1432,13 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns(
struct llama_context * ctx,
llama_token_data_array * candidates);
/// @details Randonly selects a token from the candidates following adaptive p sampler.
llama_token llama_sample_token_adaptive_p(
struct llama_context * ctx,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx,
float * orig_probs);
//
// Model split
//

View File

@@ -1033,6 +1033,111 @@ struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab&
}
// adaptive p
llama_token llama_sample_token_adaptive_p_impl(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx,
float * orig_probs)
{
GGML_ASSERT(candidates->size > 0);
const int64_t t_start_sample_us = ggml_time_us();
const size_t count = candidates->size;
adapt_p_ctx->probs.resize(count);
// cumulative distribution
const float max_logit = adapt_p_ctx->max_logit;
float cum_prob = 0.0f;
for (size_t i = 0; i < count; ++i) {
cum_prob += expf(candidates->data[i].logit - max_logit);
adapt_p_ctx->probs[i] = cum_prob;
}
adapt_p_ctx->probs.back() += 1.0f; // safety margin in case rng() ~= rng.max()
// find token with cum_prob > target_cum_prob
const float target_cum_prob = cum_prob * (float)adapt_p_ctx->rng() / (float)adapt_p_ctx->rng.max();
auto iter = std::upper_bound(adapt_p_ctx->probs.begin(), adapt_p_ctx->probs.end(), target_cum_prob);
GGML_ASSERT(iter != adapt_p_ctx->probs.end());
llama_token id = candidates->data[std::distance(adapt_p_ctx->probs.begin(), iter)].id;
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
smpl->n_sample++;
// update history with original probability of selected token
adapt_p_ctx->weighted_sum = adapt_p_ctx->decay * adapt_p_ctx->weighted_sum + orig_probs[id];
adapt_p_ctx->total_weight = adapt_p_ctx->decay * adapt_p_ctx->total_weight + 1.0f;
return id;
}
void llama_sampler_adaptive_p_apply(struct llama_sampler_adaptive_p * adapt_p_ctx, llama_token_data_array * candidates)
{
if (adapt_p_ctx->target < 0.0f) {
// sampler is disabled
llama_sample_softmax_impl(nullptr, candidates);
return;
}
// incomplete softmax because final division can be fused
float max_l = candidates->data[0].logit;
for (size_t i = 1; i < candidates->size; ++i) {
max_l = std::max(max_l, candidates->data[i].logit);
}
float cum_sum = 0.0f;
for (size_t i = 0; i < candidates->size; ++i) {
const float p = expf(candidates->data[i].logit - max_l);
candidates->data[i].p = p;
cum_sum += p;
}
// compute adapted target probability
const float target = std::clamp(adapt_p_ctx->target, 0.0f, 1.0f);
const float adapted_target = std::clamp(adapt_p_ctx->total_weight == 0.0f
? target
: 2.0f * target - (adapt_p_ctx->weighted_sum / adapt_p_ctx->total_weight),
0.0f, 1.0f);
// transformation constants
static constexpr float peak_logit_value = 5.0f;
static constexpr float inv_width = 1.0f / 0.3f;
static constexpr float sharpness = 10.0f;
const float fused_target = adapted_target * inv_width;
const float fused_width = inv_width / cum_sum;
// quadratic near target for finite differentiation, transitioning to linear decay in tails
// unbounded negative logits suppress far-from-target tokens after softmax
float max_logit = -INFINITY;
for (size_t i = 0; i < candidates->size; ++i) {
const float dist = std::abs(candidates->data[i].p * fused_width - fused_target);
const float logit = peak_logit_value - sharpness * dist * dist / (1.0f + dist);
candidates->data[i].logit = logit;
max_logit = std::max(max_logit, logit);
}
candidates->sorted = false;
adapt_p_ctx->max_logit = max_logit;
}
struct llama_sampler_adaptive_p * llama_sampler_init_adaptive_p_impl(
const float target,
const float decay,
const uint32_t seed)
{
const float clamped_decay = std::clamp(decay, 0.0f, 0.99f);
return new llama_sampler_adaptive_p {
/* .target = */ target,
/* .decay = */ clamped_decay,
/* .rng = */ std::mt19937(seed),
/* .weighted_sum = */ target / (1.0f - clamped_decay),
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
/* .max_logit = */ 0.0f,
/* .probs = */ {},
};
}
// grammar
struct llama_sampler_grammar {

View File

@@ -61,6 +61,24 @@ struct llama_sampler_dry * llama_sampler_init_dry_impl(
void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_array* cur_p);
// maintains an exponential moving average of the *ORIGINAL* probabilities of selected tokens
// used to compute an adapted target at each sampling step.
// see llama.h for a full description of the sampler
struct llama_sampler_adaptive_p {
const float target; // target probability (0.0 - 1.0; negative = disabled)
const float decay; // EMA decay; history ≈ 1/(1-decay) tokens (0.0 - 0.99)
std::mt19937 rng; // RNG
float weighted_sum; // sum(p_n * decay^N)
float total_weight; // sum(decay^i), converges to 1/(1-decay)
float max_logit; // maximum logit found during transform
std::vector<float> probs; // cumulative probabilities
};
void llama_sampler_adaptive_p_apply(
struct llama_sampler_adaptive_p * adapt_p_ctx,
llama_token_data_array * candidates);
struct llama_sampler_adaptive_p * llama_sampler_init_adaptive_p_impl(const float target, const float decay, const uint32_t seed);
void llama_sample_repetition_penalties_impl(
@@ -83,6 +101,6 @@ llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, ll
llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
llama_token llama_sample_token_adaptive_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx, float * orig_probs);

View File

@@ -7581,6 +7581,13 @@ void llama_sample_dry([[maybe_unused]] struct llama_context* ctx, struct llama_s
llama_sampler_dry_apply(smpl, candidates_p);
}
void llama_sample_adaptive_p(
[[maybe_unused]] struct llama_context * ctx,
struct llama_sampler_adaptive_p * adapt_p_ctx,
llama_token_data_array * candidates) {
llama_sampler_adaptive_p_apply(adapt_p_ctx, candidates);
}
void llama_sample_repetition_penalties(
struct llama_context * ctx,
llama_token_data_array * candidates,
@@ -7620,6 +7627,15 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, ctx->sampling.rng);
}
llama_token llama_sample_token_adaptive_p(
struct llama_context * ctx,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx,
float * orig_probs)
{
return llama_sample_token_adaptive_p_impl(&ctx->sampling, candidates, adapt_p_ctx, orig_probs);
}
int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf";
if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) {
@@ -7671,6 +7687,13 @@ void llama_sampler_dry_accept(struct llama_sampler_dry* smpl, llama_token token)
smpl->last_tokens.push_back(token);
}
struct llama_sampler_adaptive_p * llama_sampler_init_adaptive_p(const float target, const float decay, const uint32_t seed)
{
return llama_sampler_init_adaptive_p_impl(target, decay, seed);
}
int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int split_no, int split_count) {
std::string str_split_path(split_path);
char postfix[32];