Fix adaptive p sampler bug with string ban (#1287)

* adaptive p: upadte internal state only if not rewinding

* adaptive p: conditional update for speculative decoding

* adaptive p: refactor to rewind instead of update

* adaptive p fix: better comments

* fix rewind check

* add record to handle multi-token rewind

* better comment
This commit is contained in:
dungquixote42
2026-02-20 01:11:36 -05:00
committed by GitHub
parent b855bf92de
commit 0f411b02e2
7 changed files with 71 additions and 16 deletions

View File

@@ -142,6 +142,16 @@ void common_sampler_reset(common_sampler * ctx) {
llama_sampler_dry_reset(ctx->smpl);
}
void common_sampler_review(common_sampler * ctx) {
const bool record = ctx->record_samplers;
const bool rewind = ctx->rewind_samplers;
llama_review_adaptive_p(ctx->adapt_p_ctx, record, rewind);
ctx->record_samplers = false;
ctx->rewind_samplers = false;
}
void llama_sampling_set_rng_seed(struct common_sampler * ctx, uint32_t seed) {
if (seed == LLAMA_DEFAULT_SEED) {
seed = std::random_device{}();

View File

@@ -134,6 +134,9 @@ struct common_sampler {
llama_token_data_array cur_p; // current candidates
std::mt19937 rng;
bool record_samplers = false; // record current state for stateful samplers
bool rewind_samplers = false; // rewind stateful samplers to last recorded
};
@@ -148,6 +151,11 @@ void common_sampler_free(struct common_sampler * ctx);
// - reset grammar
void common_sampler_reset(common_sampler * ctx);
// Review stateful samplers
// | record current state for rewinding
// | rewind to last recorded state
void common_sampler_review(common_sampler * ctx);
// Set the sampler seed
void llama_sampling_set_rng_seed(struct common_sampler * ctx, uint32_t seed);

View File

@@ -3020,6 +3020,8 @@ void server_context::speculative_decoding_accept() {
} else {
buffer_and_check_string_ban(slot, result);
}
common_sampler_review(slot.ctx_sampling);
}
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int)ids.size() - 1, (int)slot.drafted.size(), slot.n_past);
LOG_VERBOSE("speculative decoding result", {
@@ -3135,6 +3137,7 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
if (n_rewind > 0 && (slot.rewind_count <20 || slot.rewind_count <= 2 * slot.ban_phrases.size())) {
rewind_context(slot, n_rewind);
slot.rewind_status = true;
slot.ctx_sampling->rewind_samplers = true;
}
else if (send_result) {
slot.rewind_status = false;
@@ -3147,6 +3150,7 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
// send 1 token
send_token_results(slot.token_buffer, slot, 1);
}
slot.ctx_sampling->record_samplers = true;
}
else {
// buffer the result
@@ -3264,6 +3268,8 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
buffer_and_check_string_ban(slot, result);
}
common_sampler_review(slot.ctx_sampling);
slot.i_batch = -1;
}

View File

@@ -1382,6 +1382,8 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns(
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx);
void llama_review_adaptive_p(struct llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind);
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.

View File

@@ -1053,10 +1053,27 @@ struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab&
// adaptive p
void llama_review_adaptive_p_impl(llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind) {
if (record && rewind) {
LLAMA_LOG_WARN("%s: record AND rewind is invalid\n", __func__);
return;
}
if (record) {
adapt_p_ctx->recd_weighted_sum = adapt_p_ctx->weighted_sum;
adapt_p_ctx->recd_total_weight = adapt_p_ctx->total_weight;
return;
}
if (rewind) {
adapt_p_ctx->weighted_sum = adapt_p_ctx->recd_weighted_sum;
adapt_p_ctx->total_weight = adapt_p_ctx->recd_total_weight;
return;
}
}
llama_token llama_sample_token_adaptive_p_impl(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx) {
struct llama_sampling * smpl,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx) {
GGML_ASSERT(candidates->size > 0);
const int64_t t_start_sample_us = ggml_time_us();
@@ -1082,8 +1099,8 @@ llama_token llama_sample_token_adaptive_p_impl(
// update history
const float update_prob = ctx->updt_w_cur
? candidates->data[idx].p / ctx->cum_cur_p
: ctx->orig_prob[id] / ctx->cum_orig_prob;
? candidates->data[idx].p / ctx->cum_cur_p
: ctx->orig_prob[id] / ctx->cum_orig_prob;
if (update_prob > 0) {
ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob;
ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f;
@@ -1186,17 +1203,19 @@ struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab,
GGML_ASSERT(n_vocab > 0);
const float clamped_decay = std::clamp(decay, 0.0f, 0.99f);
auto result = new llama_sampler_adaptive_p {
/* .target = */ target,
/* .decay = */ clamped_decay,
/* .updt_w_cur = */ updt_w_cur,
/* .rng = */ std::mt19937(seed),
/* .weighted_sum = */ target / (1.0f - clamped_decay),
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
/* .orig_prob = */ {},
/* .cum_orig_prob = */ 1.0f,
/* .cum_cur_p = */ 1.0f,
/* .max_xform_logit = */ -INFINITY,
/* .cum_probs = */ {},
/* .target = */ target,
/* .decay = */ clamped_decay,
/* .updt_w_cur = */ updt_w_cur,
/* .rng = */ std::mt19937(seed),
/* .weighted_sum = */ target / (1.0f - clamped_decay),
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
/* .orig_prob = */ {},
/* .cum_orig_prob = */ 1.0f,
/* .cum_cur_p = */ 1.0f,
/* .max_xform_logit = */ -INFINITY,
/* .cum_probs = */ {},
/* .recd_weighted_sum = */ target / (1.0f - clamped_decay),
/* .recd_total_weight = */ 1.0f / (1.0f - clamped_decay),
};
result->orig_prob.resize(n_vocab);
return result;

View File

@@ -83,6 +83,10 @@ struct llama_sampler_adaptive_p {
// first referenced in sample_token
std::vector<float> cum_probs; // cumulative probability distribution
// recorded states for rewinding
float recd_weighted_sum;
float recd_total_weight;
};
struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab,
@@ -101,6 +105,8 @@ void llama_sample_adaptive_p_impl(
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx);
void llama_review_adaptive_p_impl(llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind);
void llama_sample_repetition_penalties_impl(
struct llama_sampling * smpl,

View File

@@ -8160,6 +8160,10 @@ struct llama_sampler_adaptive_p * llama_init_adaptive_p(int n_vocab, const float
return llama_init_adaptive_p_impl(n_vocab, target, decay, updt_w_cur, seed);
}
void llama_review_adaptive_p(struct llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind) {
llama_review_adaptive_p_impl(adapt_p_ctx, record, rewind);
}
int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int split_no, int split_count) {
std::string str_split_path(split_path);