Fix adaptive p sampler bug with string ban (#1287)

* adaptive p: upadte internal state only if not rewinding

* adaptive p: conditional update for speculative decoding

* adaptive p: refactor to rewind instead of update

* adaptive p fix: better comments

* fix rewind check

* add record to handle multi-token rewind

* better comment
This commit is contained in:
dungquixote42
2026-02-20 01:11:36 -05:00
committed by GitHub
parent b855bf92de
commit 0f411b02e2
7 changed files with 71 additions and 16 deletions

View File

@@ -1053,10 +1053,27 @@ struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab&
// adaptive p
void llama_review_adaptive_p_impl(llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind) {
if (record && rewind) {
LLAMA_LOG_WARN("%s: record AND rewind is invalid\n", __func__);
return;
}
if (record) {
adapt_p_ctx->recd_weighted_sum = adapt_p_ctx->weighted_sum;
adapt_p_ctx->recd_total_weight = adapt_p_ctx->total_weight;
return;
}
if (rewind) {
adapt_p_ctx->weighted_sum = adapt_p_ctx->recd_weighted_sum;
adapt_p_ctx->total_weight = adapt_p_ctx->recd_total_weight;
return;
}
}
llama_token llama_sample_token_adaptive_p_impl(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx) {
struct llama_sampling * smpl,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx) {
GGML_ASSERT(candidates->size > 0);
const int64_t t_start_sample_us = ggml_time_us();
@@ -1082,8 +1099,8 @@ llama_token llama_sample_token_adaptive_p_impl(
// update history
const float update_prob = ctx->updt_w_cur
? candidates->data[idx].p / ctx->cum_cur_p
: ctx->orig_prob[id] / ctx->cum_orig_prob;
? candidates->data[idx].p / ctx->cum_cur_p
: ctx->orig_prob[id] / ctx->cum_orig_prob;
if (update_prob > 0) {
ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob;
ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f;
@@ -1186,17 +1203,19 @@ struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab,
GGML_ASSERT(n_vocab > 0);
const float clamped_decay = std::clamp(decay, 0.0f, 0.99f);
auto result = new llama_sampler_adaptive_p {
/* .target = */ target,
/* .decay = */ clamped_decay,
/* .updt_w_cur = */ updt_w_cur,
/* .rng = */ std::mt19937(seed),
/* .weighted_sum = */ target / (1.0f - clamped_decay),
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
/* .orig_prob = */ {},
/* .cum_orig_prob = */ 1.0f,
/* .cum_cur_p = */ 1.0f,
/* .max_xform_logit = */ -INFINITY,
/* .cum_probs = */ {},
/* .target = */ target,
/* .decay = */ clamped_decay,
/* .updt_w_cur = */ updt_w_cur,
/* .rng = */ std::mt19937(seed),
/* .weighted_sum = */ target / (1.0f - clamped_decay),
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
/* .orig_prob = */ {},
/* .cum_orig_prob = */ 1.0f,
/* .cum_cur_p = */ 1.0f,
/* .max_xform_logit = */ -INFINITY,
/* .cum_probs = */ {},
/* .recd_weighted_sum = */ target / (1.0f - clamped_decay),
/* .recd_total_weight = */ 1.0f / (1.0f - clamped_decay),
};
result->orig_prob.resize(n_vocab);
return result;