Adaptive p: bugfix + optimization + refactor (#1155)

* adaptive-p sampler: fix zeroed orig_probs bug and refactor

- Fix bug where original probabilities were captured as zero by calculating
  them from logits in llama_prep_adaptive_p (new).
- Replace vector with unordered_map to track candidate probabilities,
  filtering for relevance via logit delta (16.6f).
- Standardize API naming: llama_<action/verb>_<focus/name/topic>_<extra/info>
- Update function signatures to follow most other samplers.

* resolve merge bug

* adaptive-p: revert reordering function definitions
This commit is contained in:
dungquixote42
2026-01-18 01:26:06 -05:00
committed by GitHub
parent d71a3ec315
commit 6dfbef27ec
5 changed files with 121 additions and 58 deletions

View File

@@ -118,7 +118,7 @@ struct llama_sampling_context * common_sampler_init(const struct llama_vocab* vo
}
case llama_sampler_type::ADAPTIVE_P:
{
result->adapt_p_ctx=llama_sampler_init_adaptive_p(params.adaptive_target, params.adaptive_decay, result->rng());
result->adapt_p_ctx = llama_init_adaptive_p(params.adaptive_target, params.adaptive_decay, result->rng());
break;
}
default:
@@ -423,7 +423,7 @@ static void sampler_queue(
}
if (use_adaptive_p) {
// adaptive p should be put to the last, so we ignore the order in the sampler
llama_sample_adaptive_p(ctx_main, ctx_sampling->adapt_p_ctx, &cur_p);
llama_sample_adaptive_p(ctx_main, &cur_p, ctx_sampling->adapt_p_ctx);
}
}
@@ -471,15 +471,9 @@ static llama_token llama_sampling_sample_impl(
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
} else if (adaptive_target >= 0.0f && ctx_sampling->adapt_p_ctx!=nullptr) {
// adaptive p sampling
static thread_local std::vector<float> orig_probs;
orig_probs.resize(cur_p.size);
// store original probabilities
for (size_t ii = 0; ii < cur_p.size; ++ii) {
orig_probs[ii] = cur_p.data[ii].p;
}
llama_prep_adaptive_p(&cur_p, ctx_sampling->adapt_p_ctx);
sampler_queue(ctx_main, params, ctx_sampling, cur_p, std::max(1, params.min_keep));
id = llama_sample_token_adaptive_p(ctx_main, &cur_p, ctx_sampling->adapt_p_ctx, orig_probs.data());
id = llama_sample_token_adaptive_p(ctx_main, &cur_p, ctx_sampling->adapt_p_ctx);
} else {
// temperature sampling
size_t min_keep = std::max(1, params.min_keep);