From d93dfb5e6b78a822e7331b33feedbdc47eb5ec79 Mon Sep 17 00:00:00 2001 From: SamuelOliveirads Date: Thu, 16 Apr 2026 22:36:37 -0300 Subject: [PATCH] fix: save/restore sampler state during speculative checkpoint When speculative decoding rejects draft tokens and restores the recurrent state checkpoint, the sampler (RNG, grammar, prev tokens) must also be restored to maintain consistency. Without this, the sampler state reflects the rejected draft tokens, leading to potential divergence. Uses common_sampler_clone() to snapshot the sampler before the speculative batch decode, and restores it on rejection. --- examples/server/server-context.cpp | 25 +++++++++++++++++++++++++ examples/server/server-context.h | 1 + 2 files changed, 26 insertions(+) diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index df6d285e..6b0462a7 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -48,6 +48,9 @@ server_context::~server_context() { if (slot.ctx_sampling != nullptr) { common_sampler_free(slot.ctx_sampling); } + if (slot.spec_ckpt_sampler != nullptr) { + common_sampler_free(slot.spec_ckpt_sampler); + } if (slot.ctx_dft) { llama_free(slot.ctx_dft); } @@ -375,6 +378,10 @@ void server_slot::reset() { drafted.clear(); i_batch_dft.clear(); spec_ckpt_valid = false; + if (spec_ckpt_sampler) { + common_sampler_free(spec_ckpt_sampler); + spec_ckpt_sampler = nullptr; + } n_sent_token_probs = 0; infill = false; ga_i = 0; @@ -3614,6 +3621,13 @@ void server_context::speculative_decoding_accept() { llama_kv_cache_seq_rm(ctx, slot.id, slot.spec_ckpt_n_past, -1); + // restore sampler state (RNG, grammar, prev tokens) + if (slot.spec_ckpt_sampler) { + common_sampler_clone(slot.spec_ckpt_sampler, slot.ctx_sampling); + common_sampler_free(slot.spec_ckpt_sampler); + slot.spec_ckpt_sampler = nullptr; + } + if (!ids.empty()) { const int n_accepted = (int)ids.size(); llama_batch re_batch = llama_batch_init(n_accepted, 0, 1); @@ -3650,6 +3664,11 @@ void server_context::speculative_decoding_accept() { } else { llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1); slot.spec_ckpt_valid = false; + // discard saved sampler on full acceptance + if (slot.spec_ckpt_sampler) { + common_sampler_free(slot.spec_ckpt_sampler); + slot.spec_ckpt_sampler = nullptr; + } } for (size_t i = 0; i < ids.size(); ++i) { @@ -4233,6 +4252,12 @@ void server_context::update_slots() { const size_t written = llama_state_seq_get_data(ctx, slot.spec_ckpt_data.data(), ckpt_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); slot.spec_ckpt_valid = (written > 0); if (slot.spec_ckpt_valid) { + // save sampler state so we can restore RNG/grammar on rejection + if (slot.spec_ckpt_sampler) { + common_sampler_free(slot.spec_ckpt_sampler); + } + slot.spec_ckpt_sampler = common_sampler_init(model, slot.sparams); + common_sampler_clone(slot.ctx_sampling, slot.spec_ckpt_sampler); SLT_DBG(slot, "spec checkpoint saved: %zu bytes, n_past_pre_spec=%d\n", written, slot.spec_ckpt_n_past); } else { SLT_WRN(slot, "%s", "failed to save spec checkpoint\n"); diff --git a/examples/server/server-context.h b/examples/server/server-context.h index aba66624..dbe34d2b 100644 --- a/examples/server/server-context.h +++ b/examples/server/server-context.h @@ -162,6 +162,7 @@ struct server_slot { bool spec_ckpt_valid = false; llama_pos spec_ckpt_n_past = 0; std::vector spec_ckpt_data; + common_sampler * spec_ckpt_sampler = nullptr; // saved sampler state for checkpoint restore // speculative decoding stats int32_t n_draft_total = 0; // Total draft tokens generated