fix adaptive p sampler rewinding too far back (#1359)

* fix adaptive p sampler rewinding too far back

* update comments

* correct default value for total_weight, more comments

* new variables/names

* update comment for n_rewind

* move null pointer check back to common_sampler_review()

* refactor weighted_sum and total_weight to vector<pair>, better boundary check in llama_review_adaptive_p_impl()
This commit is contained in:
dungquixote42
2026-03-04 07:26:25 -05:00
committed by GitHub
parent f27678d39b
commit a903409a5e
7 changed files with 75 additions and 43 deletions

View File

@@ -1053,20 +1053,48 @@ struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab&
// adaptive p
void llama_review_adaptive_p_impl(llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind) {
if (record && rewind) {
LLAMA_LOG_WARN("%s: record AND rewind is invalid\n", __func__);
void llama_review_adaptive_p_impl(llama_sampler_adaptive_p * adapt_p_ctx, const int32_t n_rewind) {
if ((n_rewind == 0) || (adapt_p_ctx->target < 0.0f)) {
return;
}
if (record) {
adapt_p_ctx->recd_weighted_sum = adapt_p_ctx->weighted_sum;
adapt_p_ctx->recd_total_weight = adapt_p_ctx->total_weight;
// auto & weighted_sum = adapt_p_ctx->weighted_sum;
// auto & total_weight = adapt_p_ctx->total_weight;
const int32_t sz = adapt_p_ctx->history.size();
if ((sz <= 0) || (sz <= n_rewind)) {
// critically short history. reset to initial state
LLAMA_LOG_WARN("%s: sz=%d, n_rewind=%d should not be possible\n", __func__, sz, n_rewind);
adapt_p_ctx->history.clear();
adapt_p_ctx->history.push_back({
adapt_p_ctx->target / adapt_p_ctx->decay, // weighted_sum
1.0f / adapt_p_ctx->decay }); // total_weight
return;
}
if (rewind) {
adapt_p_ctx->weighted_sum = adapt_p_ctx->recd_weighted_sum;
adapt_p_ctx->total_weight = adapt_p_ctx->recd_total_weight;
return;
if (n_rewind < 0) {
// clear history except most recent
adapt_p_ctx->history.front() = adapt_p_ctx->history.back();
adapt_p_ctx->history.resize(1);
} else {
// rewind
adapt_p_ctx->history.resize(sz - n_rewind);
// int32_t sz = weighted_sum.size() - n_rewind;
// if (sz > 0) {
// weighted_sum.resize(sz);
// } else {
// LLAMA_LOG_WARN("%s: n_rewind=%d, sz=%d should not be possible\n", __func__, n_rewind, sz);
// weighted_sum.clear();
// weighted_sum.push_back(adapt_p_ctx->target / adapt_p_ctx->decay); // set to default value
// }
// sz = total_weight.size() - n_rewind;
// if (sz > 0) {
// total_weight.resize(sz);
// } else {
// LLAMA_LOG_WARN("%s: n_rewind=%d, sz=%d should not be possible\n", __func__, n_rewind, sz);
// total_weight.clear();
// total_weight.push_back(1.0f / adapt_p_ctx->decay); // set to default value
// }
}
}
@@ -1102,8 +1130,11 @@ llama_token llama_sample_token_adaptive_p_impl(
? candidates->data[idx].p / ctx->cum_cur_p
: ctx->orig_prob[id] / ctx->cum_orig_prob;
if (update_prob > 0) {
ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob;
ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f;
ctx->history.push_back({
ctx->decay * ctx->history.back().first + update_prob, // weighted_sum
ctx->decay * ctx->history.back().second + 1.0f }); // total_weight
// ctx->weighted_sum.push_back(ctx->decay * ctx->weighted_sum.back() + update_prob);
// ctx->total_weight.push_back(ctx->decay * ctx->total_weight.back() + 1.0f);
}
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
@@ -1138,10 +1169,12 @@ void llama_sample_adaptive_p_impl(struct llama_sampling * ctx, llama_token_data_
adapt_p_ctx->cum_cur_p = cum_sum;
// compute adapted target probability
const float weighted_sum = adapt_p_ctx->history.back().first;
const float total_weight = adapt_p_ctx->history.back().second;
const float target = std::clamp(adapt_p_ctx->target, 0.0f, 1.0f);
const float adapted_target = std::clamp(adapt_p_ctx->total_weight == 0.0f
const float adapted_target = std::clamp(total_weight == 0.0f
? target
: 2.0f * target - (adapt_p_ctx->weighted_sum / adapt_p_ctx->total_weight),
: 2.0f * target - (weighted_sum / total_weight),
0.0f, 1.0f);
// transformation constants
@@ -1202,16 +1235,20 @@ struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab,
/* .decay = */ clamped_decay,
/* .updt_w_cur = */ updt_w_cur,
/* .rng = */ std::mt19937(seed),
/* .weighted_sum = */ target / (1.0f - clamped_decay),
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
// /* .weighted_sum = */ {},
// /* .total_weight = */ {},
/* .history = */ {},
/* .orig_prob = */ {},
/* .cum_orig_prob = */ 1.0f,
/* .cum_cur_p = */ 1.0f,
/* .max_xform_logit = */ -INFINITY,
/* .cum_probs = */ {},
/* .recd_weighted_sum = */ target / (1.0f - clamped_decay),
/* .recd_total_weight = */ 1.0f / (1.0f - clamped_decay),
};
// result->weighted_sum.push_back(target / (1.0f - clamped_decay));
// result->total_weight.push_back(1.0f / (1.0f - clamped_decay));
result->history.push_back({
target / (1.0f - clamped_decay), // weighted_sum
1.0f / (1.0f - clamped_decay) }); // total_weight
result->orig_prob.resize(n_vocab);
return result;
}

View File

@@ -70,8 +70,9 @@ struct llama_sampler_adaptive_p {
const float decay; // EMA decay; history ≈ 1/(1-decay) tokens (0.0 - 0.99)
const bool updt_w_cur; // false=original, true=current
std::mt19937 rng; // RNG
float weighted_sum; // sum(p_n * decay^N)
float total_weight; // sum(decay^i), converges to 1/(1-decay)
// std::vector<float> weighted_sum; // [0] = sum(p_n * decay^N)
// std::vector<float> total_weight; // [0] = sum(decay^i), converges to 1/(1-decay)
std::vector<std::pair<float, float>> history; // <weighted_sum, total_weight>
// first referenced in prep
std::vector<float> orig_prob; // for storing the original proibabilities
@@ -83,10 +84,6 @@ struct llama_sampler_adaptive_p {
// first referenced in sample_token
std::vector<float> cum_probs; // cumulative probability distribution
// recorded states for rewinding
float recd_weighted_sum;
float recd_total_weight;
};
struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab,
@@ -105,7 +102,7 @@ void llama_sample_adaptive_p_impl(
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx);
void llama_review_adaptive_p_impl(llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind);
void llama_review_adaptive_p_impl(llama_sampler_adaptive_p * adapt_p_ctx, const int32_t n_rewind);
void llama_sample_repetition_penalties_impl(

View File

@@ -8304,8 +8304,8 @@ struct llama_sampler_adaptive_p * llama_init_adaptive_p(int n_vocab, const float
return llama_init_adaptive_p_impl(n_vocab, target, decay, updt_w_cur, seed);
}
void llama_review_adaptive_p(struct llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind) {
llama_review_adaptive_p_impl(adapt_p_ctx, record, rewind);
void llama_review_adaptive_p(struct llama_sampler_adaptive_p * adapt_p_ctx, const int32_t n_rewind) {
llama_review_adaptive_p_impl(adapt_p_ctx, n_rewind);
}