Speculative checkpoints for recurrent models (#1669)

* server: spec checkpoints for recurrent models

* fix: save/restore sampler state during speculative checkpoint

When speculative decoding rejects draft tokens and restores the
recurrent state checkpoint, the sampler (RNG, grammar, prev tokens)
must also be restored to maintain consistency. Without this, the
sampler state reflects the rejected draft tokens, leading to
potential divergence.

Uses common_sampler_clone() to snapshot the sampler before the
speculative batch decode, and restores it on rejection.

* server: snapshot recurrent state in tensor

* reset ngram mod state for rejected tokens

* server: refactor checkpoint state logic

* speculative: fix sampler for checkpoints

* recurrent model: implement recurrent kernel checkpoint

* recurrent model: refactor api

* spec: free rbudget before overwriting
This commit is contained in:
Samuel Oliveira Alves
2026-04-24 04:59:30 -03:00
committed by GitHub
parent 1c13288164
commit ea94afe777
15 changed files with 943 additions and 54 deletions

View File

@@ -1033,6 +1033,23 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.speculative.p_min = std::stof(argv[i]);
return true;
}
if (arg == "--recurrent-ckpt-mode") {
CHECK_ARG
const std::string val = argv[i];
if (val == "auto" || val == "AUTO") {
params.speculative.recurrent_ckpt_mode = LLAMA_SPEC_CKPT_AUTO;
} else if (val == "per-step" || val == "PER_STEP") {
params.speculative.recurrent_ckpt_mode = LLAMA_SPEC_CKPT_PER_STEP;
} else if (val == "gpu-fallback" || val == "GPU_FALLBACK") {
params.speculative.recurrent_ckpt_mode = LLAMA_SPEC_CKPT_GPU_FALLBACK;
} else if (val == "cpu" || val == "CPU") {
params.speculative.recurrent_ckpt_mode = LLAMA_SPEC_CKPT_CPU;
} else {
throw std::invalid_argument("unknown --recurrent-ckpt-mode value: " + val +
"; expected auto, per-step, gpu-fallback, or cpu");
}
return true;
}
if (arg == "--spec-autotune") {
params.speculative.autotune = true;
return true;
@@ -2732,6 +2749,11 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
"number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max });
options.push_back({ "*", "--draft-min, --draft-n-min N", "minimum number of draft tokens to use for speculative decoding" });
options.push_back({ "*", "--draft-p-min P", "minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.speculative.p_min });
options.push_back({ "*", "--recurrent-ckpt-mode MODE", "checkpoint strategy for recurrent/hybrid speculative decoding\n"
" auto auto-select: per-step if CUDA full-GPU, gpu-fallback otherwise (default)\n"
" per-step save SSM state per draft step in VRAM; no re-decode on rejection\n"
" gpu-fallback copy state to GPU buffer; re-decode on rejection\n"
" cpu serialise state via llama_state_seq; re-decode on rejection" });
options.push_back({ "*", "--spec-type Name [none | ngram - cache | ngram - simple | ngram - map - k | ngram - map - k4v | ngram - mod | suffix]", "type of speculative decoding to use when no draft model is provided (default: %d)\n", (int)params.speculative.type});
options.push_back({ "*", "--spec-ngram-size-n N", "ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)\n",params.speculative.ngram_size_n });