mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-05-12 17:05:57 +00:00
fix: save/restore sampler state during speculative checkpoint
When speculative decoding rejects draft tokens and restores the recurrent state checkpoint, the sampler (RNG, grammar, prev tokens) must also be restored to maintain consistency. Without this, the sampler state reflects the rejected draft tokens, leading to potential divergence. Uses common_sampler_clone() to snapshot the sampler before the speculative batch decode, and restores it on rejection.
This commit is contained in:
@@ -48,6 +48,9 @@ server_context::~server_context() {
|
||||
if (slot.ctx_sampling != nullptr) {
|
||||
common_sampler_free(slot.ctx_sampling);
|
||||
}
|
||||
if (slot.spec_ckpt_sampler != nullptr) {
|
||||
common_sampler_free(slot.spec_ckpt_sampler);
|
||||
}
|
||||
if (slot.ctx_dft) {
|
||||
llama_free(slot.ctx_dft);
|
||||
}
|
||||
@@ -375,6 +378,10 @@ void server_slot::reset() {
|
||||
drafted.clear();
|
||||
i_batch_dft.clear();
|
||||
spec_ckpt_valid = false;
|
||||
if (spec_ckpt_sampler) {
|
||||
common_sampler_free(spec_ckpt_sampler);
|
||||
spec_ckpt_sampler = nullptr;
|
||||
}
|
||||
n_sent_token_probs = 0;
|
||||
infill = false;
|
||||
ga_i = 0;
|
||||
@@ -3614,6 +3621,13 @@ void server_context::speculative_decoding_accept() {
|
||||
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, slot.spec_ckpt_n_past, -1);
|
||||
|
||||
// restore sampler state (RNG, grammar, prev tokens)
|
||||
if (slot.spec_ckpt_sampler) {
|
||||
common_sampler_clone(slot.spec_ckpt_sampler, slot.ctx_sampling);
|
||||
common_sampler_free(slot.spec_ckpt_sampler);
|
||||
slot.spec_ckpt_sampler = nullptr;
|
||||
}
|
||||
|
||||
if (!ids.empty()) {
|
||||
const int n_accepted = (int)ids.size();
|
||||
llama_batch re_batch = llama_batch_init(n_accepted, 0, 1);
|
||||
@@ -3650,6 +3664,11 @@ void server_context::speculative_decoding_accept() {
|
||||
} else {
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
|
||||
slot.spec_ckpt_valid = false;
|
||||
// discard saved sampler on full acceptance
|
||||
if (slot.spec_ckpt_sampler) {
|
||||
common_sampler_free(slot.spec_ckpt_sampler);
|
||||
slot.spec_ckpt_sampler = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < ids.size(); ++i) {
|
||||
@@ -4233,6 +4252,12 @@ void server_context::update_slots() {
|
||||
const size_t written = llama_state_seq_get_data(ctx, slot.spec_ckpt_data.data(), ckpt_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
slot.spec_ckpt_valid = (written > 0);
|
||||
if (slot.spec_ckpt_valid) {
|
||||
// save sampler state so we can restore RNG/grammar on rejection
|
||||
if (slot.spec_ckpt_sampler) {
|
||||
common_sampler_free(slot.spec_ckpt_sampler);
|
||||
}
|
||||
slot.spec_ckpt_sampler = common_sampler_init(model, slot.sparams);
|
||||
common_sampler_clone(slot.ctx_sampling, slot.spec_ckpt_sampler);
|
||||
SLT_DBG(slot, "spec checkpoint saved: %zu bytes, n_past_pre_spec=%d\n", written, slot.spec_ckpt_n_past);
|
||||
} else {
|
||||
SLT_WRN(slot, "%s", "failed to save spec checkpoint\n");
|
||||
|
||||
Reference in New Issue
Block a user