diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index df6d285e..6b0462a7 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -48,6 +48,9 @@ server_context::~server_context() { if (slot.ctx_sampling != nullptr) { common_sampler_free(slot.ctx_sampling); } + if (slot.spec_ckpt_sampler != nullptr) { + common_sampler_free(slot.spec_ckpt_sampler); + } if (slot.ctx_dft) { llama_free(slot.ctx_dft); } @@ -375,6 +378,10 @@ void server_slot::reset() { drafted.clear(); i_batch_dft.clear(); spec_ckpt_valid = false; + if (spec_ckpt_sampler) { + common_sampler_free(spec_ckpt_sampler); + spec_ckpt_sampler = nullptr; + } n_sent_token_probs = 0; infill = false; ga_i = 0; @@ -3614,6 +3621,13 @@ void server_context::speculative_decoding_accept() { llama_kv_cache_seq_rm(ctx, slot.id, slot.spec_ckpt_n_past, -1); + // restore sampler state (RNG, grammar, prev tokens) + if (slot.spec_ckpt_sampler) { + common_sampler_clone(slot.spec_ckpt_sampler, slot.ctx_sampling); + common_sampler_free(slot.spec_ckpt_sampler); + slot.spec_ckpt_sampler = nullptr; + } + if (!ids.empty()) { const int n_accepted = (int)ids.size(); llama_batch re_batch = llama_batch_init(n_accepted, 0, 1); @@ -3650,6 +3664,11 @@ void server_context::speculative_decoding_accept() { } else { llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1); slot.spec_ckpt_valid = false; + // discard saved sampler on full acceptance + if (slot.spec_ckpt_sampler) { + common_sampler_free(slot.spec_ckpt_sampler); + slot.spec_ckpt_sampler = nullptr; + } } for (size_t i = 0; i < ids.size(); ++i) { @@ -4233,6 +4252,12 @@ void server_context::update_slots() { const size_t written = llama_state_seq_get_data(ctx, slot.spec_ckpt_data.data(), ckpt_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); slot.spec_ckpt_valid = (written > 0); if (slot.spec_ckpt_valid) { + // save sampler state so we can restore RNG/grammar on rejection + if (slot.spec_ckpt_sampler) { + common_sampler_free(slot.spec_ckpt_sampler); + } + slot.spec_ckpt_sampler = common_sampler_init(model, slot.sparams); + common_sampler_clone(slot.ctx_sampling, slot.spec_ckpt_sampler); SLT_DBG(slot, "spec checkpoint saved: %zu bytes, n_past_pre_spec=%d\n", written, slot.spec_ckpt_n_past); } else { SLT_WRN(slot, "%s", "failed to save spec checkpoint\n"); diff --git a/examples/server/server-context.h b/examples/server/server-context.h index aba66624..dbe34d2b 100644 --- a/examples/server/server-context.h +++ b/examples/server/server-context.h @@ -162,6 +162,7 @@ struct server_slot { bool spec_ckpt_valid = false; llama_pos spec_ckpt_n_past = 0; std::vector spec_ckpt_data; + common_sampler * spec_ckpt_sampler = nullptr; // saved sampler state for checkpoint restore // speculative decoding stats int32_t n_draft_total = 0; // Total draft tokens generated