fix adaptive p sampler rewinding too far back (#1359)

* fix adaptive p sampler rewinding too far back

* update comments

* correct default value for total_weight, more comments

* new variables/names

* update comment for n_rewind

* move null pointer check back to common_sampler_review()

* refactor weighted_sum and total_weight to vector<pair>, better boundary check in llama_review_adaptive_p_impl()
This commit is contained in:
dungquixote42
2026-03-04 07:26:25 -05:00
committed by GitHub
parent f27678d39b
commit a903409a5e
7 changed files with 75 additions and 43 deletions

View File

@@ -1053,20 +1053,48 @@ struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab&
// adaptive p
void llama_review_adaptive_p_impl(llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind) {
if (record && rewind) {
LLAMA_LOG_WARN("%s: record AND rewind is invalid\n", __func__);
void llama_review_adaptive_p_impl(llama_sampler_adaptive_p * adapt_p_ctx, const int32_t n_rewind) {
if ((n_rewind == 0) || (adapt_p_ctx->target < 0.0f)) {
return;
}
if (record) {
adapt_p_ctx->recd_weighted_sum = adapt_p_ctx->weighted_sum;
adapt_p_ctx->recd_total_weight = adapt_p_ctx->total_weight;
// auto & weighted_sum = adapt_p_ctx->weighted_sum;
// auto & total_weight = adapt_p_ctx->total_weight;
const int32_t sz = adapt_p_ctx->history.size();
if ((sz <= 0) || (sz <= n_rewind)) {
// critically short history. reset to initial state
LLAMA_LOG_WARN("%s: sz=%d, n_rewind=%d should not be possible\n", __func__, sz, n_rewind);
adapt_p_ctx->history.clear();
adapt_p_ctx->history.push_back({
adapt_p_ctx->target / adapt_p_ctx->decay, // weighted_sum
1.0f / adapt_p_ctx->decay }); // total_weight
return;
}
if (rewind) {
adapt_p_ctx->weighted_sum = adapt_p_ctx->recd_weighted_sum;
adapt_p_ctx->total_weight = adapt_p_ctx->recd_total_weight;
return;
if (n_rewind < 0) {
// clear history except most recent
adapt_p_ctx->history.front() = adapt_p_ctx->history.back();
adapt_p_ctx->history.resize(1);
} else {
// rewind
adapt_p_ctx->history.resize(sz - n_rewind);
// int32_t sz = weighted_sum.size() - n_rewind;
// if (sz > 0) {
// weighted_sum.resize(sz);
// } else {
// LLAMA_LOG_WARN("%s: n_rewind=%d, sz=%d should not be possible\n", __func__, n_rewind, sz);
// weighted_sum.clear();
// weighted_sum.push_back(adapt_p_ctx->target / adapt_p_ctx->decay); // set to default value
// }
// sz = total_weight.size() - n_rewind;
// if (sz > 0) {
// total_weight.resize(sz);
// } else {
// LLAMA_LOG_WARN("%s: n_rewind=%d, sz=%d should not be possible\n", __func__, n_rewind, sz);
// total_weight.clear();
// total_weight.push_back(1.0f / adapt_p_ctx->decay); // set to default value
// }
}
}
@@ -1102,8 +1130,11 @@ llama_token llama_sample_token_adaptive_p_impl(
? candidates->data[idx].p / ctx->cum_cur_p
: ctx->orig_prob[id] / ctx->cum_orig_prob;
if (update_prob > 0) {
ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob;
ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f;
ctx->history.push_back({
ctx->decay * ctx->history.back().first + update_prob, // weighted_sum
ctx->decay * ctx->history.back().second + 1.0f }); // total_weight
// ctx->weighted_sum.push_back(ctx->decay * ctx->weighted_sum.back() + update_prob);
// ctx->total_weight.push_back(ctx->decay * ctx->total_weight.back() + 1.0f);
}
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
@@ -1138,10 +1169,12 @@ void llama_sample_adaptive_p_impl(struct llama_sampling * ctx, llama_token_data_
adapt_p_ctx->cum_cur_p = cum_sum;
// compute adapted target probability
const float weighted_sum = adapt_p_ctx->history.back().first;
const float total_weight = adapt_p_ctx->history.back().second;
const float target = std::clamp(adapt_p_ctx->target, 0.0f, 1.0f);
const float adapted_target = std::clamp(adapt_p_ctx->total_weight == 0.0f
const float adapted_target = std::clamp(total_weight == 0.0f
? target
: 2.0f * target - (adapt_p_ctx->weighted_sum / adapt_p_ctx->total_weight),
: 2.0f * target - (weighted_sum / total_weight),
0.0f, 1.0f);
// transformation constants
@@ -1202,16 +1235,20 @@ struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab,
/* .decay = */ clamped_decay,
/* .updt_w_cur = */ updt_w_cur,
/* .rng = */ std::mt19937(seed),
/* .weighted_sum = */ target / (1.0f - clamped_decay),
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
// /* .weighted_sum = */ {},
// /* .total_weight = */ {},
/* .history = */ {},
/* .orig_prob = */ {},
/* .cum_orig_prob = */ 1.0f,
/* .cum_cur_p = */ 1.0f,
/* .max_xform_logit = */ -INFINITY,
/* .cum_probs = */ {},
/* .recd_weighted_sum = */ target / (1.0f - clamped_decay),
/* .recd_total_weight = */ 1.0f / (1.0f - clamped_decay),
};
// result->weighted_sum.push_back(target / (1.0f - clamped_decay));
// result->total_weight.push_back(1.0f / (1.0f - clamped_decay));
result->history.push_back({
target / (1.0f - clamped_decay), // weighted_sum
1.0f / (1.0f - clamped_decay) }); // total_weight
result->orig_prob.resize(n_vocab);
return result;
}