fix: save/restore sampler state during speculative checkpoint

When speculative decoding rejects draft tokens and restores the
recurrent state checkpoint, the sampler (RNG, grammar, prev tokens)
must also be restored to maintain consistency. Without this, the
sampler state reflects the rejected draft tokens, leading to
potential divergence.

Uses common_sampler_clone() to snapshot the sampler before the
speculative batch decode, and restores it on rejection.
This commit is contained in:
SamuelOliveirads
2026-04-16 22:36:37 -03:00
parent d670cf85cd
commit d93dfb5e6b
2 changed files with 26 additions and 0 deletions

View File

@@ -48,6 +48,9 @@ server_context::~server_context() {
if (slot.ctx_sampling != nullptr) {
common_sampler_free(slot.ctx_sampling);
}
if (slot.spec_ckpt_sampler != nullptr) {
common_sampler_free(slot.spec_ckpt_sampler);
}
if (slot.ctx_dft) {
llama_free(slot.ctx_dft);
}
@@ -375,6 +378,10 @@ void server_slot::reset() {
drafted.clear();
i_batch_dft.clear();
spec_ckpt_valid = false;
if (spec_ckpt_sampler) {
common_sampler_free(spec_ckpt_sampler);
spec_ckpt_sampler = nullptr;
}
n_sent_token_probs = 0;
infill = false;
ga_i = 0;
@@ -3614,6 +3621,13 @@ void server_context::speculative_decoding_accept() {
llama_kv_cache_seq_rm(ctx, slot.id, slot.spec_ckpt_n_past, -1);
// restore sampler state (RNG, grammar, prev tokens)
if (slot.spec_ckpt_sampler) {
common_sampler_clone(slot.spec_ckpt_sampler, slot.ctx_sampling);
common_sampler_free(slot.spec_ckpt_sampler);
slot.spec_ckpt_sampler = nullptr;
}
if (!ids.empty()) {
const int n_accepted = (int)ids.size();
llama_batch re_batch = llama_batch_init(n_accepted, 0, 1);
@@ -3650,6 +3664,11 @@ void server_context::speculative_decoding_accept() {
} else {
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
slot.spec_ckpt_valid = false;
// discard saved sampler on full acceptance
if (slot.spec_ckpt_sampler) {
common_sampler_free(slot.spec_ckpt_sampler);
slot.spec_ckpt_sampler = nullptr;
}
}
for (size_t i = 0; i < ids.size(); ++i) {
@@ -4233,6 +4252,12 @@ void server_context::update_slots() {
const size_t written = llama_state_seq_get_data(ctx, slot.spec_ckpt_data.data(), ckpt_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
slot.spec_ckpt_valid = (written > 0);
if (slot.spec_ckpt_valid) {
// save sampler state so we can restore RNG/grammar on rejection
if (slot.spec_ckpt_sampler) {
common_sampler_free(slot.spec_ckpt_sampler);
}
slot.spec_ckpt_sampler = common_sampler_init(model, slot.sparams);
common_sampler_clone(slot.ctx_sampling, slot.spec_ckpt_sampler);
SLT_DBG(slot, "spec checkpoint saved: %zu bytes, n_past_pre_spec=%d\n", written, slot.spec_ckpt_n_past);
} else {
SLT_WRN(slot, "%s", "failed to save spec checkpoint\n");