Speculative checkpoints for recurrent models (#1669)

* server: spec checkpoints for recurrent models

* fix: save/restore sampler state during speculative checkpoint

When speculative decoding rejects draft tokens and restores the
recurrent state checkpoint, the sampler (RNG, grammar, prev tokens)
must also be restored to maintain consistency. Without this, the
sampler state reflects the rejected draft tokens, leading to
potential divergence.

Uses common_sampler_clone() to snapshot the sampler before the
speculative batch decode, and restores it on rejection.

* server: snapshot recurrent state in tensor

* reset ngram mod state for rejected tokens

* server: refactor checkpoint state logic

* speculative: fix sampler for checkpoints

* recurrent model: implement recurrent kernel checkpoint

* recurrent model: refactor api

* spec: free rbudget before overwriting
This commit is contained in:
Samuel Oliveira Alves
2026-04-24 04:59:30 -03:00
committed by GitHub
parent 1c13288164
commit ea94afe777
15 changed files with 943 additions and 54 deletions

View File

@@ -22,6 +22,51 @@ static void log_text(const gpt_params & params_base, const std::string & text) {
}
}
void server_speculative_checkpoint::clear() {
valid = false;
per_step_enabled = false;
n_past = 0;
sampled = LLAMA_TOKEN_NULL;
if (sampler != nullptr) {
common_sampler_free(sampler);
sampler = nullptr;
}
}
static void discard_speculative_checkpoint(server_slot & slot, llama_context * ctx) {
slot.spec_ckpt.clear();
llama_spec_ckpt_discard(ctx);
}
static bool save_speculative_checkpoint(server_slot & slot, llama_model * model, llama_context * ctx, int ckpt_mode) {
slot.spec_ckpt.clear();
slot.spec_ckpt.n_past = slot.n_past - (int32_t)(slot.drafted.size() + 1);
slot.spec_ckpt.sampled = slot.sampled;
const int max_tokens = (int)slot.drafted.size() + 1;
const int actual_mode = llama_spec_ckpt_init(ctx, ckpt_mode, max_tokens);
if (actual_mode == LLAMA_SPEC_CKPT_NONE) {
return false;
}
slot.spec_ckpt.per_step_enabled = (actual_mode == LLAMA_SPEC_CKPT_PER_STEP);
slot.spec_ckpt.valid = llama_spec_ckpt_save(ctx, slot.id);
if (!slot.spec_ckpt.valid) {
llama_spec_ckpt_discard(ctx);
return false;
}
slot.spec_ckpt.sampler = common_sampler_init(model, slot.sparams);
if (slot.spec_ckpt.sampler == nullptr) {
discard_speculative_checkpoint(slot, ctx);
return false;
}
common_sampler_clone(slot.ctx_sampling, slot.spec_ckpt.sampler);
return true;
}
server_context::~server_context() {
if (ctx) {
llama_free(ctx);
@@ -49,6 +94,7 @@ server_context::~server_context() {
if (slot.ctx_sampling != nullptr) {
common_sampler_free(slot.ctx_sampling);
}
slot.spec_ckpt.clear();
if (slot.ctx_dft) {
llama_free(slot.ctx_dft);
}
@@ -112,15 +158,6 @@ bool server_context::load_model(const gpt_params& params_) {
}
// Load draft model for speculative decoding if specified
if (has_draft_model) {
if (llama_model_has_recurrent(model)) {
LLAMA_LOG_WARN("\n=======================================================================\n");
LLAMA_LOG_WARN(" Speculative decodong is not suported for recurrent/hybrid models\n");
LLAMA_LOG_WARN(" --> bailing out\n");
LLAMA_LOG_WARN("========================================================================\n\n");
GGML_ABORT("Fatal error");
}
LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n");
gpt_params params_dft;
@@ -387,6 +424,7 @@ void server_slot::reset() {
n_sent_text = 0;
drafted.clear();
i_batch_dft.clear();
spec_ckpt.clear();
n_sent_token_probs = 0;
infill = false;
ga_i = 0;
@@ -3679,6 +3717,72 @@ void server_context::extend_context(const int32_t n_tokens) {
}
}
// Restore recurrent state and re-decode accepted tokens after speculative-decode rejection.
static void restore_speculative_checkpoint(
server_slot & slot, llama_context * ctx, llama_model * model,
const std::vector<llama_token> & ids, int n_draft) {
if (slot.spec_ckpt.per_step_enabled) {
const int step = (int)ids.size() - 1;
llama_spec_ckpt_restore(ctx, slot.id, slot.spec_ckpt.n_past, step);
if (slot.spec_ckpt.sampler) {
common_sampler_clone(slot.spec_ckpt.sampler, slot.ctx_sampling);
}
for (llama_token id : ids) {
common_sampler_accept(slot.ctx_sampling, ctx, id, true);
}
SLT_DBG(slot, "per-step restore: step=%d (rejected %d drafts)\n",
step, (int)(n_draft - (ids.size() - 1)));
} else {
// Restore pre-speculation recurrent state then re-decode accepted tokens.
llama_spec_ckpt_restore(ctx, slot.id, slot.spec_ckpt.n_past, 0);
if (slot.spec_ckpt.sampler) {
common_sampler_clone(slot.spec_ckpt.sampler, slot.ctx_sampling);
}
if (!ids.empty()) {
// Re-decode to advance recurrent state to the accepted position.
const int n_re = (int)ids.size();
llama_batch re_batch = llama_batch_init(n_re, 0, 1);
common_batch_add(re_batch, slot.spec_ckpt.sampled, slot.spec_ckpt.n_past, { slot.id }, n_re == 1);
for (int j = 0; j < n_re - 1; j++) {
common_batch_add(re_batch, ids[j], slot.spec_ckpt.n_past + 1 + j, { slot.id }, j == n_re - 2);
}
if (slot.has_mtp) {
llama_set_embeddings(ctx, true);
}
const int ret = llama_decode(ctx, re_batch);
if (ret != 0) {
SLT_ERR(slot, "failed to re-decode accepted tokens after checkpoint restore: %d\n", ret);
}
if (slot.has_mtp) {
llama_set_embeddings(ctx, false);
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
const float * emb = llama_get_embeddings_ith(ctx, -1);
if (emb) {
slot.mtp_hidden_state.resize(n_embd);
memcpy(slot.mtp_hidden_state.data(), emb, n_embd * sizeof(float));
}
}
for (llama_token id : ids) {
common_sampler_accept(slot.ctx_sampling, ctx, id, true);
}
llama_batch_free(re_batch);
SLT_DBG(slot, "spec checkpoint restored: re-decoded %d tokens (rejected %d drafts)\n",
n_re, (int)(n_draft - (ids.size() - 1)));
}
}
discard_speculative_checkpoint(slot, ctx);
}
void server_context::speculative_decoding_accept() {
for (auto& slot : slots) {
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch_dft.empty()) {
@@ -3739,7 +3843,14 @@ void server_context::speculative_decoding_accept() {
slot.sampled = ids.back(); // last accepted token
slot.n_past = slot.cache_tokens.n_tokens();
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
// for recurrent/hybrid models: if any drafts were rejected, restore recurrent state
const bool any_rejected = (ids.size() - 1) < n_draft;
if (any_rejected && slot.spec_ckpt.valid) {
restore_speculative_checkpoint(slot, ctx, model, ids, n_draft);
} else {
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
discard_speculative_checkpoint(slot, ctx);
}
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;
@@ -4305,6 +4416,23 @@ void server_context::update_slots() {
// make sure we're in the right embedding mode
llama_set_embeddings(ctx, batch_type == 1);
if (llama_model_has_recurrent(model)) {
const int ckpt_mode = params_base.speculative.recurrent_ckpt_mode;
for (auto & slot : slots) {
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch_dft.empty()) {
continue;
}
if (save_speculative_checkpoint(slot, model, ctx, ckpt_mode)) {
const char * mode_name = slot.spec_ckpt.per_step_enabled ? "per-step" : "shadow/cpu";
SLT_DBG(slot, "spec checkpoint saved (mode=%s), n_past_pre_spec=%d\n",
mode_name, slot.spec_ckpt.n_past);
} else {
SLT_WRN(slot, "%s", "failed to save spec checkpoint\n");
}
}
}
// process the created batch of tokens
process_batch_tokens(n_batch); // Decode with batch