Adaptive P sampler: update review logic, delete old code comments, put prep stage after logit bias (#1386)

* simpler n_rewind logic, delete old comments

* use more consistent names, add updt_w_cur to json schema

* align comments

* refactor review logic, update struct/variable names

* revert cosmetic changes

* check enable/disable in llama_prep_adaptive_p_impl()

* delete extra whitespaces after statement

* show target in debug prints

* more concise debug print

* delete old comments

* update with loop instead of move()

* comment out all adaptive p debug prints

* more debug prints

* move review() variables: common_sampler struct -> common_sampler_review() args

* match n_unsent type

* fix merge bugs, delete adaptive p references in buffer_and_check_string_ban()

* restore accidental erasure

* Revert "adaptive p: collect probability before logit bias"

This reverts commit 1434878461.
This commit is contained in:
dungquixote42
2026-03-14 07:34:12 -04:00
committed by GitHub
parent a6a1da9a28
commit be2940f57a
7 changed files with 48 additions and 91 deletions

View File

@@ -882,6 +882,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
slot.sparams.adaptive_target = json_value(data, "adaptive_target", default_sparams.adaptive_target);
slot.sparams.adaptive_decay = json_value(data, "adaptive_decay", default_sparams.adaptive_decay);
slot.sparams.adaptive_updt_w_cur = json_value(data, "adaptive_updt_w_cur", default_sparams.adaptive_updt_w_cur);
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
slot.params.n_discard = json_value(data, "n_discard", defaults.n_discard);
@@ -1667,6 +1668,7 @@ json server_context::get_formated_generation(const server_slot& slot) const {
{"mirostat_eta", slot.sparams.mirostat_eta},
{"adaptive_target", slot.sparams.adaptive_target},
{"adaptive_decay", slot.sparams.adaptive_decay},
{"adaptive_updt_w_cur", slot.sparams.adaptive_updt_w_cur},
{"penalize_nl", slot.sparams.penalize_nl},
{"stop", slot.params.antiprompt},
{"max_tokens", slot.params.n_predict}, // User configured n_predict
@@ -3332,7 +3334,7 @@ void server_context::speculative_decoding_accept() {
buffer_and_check_string_ban(slot, result);
}
common_sampler_review(slot.ctx_sampling);
common_sampler_review(slot.ctx_sampling, slot.token_buffer.size(), slot.rewind_status);
}
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int)ids.size() - 1, (int)slot.drafted.size(), slot.n_past);
LOG_VERBOSE("speculative decoding result", {
@@ -3545,7 +3547,6 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
bool buffer_full = slot.token_buffer.size() >= slot.n_buffer;
int32_t ban_pos = -1;
int32_t n_rewind = 0;
bool sent_results = false;
// Always reset logit bias to base before checking bans
@@ -3553,12 +3554,6 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
if (slot.ban_phrases.size() > 0 || slot.ban_regex.size() > 0 || slot.ban_regex_ci.size() > 0) {
ban_pos = check_ban_phrase(slot);
if (ban_pos >= 0 && slot.sparams.adaptive_target >= 0.0f) {
int32_t buffer_start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1;
int32_t n_keep_buffer = ban_pos - buffer_start_pos;
if (n_keep_buffer < 0) n_keep_buffer = 0;
n_rewind = (int32_t)slot.token_buffer.size() - n_keep_buffer;
}
}
bool allow_rewind = true;
@@ -3600,17 +3595,11 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
// send 1 token from the front (FIFO)
send_token_results(slot.token_buffer, slot, 1);
}
if (slot.sparams.adaptive_target >= 0.0f) {
sent_results = true;
}
}
else {
// buffer the result, wait for more tokens to validate string
slot.sampled = result.tok;
}
if (slot.sparams.adaptive_target >= 0.0f) {
slot.ctx_sampling->n_rewind = sent_results ? -1 : n_rewind;
}
}
void server_context::process_batch_tokens(int32_t & n_batch) {
@@ -3761,7 +3750,7 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
buffer_and_check_string_ban(slot, result);
}
common_sampler_review(slot.ctx_sampling);
common_sampler_review(slot.ctx_sampling, slot.token_buffer.size(), slot.rewind_status);
slot.i_batch = -1;
}