Speculative checkpoints for recurrent models (#1669)

* server: spec checkpoints for recurrent models

* fix: save/restore sampler state during speculative checkpoint

When speculative decoding rejects draft tokens and restores the
recurrent state checkpoint, the sampler (RNG, grammar, prev tokens)
must also be restored to maintain consistency. Without this, the
sampler state reflects the rejected draft tokens, leading to
potential divergence.

Uses common_sampler_clone() to snapshot the sampler before the
speculative batch decode, and restores it on rejection.

* server: snapshot recurrent state in tensor

* reset ngram mod state for rejected tokens

* server: refactor checkpoint state logic

* speculative: fix sampler for checkpoints

* recurrent model: implement recurrent kernel checkpoint

* recurrent model: refactor api

* spec: free rbudget before overwriting
This commit is contained in:
Samuel Oliveira Alves
2026-04-24 04:59:30 -03:00
committed by GitHub
parent 1c13288164
commit ea94afe777
15 changed files with 943 additions and 54 deletions

View File

@@ -218,6 +218,12 @@ void llama_sampling_set_rng_seed(struct common_sampler * ctx, uint32_t seed) {
}
void common_sampler_clone(common_sampler * src, common_sampler * dst) {
dst->params = src->params;
dst->mirostat_mu = src->mirostat_mu;
dst->n_valid = src->n_valid;
dst->rng = src->rng;
dst->server_biases = src->server_biases;
if (dst->grammar) {
llama_grammar_free(dst->grammar);
dst->grammar = nullptr;
@@ -230,7 +236,18 @@ void common_sampler_clone(common_sampler * src, common_sampler * dst) {
}
dst->prev = src->prev;
dst->smpl = llama_sampler_dry_clone(src->smpl);
if (dst->smpl) {
llama_sampler_dry_free(dst->smpl);
dst->smpl = nullptr;
}
if (src->smpl) {
dst->smpl = llama_sampler_dry_clone(src->smpl);
}
if (dst->rbudget) {
common_reasoning_budget_free(dst->rbudget);
dst->rbudget = nullptr;
}
if (src->rbudget) {
dst->rbudget = common_reasoning_budget_clone(src->rbudget);
}