From aaa545c3dc1d832546a0065f7f3a74cee3e77f07 Mon Sep 17 00:00:00 2001 From: dungquixote42 <62397442+dungquixote42@users.noreply.github.com> Date: Tue, 24 Feb 2026 09:39:17 -0500 Subject: [PATCH] adaptive p: collect probability before logit bias (#1314) --- common/sampling.cpp | 6 +++++- include/llama.h | 2 +- src/llama-sampling.cpp | 13 ++++--------- src/llama-sampling.h | 2 +- src/llama.cpp | 4 ++-- 5 files changed, 13 insertions(+), 14 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index e631e690..4c4a1371 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -424,7 +424,6 @@ static llama_token llama_sampling_sample_impl( id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu); } else if (adaptive_target >= 0.0f && ctx_sampling->adapt_p_ctx!=nullptr) { // adaptive p sampling - llama_prep_adaptive_p(ctx_main, &cur_p, ctx_sampling->adapt_p_ctx); sampler_queue(ctx_main, params, ctx_sampling, cur_p, std::max(1, params.min_keep)); id = llama_sample_token_adaptive_p(ctx_main, &cur_p, ctx_sampling->adapt_p_ctx); } else { @@ -496,6 +495,11 @@ static llama_token_data_array llama_sampling_prepare_impl( *original_logits = {logits, logits + n_vocab}; } + if ((params.temp > 0) && (params.mirostat == 0) && (params.adaptive_target >= 0) && (ctx_sampling->adapt_p_ctx != nullptr)) { + // collect original probability before logit bias is applied + llama_prep_adaptive_p(ctx_main, logits, ctx_sampling->adapt_p_ctx); + } + // apply params.logit_bias map for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { logits[it->first] += it->second; diff --git a/include/llama.h b/include/llama.h index a8feef50..104f5d40 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1390,7 +1390,7 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns( const uint32_t seed); void llama_prep_adaptive_p(struct llama_context * ctx, - llama_token_data_array * candidates, + float * logits, struct llama_sampler_adaptive_p * adapt_p_ctx); /// @details Adaptive p sampler described in https://github.com/MrJackSpade/adaptive-p-docs/blob/main/README.md diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index c442f356..9191ba41 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1169,7 +1169,7 @@ void llama_sample_adaptive_p_impl(struct llama_sampling * ctx, llama_token_data_ void llama_prep_adaptive_p_impl( struct llama_sampling * smpl, - llama_token_data_array * candidates, + float * logits, struct llama_sampler_adaptive_p * adapt_p_ctx) { if (adapt_p_ctx->updt_w_cur) { // update with current probability, original not needed @@ -1178,16 +1178,11 @@ void llama_prep_adaptive_p_impl( constexpr float kDelta = 30.0f; //16.6f; auto t_start = ggml_time_us(); auto & orig_prob = adapt_p_ctx->orig_prob; - if (candidates->size != orig_prob.size() || candidates->sorted) { - LLAMA_LOG_ERROR("%s: this function must be called before any other sampler has been applied\n", __func__); - LLAMA_LOG_ERROR("%s: the sampler has been initialized with a vocabulary of %zu, but is being called with %zu candidates\n", - __func__, orig_prob.size(), candidates->size); - GGML_ABORT("Bad candidates in adaptive_p sampler"); - } + + std::copy(logits, logits + orig_prob.size(), orig_prob.begin()); float max_logit = -INFINITY; - for (int j = 0; j < int(candidates->size); ++j) { - orig_prob[j] = candidates->data[j].logit; + for (int j = 0; j < int(orig_prob.size()); ++j) { max_logit = std::max(max_logit, orig_prob[j]); } adapt_p_ctx->cum_orig_prob = iqk_exp_with_thresh(orig_prob.size(), orig_prob.data(), max_logit, max_logit - kDelta); diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 0a52cca5..2b52a412 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -97,7 +97,7 @@ struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab, void llama_prep_adaptive_p_impl( struct llama_sampling * smpl, - llama_token_data_array * candidates, + float * logits, struct llama_sampler_adaptive_p * adapt_p_ctx); void llama_sample_adaptive_p_impl( diff --git a/src/llama.cpp b/src/llama.cpp index 71d27dd3..4442e2dd 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8171,8 +8171,8 @@ void llama_sample_adaptive_p(llama_context * ctx, llama_sample_adaptive_p_impl(&ctx->sampling, candidates, adapt_p_ctx); } -void llama_prep_adaptive_p(struct llama_context * ctx, llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx) { - llama_prep_adaptive_p_impl(&ctx->sampling, candidates, adapt_p_ctx); +void llama_prep_adaptive_p(struct llama_context * ctx, float * logits, struct llama_sampler_adaptive_p * adapt_p_ctx) { + llama_prep_adaptive_p_impl(&ctx->sampling, logits, adapt_p_ctx); }