fix adaptive p sampler rewinding too far back (#1359)

* fix adaptive p sampler rewinding too far back

* update comments

* correct default value for total_weight, more comments

* new variables/names

* update comment for n_rewind

* move null pointer check back to common_sampler_review()

* refactor weighted_sum and total_weight to vector<pair>, better boundary check in llama_review_adaptive_p_impl()
This commit is contained in:
dungquixote42
2026-03-04 07:26:25 -05:00
committed by GitHub
parent f27678d39b
commit a903409a5e
7 changed files with 75 additions and 43 deletions

View File

@@ -106,6 +106,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
}
}
result->n_rewind = -1;
return result;
}
@@ -143,16 +145,12 @@ void common_sampler_reset(common_sampler * ctx) {
}
void common_sampler_review(common_sampler * ctx) {
if (!ctx->adapt_p_ctx) {
return;
const int32_t n_rewind = ctx->n_rewind;
// add stateful samplers here
if (ctx->adapt_p_ctx != nullptr) {
llama_review_adaptive_p(ctx->adapt_p_ctx, n_rewind);
}
const bool record = ctx->record_samplers;
const bool rewind = ctx->rewind_samplers;
llama_review_adaptive_p(ctx->adapt_p_ctx, record, rewind);
ctx->record_samplers = false;
ctx->rewind_samplers = false;
}
void llama_sampling_set_rng_seed(struct common_sampler * ctx, uint32_t seed) {

View File

@@ -135,8 +135,7 @@ struct common_sampler {
std::mt19937 rng;
bool record_samplers = false; // record current state for stateful samplers
bool rewind_samplers = false; // rewind stateful samplers to last recorded
int32_t n_rewind; // number of tokens to rewind
};
@@ -152,8 +151,7 @@ void common_sampler_free(struct common_sampler * ctx);
void common_sampler_reset(common_sampler * ctx);
// Review stateful samplers
// | record current state for rewinding
// | rewind to last recorded state
// - rewind internal states (maybe)
void common_sampler_review(common_sampler * ctx);
// Set the sampler seed

View File

@@ -3332,6 +3332,7 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
bool next_token = has_next_token(result, slot);
bool send_result = slot.token_buffer.size() >= slot.n_buffer || !next_token;
int32_t n_rewind = 0;
bool sent_results = false;
// don't restore if last time was also rewind
if (!slot.rewind_status) {
slot.ctx_sampling->params.logit_bias = slot.logit_bias; // restore logit bias
@@ -3343,7 +3344,6 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
if (n_rewind > 0 && (slot.rewind_count <20 || slot.rewind_count <= 2 * slot.ban_phrases.size())) {
rewind_context(slot, n_rewind);
slot.rewind_status = true;
slot.ctx_sampling->rewind_samplers = true;
}
else if (send_result) {
slot.rewind_status = false;
@@ -3356,12 +3356,14 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
// send 1 token
send_token_results(slot.token_buffer, slot, 1);
}
slot.ctx_sampling->record_samplers = true;
sent_results = true;
}
else {
// buffer the result
slot.sampled = result.tok; // for common batch add
}
slot.ctx_sampling->n_rewind = sent_results ? -1 : n_rewind;
}
void server_context::process_batch_tokens(int32_t & n_batch) {

View File

@@ -1415,7 +1415,7 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns(
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx);
void llama_review_adaptive_p(struct llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind);
void llama_review_adaptive_p(struct llama_sampler_adaptive_p * adapt_p_ctx, const int32_t n_rewind);
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.

View File

@@ -1053,20 +1053,48 @@ struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab&
// adaptive p
void llama_review_adaptive_p_impl(llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind) {
if (record && rewind) {
LLAMA_LOG_WARN("%s: record AND rewind is invalid\n", __func__);
void llama_review_adaptive_p_impl(llama_sampler_adaptive_p * adapt_p_ctx, const int32_t n_rewind) {
if ((n_rewind == 0) || (adapt_p_ctx->target < 0.0f)) {
return;
}
if (record) {
adapt_p_ctx->recd_weighted_sum = adapt_p_ctx->weighted_sum;
adapt_p_ctx->recd_total_weight = adapt_p_ctx->total_weight;
// auto & weighted_sum = adapt_p_ctx->weighted_sum;
// auto & total_weight = adapt_p_ctx->total_weight;
const int32_t sz = adapt_p_ctx->history.size();
if ((sz <= 0) || (sz <= n_rewind)) {
// critically short history. reset to initial state
LLAMA_LOG_WARN("%s: sz=%d, n_rewind=%d should not be possible\n", __func__, sz, n_rewind);
adapt_p_ctx->history.clear();
adapt_p_ctx->history.push_back({
adapt_p_ctx->target / adapt_p_ctx->decay, // weighted_sum
1.0f / adapt_p_ctx->decay }); // total_weight
return;
}
if (rewind) {
adapt_p_ctx->weighted_sum = adapt_p_ctx->recd_weighted_sum;
adapt_p_ctx->total_weight = adapt_p_ctx->recd_total_weight;
return;
if (n_rewind < 0) {
// clear history except most recent
adapt_p_ctx->history.front() = adapt_p_ctx->history.back();
adapt_p_ctx->history.resize(1);
} else {
// rewind
adapt_p_ctx->history.resize(sz - n_rewind);
// int32_t sz = weighted_sum.size() - n_rewind;
// if (sz > 0) {
// weighted_sum.resize(sz);
// } else {
// LLAMA_LOG_WARN("%s: n_rewind=%d, sz=%d should not be possible\n", __func__, n_rewind, sz);
// weighted_sum.clear();
// weighted_sum.push_back(adapt_p_ctx->target / adapt_p_ctx->decay); // set to default value
// }
// sz = total_weight.size() - n_rewind;
// if (sz > 0) {
// total_weight.resize(sz);
// } else {
// LLAMA_LOG_WARN("%s: n_rewind=%d, sz=%d should not be possible\n", __func__, n_rewind, sz);
// total_weight.clear();
// total_weight.push_back(1.0f / adapt_p_ctx->decay); // set to default value
// }
}
}
@@ -1102,8 +1130,11 @@ llama_token llama_sample_token_adaptive_p_impl(
? candidates->data[idx].p / ctx->cum_cur_p
: ctx->orig_prob[id] / ctx->cum_orig_prob;
if (update_prob > 0) {
ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob;
ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f;
ctx->history.push_back({
ctx->decay * ctx->history.back().first + update_prob, // weighted_sum
ctx->decay * ctx->history.back().second + 1.0f }); // total_weight
// ctx->weighted_sum.push_back(ctx->decay * ctx->weighted_sum.back() + update_prob);
// ctx->total_weight.push_back(ctx->decay * ctx->total_weight.back() + 1.0f);
}
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
@@ -1138,10 +1169,12 @@ void llama_sample_adaptive_p_impl(struct llama_sampling * ctx, llama_token_data_
adapt_p_ctx->cum_cur_p = cum_sum;
// compute adapted target probability
const float weighted_sum = adapt_p_ctx->history.back().first;
const float total_weight = adapt_p_ctx->history.back().second;
const float target = std::clamp(adapt_p_ctx->target, 0.0f, 1.0f);
const float adapted_target = std::clamp(adapt_p_ctx->total_weight == 0.0f
const float adapted_target = std::clamp(total_weight == 0.0f
? target
: 2.0f * target - (adapt_p_ctx->weighted_sum / adapt_p_ctx->total_weight),
: 2.0f * target - (weighted_sum / total_weight),
0.0f, 1.0f);
// transformation constants
@@ -1202,16 +1235,20 @@ struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab,
/* .decay = */ clamped_decay,
/* .updt_w_cur = */ updt_w_cur,
/* .rng = */ std::mt19937(seed),
/* .weighted_sum = */ target / (1.0f - clamped_decay),
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
// /* .weighted_sum = */ {},
// /* .total_weight = */ {},
/* .history = */ {},
/* .orig_prob = */ {},
/* .cum_orig_prob = */ 1.0f,
/* .cum_cur_p = */ 1.0f,
/* .max_xform_logit = */ -INFINITY,
/* .cum_probs = */ {},
/* .recd_weighted_sum = */ target / (1.0f - clamped_decay),
/* .recd_total_weight = */ 1.0f / (1.0f - clamped_decay),
};
// result->weighted_sum.push_back(target / (1.0f - clamped_decay));
// result->total_weight.push_back(1.0f / (1.0f - clamped_decay));
result->history.push_back({
target / (1.0f - clamped_decay), // weighted_sum
1.0f / (1.0f - clamped_decay) }); // total_weight
result->orig_prob.resize(n_vocab);
return result;
}

View File

@@ -70,8 +70,9 @@ struct llama_sampler_adaptive_p {
const float decay; // EMA decay; history ≈ 1/(1-decay) tokens (0.0 - 0.99)
const bool updt_w_cur; // false=original, true=current
std::mt19937 rng; // RNG
float weighted_sum; // sum(p_n * decay^N)
float total_weight; // sum(decay^i), converges to 1/(1-decay)
// std::vector<float> weighted_sum; // [0] = sum(p_n * decay^N)
// std::vector<float> total_weight; // [0] = sum(decay^i), converges to 1/(1-decay)
std::vector<std::pair<float, float>> history; // <weighted_sum, total_weight>
// first referenced in prep
std::vector<float> orig_prob; // for storing the original proibabilities
@@ -83,10 +84,6 @@ struct llama_sampler_adaptive_p {
// first referenced in sample_token
std::vector<float> cum_probs; // cumulative probability distribution
// recorded states for rewinding
float recd_weighted_sum;
float recd_total_weight;
};
struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab,
@@ -105,7 +102,7 @@ void llama_sample_adaptive_p_impl(
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx);
void llama_review_adaptive_p_impl(llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind);
void llama_review_adaptive_p_impl(llama_sampler_adaptive_p * adapt_p_ctx, const int32_t n_rewind);
void llama_sample_repetition_penalties_impl(

View File

@@ -8304,8 +8304,8 @@ struct llama_sampler_adaptive_p * llama_init_adaptive_p(int n_vocab, const float
return llama_init_adaptive_p_impl(n_vocab, target, decay, updt_w_cur, seed);
}
void llama_review_adaptive_p(struct llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind) {
llama_review_adaptive_p_impl(adapt_p_ctx, record, rewind);
void llama_review_adaptive_p(struct llama_sampler_adaptive_p * adapt_p_ctx, const int32_t n_rewind) {
llama_review_adaptive_p_impl(adapt_p_ctx, n_rewind);
}