From ea94afe777e8dbfbd7ffddd3b33966db68eca20b Mon Sep 17 00:00:00 2001 From: Samuel Oliveira Alves <107287165+SamuelOliveirads@users.noreply.github.com> Date: Fri, 24 Apr 2026 04:59:30 -0300 Subject: [PATCH] Speculative checkpoints for recurrent models (#1669) * server: spec checkpoints for recurrent models * fix: save/restore sampler state during speculative checkpoint When speculative decoding rejects draft tokens and restores the recurrent state checkpoint, the sampler (RNG, grammar, prev tokens) must also be restored to maintain consistency. Without this, the sampler state reflects the rejected draft tokens, leading to potential divergence. Uses common_sampler_clone() to snapshot the sampler before the speculative batch decode, and restores it on rejection. * server: snapshot recurrent state in tensor * reset ngram mod state for rejected tokens * server: refactor checkpoint state logic * speculative: fix sampler for checkpoints * recurrent model: implement recurrent kernel checkpoint * recurrent model: refactor api * spec: free rbudget before overwriting --- common/common.cpp | 22 ++ common/common.h | 3 + common/ngram-map.cpp | 4 +- common/sampling.cpp | 19 +- common/speculative.cpp | 9 +- examples/server/server-context.cpp | 148 +++++++- examples/server/server-context.h | 13 + ggml/include/ggml.h | 3 +- ggml/src/ggml-cuda/delta-net.cu | 40 ++- ggml/src/ggml.c | 17 +- include/llama.h | 22 ++ src/llama-context.h | 71 ++++ src/llama-delta-net.cpp | 56 ++- src/llama-delta-net.h | 10 +- src/llama.cpp | 560 ++++++++++++++++++++++++++++- 15 files changed, 943 insertions(+), 54 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 0632b0a1..d308fe9d 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1033,6 +1033,23 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.speculative.p_min = std::stof(argv[i]); return true; } + if (arg == "--recurrent-ckpt-mode") { + CHECK_ARG + const std::string val = argv[i]; + if (val == "auto" || val == "AUTO") { + params.speculative.recurrent_ckpt_mode = LLAMA_SPEC_CKPT_AUTO; + } else if (val == "per-step" || val == "PER_STEP") { + params.speculative.recurrent_ckpt_mode = LLAMA_SPEC_CKPT_PER_STEP; + } else if (val == "gpu-fallback" || val == "GPU_FALLBACK") { + params.speculative.recurrent_ckpt_mode = LLAMA_SPEC_CKPT_GPU_FALLBACK; + } else if (val == "cpu" || val == "CPU") { + params.speculative.recurrent_ckpt_mode = LLAMA_SPEC_CKPT_CPU; + } else { + throw std::invalid_argument("unknown --recurrent-ckpt-mode value: " + val + + "; expected auto, per-step, gpu-fallback, or cpu"); + } + return true; + } if (arg == "--spec-autotune") { params.speculative.autotune = true; return true; @@ -2732,6 +2749,11 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param "number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max }); options.push_back({ "*", "--draft-min, --draft-n-min N", "minimum number of draft tokens to use for speculative decoding" }); options.push_back({ "*", "--draft-p-min P", "minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.speculative.p_min }); + options.push_back({ "*", "--recurrent-ckpt-mode MODE", "checkpoint strategy for recurrent/hybrid speculative decoding\n" + " auto auto-select: per-step if CUDA full-GPU, gpu-fallback otherwise (default)\n" + " per-step save SSM state per draft step in VRAM; no re-decode on rejection\n" + " gpu-fallback copy state to GPU buffer; re-decode on rejection\n" + " cpu serialise state via llama_state_seq; re-decode on rejection" }); options.push_back({ "*", "--spec-type Name [none | ngram - cache | ngram - simple | ngram - map - k | ngram - map - k4v | ngram - mod | suffix]", "type of speculative decoding to use when no draft model is provided (default: %d)\n", (int)params.speculative.type}); options.push_back({ "*", "--spec-ngram-size-n N", "ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)\n",params.speculative.ngram_size_n }); diff --git a/common/common.h b/common/common.h index 37cf6ec2..734d93de 100644 --- a/common/common.h +++ b/common/common.h @@ -164,6 +164,9 @@ struct common_ngram_mod; struct common_params_speculative { common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE; // type of speculative decoding + // Recurrent-model checkpoint strategy for speculative decoding. + int recurrent_ckpt_mode = LLAMA_SPEC_CKPT_AUTO; + std::string devices; std::string params; int32_t n_threads = -1; diff --git a/common/ngram-map.cpp b/common/ngram-map.cpp index 48699428..affb26c1 100644 --- a/common/ngram-map.cpp +++ b/common/ngram-map.cpp @@ -208,7 +208,7 @@ void common_ngram_map_begin( count_keys, count_keys_del, count_values_del, count_map_entries_upd); } - map.idx_last_check = (map.size_last_begin > 0) ? map.size_last_begin - 1 : 0; + map.idx_last_check = size_begin; map.size_last_begin = size_begin; } @@ -386,7 +386,7 @@ void common_ngram_map_draft(common_ngram_map & map, LOG_DBG("%s: key_idx = %zu, key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__, curr_key.key_idx, key_offset, curr_key.key_num, draft.size()); - map.last_draft_created = false; + map.last_draft_created = true; map.last_draft_key_idx = key_offset; map.last_draft_value_idx = 0; // value 0 is used for simple mode return; diff --git a/common/sampling.cpp b/common/sampling.cpp index 3fb18358..d573b5e7 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -218,6 +218,12 @@ void llama_sampling_set_rng_seed(struct common_sampler * ctx, uint32_t seed) { } void common_sampler_clone(common_sampler * src, common_sampler * dst) { + dst->params = src->params; + dst->mirostat_mu = src->mirostat_mu; + dst->n_valid = src->n_valid; + dst->rng = src->rng; + dst->server_biases = src->server_biases; + if (dst->grammar) { llama_grammar_free(dst->grammar); dst->grammar = nullptr; @@ -230,7 +236,18 @@ void common_sampler_clone(common_sampler * src, common_sampler * dst) { } dst->prev = src->prev; - dst->smpl = llama_sampler_dry_clone(src->smpl); + if (dst->smpl) { + llama_sampler_dry_free(dst->smpl); + dst->smpl = nullptr; + } + if (src->smpl) { + dst->smpl = llama_sampler_dry_clone(src->smpl); + } + + if (dst->rbudget) { + common_reasoning_budget_free(dst->rbudget); + dst->rbudget = nullptr; + } if (src->rbudget) { dst->rbudget = common_reasoning_budget_clone(src->rbudget); } diff --git a/common/speculative.cpp b/common/speculative.cpp index 00696b65..f76d40ab 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -598,6 +598,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { i_last = 0; n_draft_last = 0; + n_low = 0; const size_t n = mod.get_n(); @@ -1265,13 +1266,11 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) { spec->t_step_start_us = 0; } - if (n_accepted == 0) { - return; - } - common_speculative_state * impl = spec->curr_impl; - GGML_ASSERT(impl); + if (!impl) { + return; + } { common_time_meas tm(impl->t_accept_us, !impl->gen_perf); diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index 71a818b3..e5e1ec1d 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -22,6 +22,51 @@ static void log_text(const gpt_params & params_base, const std::string & text) { } } +void server_speculative_checkpoint::clear() { + valid = false; + per_step_enabled = false; + n_past = 0; + sampled = LLAMA_TOKEN_NULL; + + if (sampler != nullptr) { + common_sampler_free(sampler); + sampler = nullptr; + } +} + +static void discard_speculative_checkpoint(server_slot & slot, llama_context * ctx) { + slot.spec_ckpt.clear(); + llama_spec_ckpt_discard(ctx); +} + +static bool save_speculative_checkpoint(server_slot & slot, llama_model * model, llama_context * ctx, int ckpt_mode) { + slot.spec_ckpt.clear(); + slot.spec_ckpt.n_past = slot.n_past - (int32_t)(slot.drafted.size() + 1); + slot.spec_ckpt.sampled = slot.sampled; + + const int max_tokens = (int)slot.drafted.size() + 1; + const int actual_mode = llama_spec_ckpt_init(ctx, ckpt_mode, max_tokens); + if (actual_mode == LLAMA_SPEC_CKPT_NONE) { + return false; + } + slot.spec_ckpt.per_step_enabled = (actual_mode == LLAMA_SPEC_CKPT_PER_STEP); + + slot.spec_ckpt.valid = llama_spec_ckpt_save(ctx, slot.id); + if (!slot.spec_ckpt.valid) { + llama_spec_ckpt_discard(ctx); + return false; + } + + slot.spec_ckpt.sampler = common_sampler_init(model, slot.sparams); + if (slot.spec_ckpt.sampler == nullptr) { + discard_speculative_checkpoint(slot, ctx); + return false; + } + + common_sampler_clone(slot.ctx_sampling, slot.spec_ckpt.sampler); + return true; +} + server_context::~server_context() { if (ctx) { llama_free(ctx); @@ -49,6 +94,7 @@ server_context::~server_context() { if (slot.ctx_sampling != nullptr) { common_sampler_free(slot.ctx_sampling); } + slot.spec_ckpt.clear(); if (slot.ctx_dft) { llama_free(slot.ctx_dft); } @@ -112,15 +158,6 @@ bool server_context::load_model(const gpt_params& params_) { } // Load draft model for speculative decoding if specified if (has_draft_model) { - - if (llama_model_has_recurrent(model)) { - LLAMA_LOG_WARN("\n=======================================================================\n"); - LLAMA_LOG_WARN(" Speculative decodong is not suported for recurrent/hybrid models\n"); - LLAMA_LOG_WARN(" --> bailing out\n"); - LLAMA_LOG_WARN("========================================================================\n\n"); - GGML_ABORT("Fatal error"); - } - LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n"); gpt_params params_dft; @@ -387,6 +424,7 @@ void server_slot::reset() { n_sent_text = 0; drafted.clear(); i_batch_dft.clear(); + spec_ckpt.clear(); n_sent_token_probs = 0; infill = false; ga_i = 0; @@ -3679,6 +3717,72 @@ void server_context::extend_context(const int32_t n_tokens) { } } +// Restore recurrent state and re-decode accepted tokens after speculative-decode rejection. +static void restore_speculative_checkpoint( + server_slot & slot, llama_context * ctx, llama_model * model, + const std::vector & ids, int n_draft) { + if (slot.spec_ckpt.per_step_enabled) { + const int step = (int)ids.size() - 1; + llama_spec_ckpt_restore(ctx, slot.id, slot.spec_ckpt.n_past, step); + + if (slot.spec_ckpt.sampler) { + common_sampler_clone(slot.spec_ckpt.sampler, slot.ctx_sampling); + } + for (llama_token id : ids) { + common_sampler_accept(slot.ctx_sampling, ctx, id, true); + } + + SLT_DBG(slot, "per-step restore: step=%d (rejected %d drafts)\n", + step, (int)(n_draft - (ids.size() - 1))); + } else { + // Restore pre-speculation recurrent state then re-decode accepted tokens. + llama_spec_ckpt_restore(ctx, slot.id, slot.spec_ckpt.n_past, 0); + + if (slot.spec_ckpt.sampler) { + common_sampler_clone(slot.spec_ckpt.sampler, slot.ctx_sampling); + } + + if (!ids.empty()) { + // Re-decode to advance recurrent state to the accepted position. + const int n_re = (int)ids.size(); + llama_batch re_batch = llama_batch_init(n_re, 0, 1); + common_batch_add(re_batch, slot.spec_ckpt.sampled, slot.spec_ckpt.n_past, { slot.id }, n_re == 1); + for (int j = 0; j < n_re - 1; j++) { + common_batch_add(re_batch, ids[j], slot.spec_ckpt.n_past + 1 + j, { slot.id }, j == n_re - 2); + } + + if (slot.has_mtp) { + llama_set_embeddings(ctx, true); + } + + const int ret = llama_decode(ctx, re_batch); + if (ret != 0) { + SLT_ERR(slot, "failed to re-decode accepted tokens after checkpoint restore: %d\n", ret); + } + + if (slot.has_mtp) { + llama_set_embeddings(ctx, false); + const int n_embd = llama_model_n_embd(llama_get_model(ctx)); + const float * emb = llama_get_embeddings_ith(ctx, -1); + if (emb) { + slot.mtp_hidden_state.resize(n_embd); + memcpy(slot.mtp_hidden_state.data(), emb, n_embd * sizeof(float)); + } + } + + for (llama_token id : ids) { + common_sampler_accept(slot.ctx_sampling, ctx, id, true); + } + + llama_batch_free(re_batch); + SLT_DBG(slot, "spec checkpoint restored: re-decoded %d tokens (rejected %d drafts)\n", + n_re, (int)(n_draft - (ids.size() - 1))); + } + } + + discard_speculative_checkpoint(slot, ctx); +} + void server_context::speculative_decoding_accept() { for (auto& slot : slots) { if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch_dft.empty()) { @@ -3739,7 +3843,14 @@ void server_context::speculative_decoding_accept() { slot.sampled = ids.back(); // last accepted token slot.n_past = slot.cache_tokens.n_tokens(); - llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1); + // for recurrent/hybrid models: if any drafts were rejected, restore recurrent state + const bool any_rejected = (ids.size() - 1) < n_draft; + if (any_rejected && slot.spec_ckpt.valid) { + restore_speculative_checkpoint(slot, ctx, model, ids, n_draft); + } else { + llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1); + discard_speculative_checkpoint(slot, ctx); + } for (size_t i = 0; i < ids.size(); ++i) { completion_token_output result; @@ -4305,6 +4416,23 @@ void server_context::update_slots() { // make sure we're in the right embedding mode llama_set_embeddings(ctx, batch_type == 1); + if (llama_model_has_recurrent(model)) { + const int ckpt_mode = params_base.speculative.recurrent_ckpt_mode; + + for (auto & slot : slots) { + if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch_dft.empty()) { + continue; + } + if (save_speculative_checkpoint(slot, model, ctx, ckpt_mode)) { + const char * mode_name = slot.spec_ckpt.per_step_enabled ? "per-step" : "shadow/cpu"; + SLT_DBG(slot, "spec checkpoint saved (mode=%s), n_past_pre_spec=%d\n", + mode_name, slot.spec_ckpt.n_past); + } else { + SLT_WRN(slot, "%s", "failed to save spec checkpoint\n"); + } + } + } + // process the created batch of tokens process_batch_tokens(n_batch); // Decode with batch diff --git a/examples/server/server-context.h b/examples/server/server-context.h index f42c6d46..074787b5 100644 --- a/examples/server/server-context.h +++ b/examples/server/server-context.h @@ -22,6 +22,16 @@ enum slot_command { SLOT_COMMAND_RELEASE, }; +struct server_speculative_checkpoint { + bool valid = false; + bool per_step_enabled = false; // per-step SSM checkpoints active + llama_pos n_past = 0; + llama_token sampled = LLAMA_TOKEN_NULL; + common_sampler * sampler = nullptr; // saved sampler state + + void clear(); +}; + struct server_slot { int id; int id_task = -1; @@ -160,6 +170,9 @@ struct server_slot { bool has_mtp = false; std::vector mtp_hidden_state; + // saves recurrent state before a speculative batch so it can be restored on rejection + server_speculative_checkpoint spec_ckpt; + // speculative decoding stats int32_t n_draft_total = 0; // Total draft tokens generated int32_t n_draft_accepted = 0; // Draft tokens actually accepted diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 0d164166..250d924e 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2526,7 +2526,8 @@ extern "C" { struct ggml_tensor * v, struct ggml_tensor * g, struct ggml_tensor * beta, - struct ggml_tensor * state); + struct ggml_tensor * state, + bool save_all_steps); // custom operators diff --git a/ggml/src/ggml-cuda/delta-net.cu b/ggml/src/ggml-cuda/delta-net.cu index 35233255..d43a855c 100644 --- a/ggml/src/ggml-cuda/delta-net.cu +++ b/ggml/src/ggml-cuda/delta-net.cu @@ -35,13 +35,14 @@ __global__ void delta_net_recurrent_f32( const float * __restrict__ g, // [n_tokens, 1, n_heads, n_seqs] const float * __restrict__ beta_in, // [1, n_tokens, n_heads, n_seqs] const float * __restrict__ state_in, // [HEAD_DIM, HEAD_DIM*n_heads, 1, n_seqs] - float * __restrict__ dst, // output + new_state concatenated + float * __restrict__ dst, // output + new_state(s) concatenated const int64_t n_heads, const int64_t gqa_ratio, const int repeat_type, const int64_t n_tokens, const int64_t n_seqs, const int64_t output_offset, // offset where state starts in output + const int save_all_states, // 1 = save per-step states, 0 = final only size_t vnb1, size_t vnb2, size_t vnb3) { constexpr int warps_per_head = HEAD_DIM/WARP_SIZE; const int batch_idx = blockIdx.x / (warps_per_head*n_heads); @@ -69,6 +70,9 @@ __global__ void delta_net_recurrent_f32( const int64_t state_head_offset = head_idx * HEAD_DIM * HEAD_DIM; const int64_t state_batch_stride = HEAD_DIM * HEAD_DIM * n_heads; + // State step stride for save_all_states: HEAD_DIM^2 * n_heads * n_seqs + const int64_t state_step_stride = HEAD_DIM * HEAD_DIM * n_heads * n_seqs; + // Pointers for this batch/head const float * q_ptr = q + batch_idx * qkv_stride_batch_kq + head_idx_kq * qkv_stride_head; const float * k_ptr = k + batch_idx * qkv_stride_batch_kq + head_idx_kq * qkv_stride_head; @@ -155,6 +159,15 @@ __global__ void delta_net_recurrent_f32( state_local[i] = new_state_val; } + // Save per-step state if requested + if (save_all_states) { + float * state_step_dst = dst + output_offset + t * state_step_stride + batch_idx * state_batch_stride + state_head_offset; + for (int i = 0; i < HEAD_DIM/num_warps; ++i) { + int col = num_warps*i + col_idx_0; + state_step_dst[col*HEAD_DIM + row_out] = state_local[i]; + } + } + // Barrier required: (a) sK reads in the state update above must complete // before next iteration overwrites sK at the top of the loop, and (b) this // single barrier also orders all_sum1/all_sum2 reads above vs. the next @@ -163,9 +176,11 @@ __global__ void delta_net_recurrent_f32( __syncthreads(); } // Copy the final state to its destination - for (int i = 0; i < HEAD_DIM/num_warps; ++i) { - int col = num_warps*i + col_idx_0; - state_dst[col*HEAD_DIM + row_out] = state_local[i]; + if (!save_all_states) { + for (int i = 0; i < HEAD_DIM/num_warps; ++i) { + int col = num_warps*i + col_idx_0; + state_dst[col*HEAD_DIM + row_out] = state_local[i]; + } } } @@ -183,9 +198,10 @@ static void delta_net_f32_cuda( const int64_t gqa_ratio, const int repeat_type, const int64_t n_seqs, + const int save_all_states, size_t vnb1, size_t vnb2, size_t vnb3, const int device_id, - const int cc, // compute capability (e.g., 890 for SM 8.9, 1200 for SM 12.0) + const int cc, // compute capability (e.g., 890 for SM 8.9, 1200 for SM 12.0) cudaStream_t stream) { GGML_UNUSED(device_id); GGML_UNUSED(cc); @@ -204,19 +220,19 @@ static void delta_net_f32_cuda( constexpr int threads_per_block = 256; if (head_dim == 64) { delta_net_recurrent_f32<64, threads_per_block><<>>( - q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, vnb1, vnb2, vnb3); + q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, save_all_states, vnb1, vnb2, vnb3); } else { delta_net_recurrent_f32<128, threads_per_block><<>>( - q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, vnb1, vnb2, vnb3); + q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, save_all_states, vnb1, vnb2, vnb3); } } else { constexpr int threads_per_block = 128; if (head_dim == 64) { delta_net_recurrent_f32<64, threads_per_block><<>>( - q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, vnb1, vnb2, vnb3); + q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, save_all_states, vnb1, vnb2, vnb3); } else { delta_net_recurrent_f32<128, threads_per_block><<>>( - q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, vnb1, vnb2, vnb3); + q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, save_all_states, vnb1, vnb2, vnb3); } } @@ -258,9 +274,12 @@ void ggml_cuda_op_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) // Verify output tensor size const int64_t output_size = head_dim * n_tokens * n_heads * n_seqs; const int64_t state_size = head_dim * head_dim * n_heads * n_seqs; - GGML_ASSERT(ggml_nelements(dst) == output_size + state_size); int repeat_type = dst->op_params[0]; + int save_all_states = dst->op_params[1]; + + const int64_t expected_size = save_all_states ? (output_size + n_tokens * state_size) : (output_size + state_size); + GGML_ASSERT(ggml_nelements(dst) == expected_size); GGML_ASSERT(head_dim <= 256); // Reasonable limit for shared memory @@ -277,6 +296,7 @@ void ggml_cuda_op_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) (const float *)src5->data, (float *)dst->data, head_dim, n_tokens, n_heads, gqa_ratio, repeat_type, n_seqs, + save_all_states, src2->nb[1]/sizeof(float), src2->nb[2]/sizeof(float), src2->nb[3]/sizeof(float), device_id, cc, ctx.stream()); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 59842a2c..4454ffb9 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -9940,7 +9940,8 @@ struct ggml_tensor * ggml_delta_net( struct ggml_tensor * v, struct ggml_tensor * g, struct ggml_tensor * beta, - struct ggml_tensor * state) { + struct ggml_tensor * state, + bool save_all_steps) { GGML_ASSERT(ggml_is_contiguous(q)); GGML_ASSERT(ggml_is_contiguous(k)); GGML_ASSERT(ggml_is_contiguous(state)); @@ -9971,9 +9972,11 @@ struct ggml_tensor * ggml_delta_net( const int64_t output_size = S_v * H_v * n_tokens * n_seqs; const int64_t state_size = S_v * S_v * H_v * n_seqs; - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, output_size + state_size); + const int64_t state_slots = save_all_steps ? n_tokens : 1; + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, output_size + state_slots * state_size); result->op = GGML_OP_DELTA_NET; + result->op_params[1] = save_all_steps ? 1 : 0; result->src[0] = q; result->src[1] = k; result->src[2] = v; @@ -22654,17 +22657,19 @@ static void ggml_compute_forward_delta_net_f32( const float * beta_data = (const float *) src4->data; const float * state_in = (const float *) src5->data; float * out_data = (float *) dst->data; - float * state_out = out_data + output_size; const int ith = params->ith; const int nth = params->nth; int repeat_type = dst->op_params[0]; + // save_all_steps is handled by the CUDA backend only; + // the CPU path always writes to the single state slot after the output. + float * state_working = out_data + output_size; if (iqk_fused_delta_net(head_dim, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, src2->nb[1]/sizeof(float), src2->nb[2]/sizeof(float), src2->nb[3]/sizeof(float), q_data, k_data, v_data, g_data, beta_data, state_in, - out_data, state_out, ith, nth)) { + out_data, state_working, ith, nth)) { return; } @@ -22694,10 +22699,10 @@ static void ggml_compute_forward_delta_net_f32( const int64_t out_token_stride = head_dim * n_heads; for (int64_t i = 0; i < head_dim * head_dim; ++i) { - state_out[state_head_offset + i] = state_in[state_head_offset + i]; + state_working[state_head_offset + i] = state_in[state_head_offset + i]; } - float * state = state_out + state_head_offset; + float * state = state_working + state_head_offset; for (int64_t t = 0; t < n_tokens; ++t) { const float * q_t = q_data + qkv_head_offset_kq + t * qkv_token_stride; diff --git a/include/llama.h b/include/llama.h index 732bdde7..8d6cd295 100644 --- a/include/llama.h +++ b/include/llama.h @@ -800,6 +800,28 @@ extern "C" { LLAMA_API void llama_kv_cache_clear( struct llama_context * ctx); + // Unified checkpoint API for recurrent/hybrid speculative decoding. + enum llama_spec_ckpt_mode { + LLAMA_SPEC_CKPT_NONE = -1, + LLAMA_SPEC_CKPT_AUTO = 0, + LLAMA_SPEC_CKPT_PER_STEP = 1, + LLAMA_SPEC_CKPT_GPU_FALLBACK = 2, + LLAMA_SPEC_CKPT_CPU = 3, + }; + + // Initialise the checkpoint system for the upcoming speculation window. + LLAMA_API int llama_spec_ckpt_init(struct llama_context * ctx, int mode, int max_tokens); + + // Save the current recurrent state as a speculative checkpoint. + LLAMA_API bool llama_spec_ckpt_save(struct llama_context * ctx, llama_seq_id seq_id); + + // Restore the recurrent state after speculative decode. + LLAMA_API bool llama_spec_ckpt_restore(struct llama_context * ctx, llama_seq_id seq_id, + llama_pos n_past, int accepted_step); + + // Discard the saved checkpoint and reset internal mode state. + LLAMA_API void llama_spec_ckpt_discard(struct llama_context * ctx); + // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails // seq_id < 0 : match any sequence diff --git a/src/llama-context.h b/src/llama-context.h index be563122..7b6e56cf 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -59,6 +59,9 @@ struct llama_kv_cache { std::vector v_l; std::vector s_l; // per layer recurrent state storage (Qwen3Next) + // When true, the delta_net graph builder will enable per-step SSM state saves + bool save_per_step_ssm = false; + std::vector split_k_l; std::vector split_v_l; std::vector split_s_l; @@ -74,6 +77,74 @@ struct llama_kv_cache { return size; } + // GPU-resident checkpoint for recurrent/hybrid speculative decoding + struct gpu_checkpoint { + std::vector cells_snapshot; + uint32_t head_snapshot = 0; + uint32_t used_snapshot = 0; + + std::vector s_l_shadow; + + std::vector> split_s_l_shadow; + + // Per-step SSM state checkpoints for speculative decoding. + std::vector per_step_ssm; + + // Per-step conv feature buffer: stores qkv_mixed features from the + // verification forward pass so conv state can be reconstructed at any step. + // One tensor per recurrent layer, each sized [conv_dim * max_tokens]. + std::vector per_step_qkv; + + int32_t per_step_n_tokens = 0; + int32_t per_step_max_allocated = 0; + int64_t per_step_ssm_state_size = 0; + int64_t per_step_conv_state_dim = 0; + int64_t per_step_conv_dim = 0; + int32_t per_step_d_conv = 0; + + int selected_spec_mode = -1; + + // Serialised sequence state for CPU mode + std::vector cpu_state_data; + + // Separate storage for per-step allocations + std::vector per_step_ctxs; + std::vector per_step_bufs; + + std::vector shadow_ctxs; + std::vector shadow_bufs; + + bool allocated = false; + bool saved = false; + + ~gpu_checkpoint() { + for (struct ggml_context * ctx : shadow_ctxs) { + ggml_free(ctx); + } + for (ggml_backend_buffer_t buf : shadow_bufs) { + ggml_backend_buffer_free(buf); + } + for (struct ggml_context * ctx : per_step_ctxs) { + ggml_free(ctx); + } + for (ggml_backend_buffer_t buf : per_step_bufs) { + ggml_backend_buffer_free(buf); + } + } + }; + + gpu_checkpoint ckpt; + + bool checkpoint_alloc_shadows(); + bool checkpoint_supported() const; + bool checkpoint_save(); + bool checkpoint_restore(); + void checkpoint_delete(); + + // Per-step checkpoint: allocate, restore step k's full state (SSM + conv) to cache + bool per_step_alloc(int max_tokens); + bool per_step_restore(int step); + ~llama_kv_cache() { for (struct ggml_context * ctx : ctxs) { ggml_free(ctx); diff --git a/src/llama-delta-net.cpp b/src/llama-delta-net.cpp index 060294df..1c24fc2e 100644 --- a/src/llama-delta-net.cpp +++ b/src/llama-delta-net.cpp @@ -70,6 +70,7 @@ delta_net::delta_net(llama_context & _lctx, const llama_batch & _batch) : lctx(_ GGML_ASSERT((uint32_t) s < qnext_state_slots); } + save_per_step_states = lctx.kv_self.save_per_step_ssm && batch.n_tokens > 1; } delta_net::~delta_net() = default; @@ -77,7 +78,9 @@ delta_net::~delta_net() = default; std::pair delta_net::build_fused_delta_net(ggml_context * ctx0, ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state, - int il, const llm_build_cb & cb, int repeat_type) { + int il, const llm_build_cb & cb, int repeat_type, + bool save_all_steps, + ggml_cgraph * gf, ggml_tensor * per_step_ckpt) { const int64_t S_k = q->ne[0]; const int64_t H_k = q->ne[2]; @@ -119,7 +122,7 @@ std::pair delta_net::build_fused_delta_net(ggml_co cb(beta, "beta_fused", il); cb(state_flat,"state_fused", il); - ggml_tensor * fused_result = ggml_delta_net(ctx0, q, k, v, g, beta, state_flat); + ggml_tensor * fused_result = ggml_delta_net(ctx0, q, k, v, g, beta, state_flat, save_all_steps); cb(fused_result, "delta_net_fused_raw", il); fused_result->op_params[0] = repeat_type; @@ -133,13 +136,34 @@ std::pair delta_net::build_fused_delta_net(ggml_co ggml_row_size(fused_result->type, S_v * H_v * n_tokens), 0); //output_tokens = ggml_cont_4d(ctx0, output_tokens, S_v, H_v, n_tokens, n_seqs); + // per-step states are at [output_size, output_size + n_tokens*state_size) + const int64_t last_state_offset = save_all_steps + ? (output_size + (n_tokens - 1) * state_size) + : output_size; + ggml_tensor * new_state_flat = ggml_view_1d(ctx0, fused_result, state_size, - output_size * ggml_element_size(fused_result)); + last_state_offset * ggml_element_size(fused_result)); ggml_tensor * new_state = ggml_reshape_4d(ctx0, new_state_flat, S_v, S_v, H_v, n_seqs); cb(output_tokens, "output_tokens", il); cb(new_state, "new_state", il); + // Copy all per-step SSM states to persistent checkpoint tensor + if (save_all_steps && per_step_ckpt != nullptr && gf != nullptr && n_tokens > 1) { + const int64_t per_step_total = n_tokens * state_size; + if (per_step_total <= ggml_nelements(per_step_ckpt)) { + ggml_tensor * all_steps_src = ggml_view_1d(ctx0, fused_result, per_step_total, + output_size * ggml_element_size(fused_result)); + ggml_tensor * ckpt_dst = ggml_view_1d(ctx0, per_step_ckpt, per_step_total, 0); + auto ckpt_cpy = ggml_cpy(ctx0, all_steps_src, ckpt_dst); + cb(ckpt_cpy, "per_step_ckpt_cpy", il); + ggml_build_forward_expand(gf, ckpt_cpy); + } else { + LLAMA_LOG_WARN("%s: per-step checkpoint tensor too small for %lld tokens (need %lld, have %lld), skipping per-step save\n", + __func__, (long long)n_tokens, (long long)per_step_total, (long long)ggml_nelements(per_step_ckpt)); + } + } + return {output_tokens, new_state}; } @@ -281,7 +305,8 @@ ggml_tensor * delta_net::build_qkv(ggml_context * ctx0, ggml_tensor * state_stor ggml_tensor * qkv_mixed, ggml_tensor * inp_s_seq_qnext, ggml_tensor * beta, ggml_tensor * gate, int64_t head_k_dim, int64_t num_k_heads, int64_t head_v_dim, int64_t num_v_heads, int64_t ssm_d_conv, int64_t state_seq_id_local, uint32_t qnext_state_slots, bool reset_state_local, - float eps_norm, int repeat_type, int il, const llm_build_cb & cb, ggml_cgraph * gf) { + float eps_norm, int repeat_type, int il, const llm_build_cb & cb, ggml_cgraph * gf, + bool save_per_step_states, ggml_tensor * per_step_ckpt) { const int64_t key_dim = head_k_dim * num_k_heads; const int64_t value_dim = head_v_dim * num_v_heads; const int64_t conv_dim = key_dim * 2 + value_dim; @@ -366,7 +391,8 @@ ggml_tensor * delta_net::build_qkv(ggml_context * ctx0, ggml_tensor * state_stor cb(q_conv, "q_conv_normed", il); cb(k_conv, "k_conv_normed", il); - auto [output, new_state] = build_fused_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb, repeat_type); + auto [output, new_state] = build_fused_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb, repeat_type, + save_per_step_states, gf, per_step_ckpt); cb(output, "attn_output", il); cb(new_state, "new_state", il); @@ -566,11 +592,29 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_ auto [beta, gate] = build_beta_gate(lctx, ctx0, model.layers[il].ssm_beta_alpha, model.layers[il].ssm_beta, model.layers[il].ssm_alpha, model.layers[il].ssm_dt, model.layers[il].ssm_a, num_k_heads, num_v_heads, n_seqs, cur, il, cb, gf); + // Get per-step checkpoint tensor if available + ggml_tensor * per_step_ckpt = nullptr; + if (save_per_step_states && il < (int)kv_self.ckpt.per_step_ssm.size()) { + per_step_ckpt = kv_self.ckpt.per_step_ssm[il]; + } + + // Save qkv_mixed features for per-step conv state reconstruction + if (save_per_step_states && il < (int)kv_self.ckpt.per_step_qkv.size() && kv_self.ckpt.per_step_qkv[il] != nullptr) { + const int64_t conv_dim = qkv_mixed->ne[0]; + const int64_t n_tok_qkv = qkv_mixed->ne[1] * qkv_mixed->ne[2]; + ggml_tensor * qkv_flat = ggml_reshape_2d(ctx0, qkv_mixed, conv_dim, n_tok_qkv); + ggml_tensor * qkv_dst = ggml_view_2d(ctx0, kv_self.ckpt.per_step_qkv[il], + conv_dim, n_tok_qkv, conv_dim * sizeof(float), 0); + auto qkv_cpy = ggml_cpy(ctx0, qkv_flat, qkv_dst); + ggml_build_forward_expand(gf, qkv_cpy); + } + auto output = build_qkv(ctx0, kv_self.s_l[il], model.layers[il].ssm_conv1d, qkv_mixed, inp_s_seq_qnext, beta, gate, head_k_dim, num_k_heads, head_v_dim, num_v_heads, hparams.ssm_d_conv, state_seq_id_local, qnext_state_slots, reset_state_local, hparams.f_norm_rms_eps, - model.layers[il].ssm_beta_alpha ? 0 : 1, il, cb, gf); + model.layers[il].ssm_beta_alpha ? 0 : 1, il, cb, gf, + save_per_step_states, per_step_ckpt); auto gated_output = build_gated_output(lctx, ctx0, model.layers[il].ssm_norm, model.layers[il].ssm_out, output, z, head_v_dim, num_v_heads, n_tok, il, cb); if (inp_out_ids) { diff --git a/src/llama-delta-net.h b/src/llama-delta-net.h index 58259633..2fe499ad 100644 --- a/src/llama-delta-net.h +++ b/src/llama-delta-net.h @@ -8,10 +8,15 @@ struct delta_net { delta_net(llama_context & lctx, const llama_batch & batch); ~delta_net(); + // Used for speculative decoding to enable per-step state checkpoint restoration. + bool save_per_step_states = false; + static std::pair build_fused_delta_net(ggml_context * ctx0, ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state, - int il, const llm_build_cb & cb, int repeat_type); + int il, const llm_build_cb & cb, int repeat_type, + bool save_all_steps = false, + ggml_cgraph * gf = nullptr, ggml_tensor * per_step_ckpt = nullptr); ggml_tensor * build_layer_attn_linear_core(ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * inp_s_seq_qnext, ggml_tensor * inp_out_ids, @@ -46,7 +51,8 @@ private: ggml_tensor * qkv_mixed, ggml_tensor * inp_s_seq_qnext, ggml_tensor * beta, ggml_tensor * gate, int64_t head_k_dim, int64_t num_k_heads, int64_t head_v_dim, int64_t num_v_heads, int64_t ssm_d_conv, int64_t state_seq_id_local, uint32_t qnext_state_slots, bool reset_state_local, - float eps_norm, int repeat_type, int il, const llm_build_cb & cb, ggml_cgraph * gf); + float eps_norm, int repeat_type, int il, const llm_build_cb & cb, ggml_cgraph * gf, + bool save_per_step_states = false, ggml_tensor * per_step_ckpt = nullptr); static ggml_tensor * build_gated_output(llama_context & lctx, ggml_context * ctx0, ggml_tensor * ssm_norm, ggml_tensor * ssm_out, ggml_tensor * output, ggml_tensor * z, int64_t head_v_dim, int64_t num_v_heads, int64_t n_tok, int il, const llm_build_cb & cb); diff --git a/src/llama.cpp b/src/llama.cpp index 3c3f1ef4..7a832716 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1204,6 +1204,380 @@ static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache, uin return 0; } +bool llama_kv_cache::checkpoint_supported() const { + for (const auto * s : s_l) { + if (s != nullptr) { + return true; + } + } + return false; +} + +bool llama_kv_cache::checkpoint_alloc_shadows() { + if (ckpt.allocated) { + return true; + } + + const uint32_t n_layer = (uint32_t)s_l.size(); + ckpt.s_l_shadow.resize(n_layer, nullptr); + + struct tensor_entry { + ggml_tensor * primary; + uint32_t il; + int split_idx; // -1 for non-split + }; + + const bool conv_only_shadow = save_per_step_ssm && ckpt.per_step_conv_state_dim > 0; + std::vector nonsplit_entries; + + std::map> split_buft_entries; + + uint32_t split_s_idx = 0; + for (uint32_t il = 0; il < n_layer; ++il) { + if (s_l[il] == nullptr) { + continue; + } + + auto * extra = s_l[il]->extra; + if (extra != nullptr) { + auto * split_info = (const ggml_split_tensor_t *)extra; + for (int d = 0; d < split_info->n_device; ++d) { + if (split_info->splits[d] == nullptr) continue; + ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(split_info->splits[d]->buffer); + split_buft_entries[buft].push_back({split_info->splits[d], il, d}); + } + split_s_idx++; + } else { + nonsplit_entries.push_back({s_l[il], il, -1}); + } + } + + if (!nonsplit_entries.empty()) { + ggml_init_params params = { + /*.mem_size =*/ nonsplit_entries.size() * ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context * ctx = ggml_init(params); + if (!ctx) { + LLAMA_LOG_ERROR("%s: failed to create ggml context for shadow tensors\n", __func__); + return false; + } + + for (auto & entry : nonsplit_entries) { + // Only need the conv portion when per-step is active. + const int64_t nelems = conv_only_shadow + ? ckpt.per_step_conv_state_dim + : (int64_t)ggml_nelements(entry.primary); + ggml_tensor * shadow = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nelems); + ggml_format_name(shadow, "shadow_s_l%d", entry.il); + ckpt.s_l_shadow[entry.il] = shadow; + } + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_cpu_buffer_type()); + if (!buf) { + LLAMA_LOG_ERROR("%s: failed to allocate CPU buffer for shadow tensors\n", __func__); + ggml_free(ctx); + return false; + } + ggml_backend_buffer_clear(buf, 0); + LLAMA_LOG_INFO("%s: CPU shadow buffer = %8.2f MiB (%s)\n", __func__, + ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0, + conv_only_shadow ? "conv-state only" : "full recurrent state"); + ckpt.shadow_ctxs.push_back(ctx); + ckpt.shadow_bufs.push_back(buf); + } + + // Allocate split shadows on their respective devices + for (auto & [buft, entries] : split_buft_entries) { + ggml_init_params params = { + /*.mem_size =*/ entries.size() * ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context * ctx = ggml_init(params); + if (!ctx) { + LLAMA_LOG_ERROR("%s: failed to create ggml context for split shadow tensors\n", __func__); + return false; + } + + for (auto & entry : entries) { + ggml_tensor * shadow = ggml_dup_tensor(ctx, entry.primary); + ggml_format_name(shadow, "shadow_s_l%d_d%d", entry.il, entry.split_idx); + entry.primary = shadow; + } + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + LLAMA_LOG_ERROR("%s: failed to allocate buffer for split shadow tensors\n", __func__); + ggml_free(ctx); + return false; + } + ggml_backend_buffer_clear(buf, 0); + LLAMA_LOG_INFO("%s: %10s split shadow buffer = %8.2f MiB\n", __func__, + ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0); + ckpt.shadow_ctxs.push_back(ctx); + ckpt.shadow_bufs.push_back(buf); + } + + // Build split shadow lookup + ckpt.split_s_l_shadow.resize(split_s_l.size()); + split_s_idx = 0; + for (uint32_t il = 0; il < n_layer; ++il) { + if (s_l[il] == nullptr || s_l[il]->extra == nullptr) { + continue; + } + + auto * split_info = (const ggml_split_tensor_t *)s_l[il]->extra; + auto & shadow_split = ckpt.split_s_l_shadow[split_s_idx]; + shadow_split.resize(split_info->n_device, nullptr); + + for (int d = 0; d < split_info->n_device; ++d) { + if (split_info->splits[d] == nullptr) continue; + ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(split_info->splits[d]->buffer); + for (auto & entry : split_buft_entries[buft]) { + if (entry.il == il && entry.split_idx == d) { + shadow_split[d] = entry.primary; + break; + } + } + } + split_s_idx++; + } + + ckpt.allocated = true; + return true; +} + +bool llama_kv_cache::checkpoint_save() { + if (!checkpoint_alloc_shadows()) { + return false; + } + + const uint32_t n_layer = (uint32_t)s_l.size(); + + ckpt.cells_snapshot = cells; + ckpt.head_snapshot = head; + ckpt.used_snapshot = used; + + uint32_t split_s_idx = 0; + for (uint32_t il = 0; il < n_layer; ++il) { + if (s_l[il] == nullptr) { + continue; + } + + if (s_l[il]->extra != nullptr) { + auto * split_info = (const ggml_split_tensor_t *)s_l[il]->extra; + auto & shadow_split = ckpt.split_s_l_shadow[split_s_idx]; + for (int d = 0; d < split_info->n_device; ++d) { + if (split_info->splits[d] && shadow_split[d]) { + ggml_backend_tensor_copy(split_info->splits[d], shadow_split[d]); + } + } + split_s_idx++; + } else { + const size_t nbytes = ggml_nbytes(ckpt.s_l_shadow[il]); + ggml_backend_tensor_get(s_l[il], ckpt.s_l_shadow[il]->data, 0, nbytes); + } + } + + ckpt.saved = true; + return true; +} + +bool llama_kv_cache::checkpoint_restore() { + if (!ckpt.saved) { + LLAMA_LOG_ERROR("%s: no checkpoint saved\n", __func__); + return false; + } + + const uint32_t n_layer = (uint32_t)s_l.size(); + + cells = ckpt.cells_snapshot; + head = ckpt.head_snapshot; + used = ckpt.used_snapshot; + + uint32_t split_s_idx = 0; + for (uint32_t il = 0; il < n_layer; ++il) { + if (s_l[il] == nullptr) { + continue; + } + + if (s_l[il]->extra != nullptr) { + auto * split_info = (const ggml_split_tensor_t *)s_l[il]->extra; + auto & shadow_split = ckpt.split_s_l_shadow[split_s_idx]; + for (int d = 0; d < split_info->n_device; ++d) { + if (split_info->splits[d] && shadow_split[d]) { + ggml_backend_tensor_copy(shadow_split[d], split_info->splits[d]); + } + } + split_s_idx++; + } else { + GGML_ASSERT(ggml_nbytes(ckpt.s_l_shadow[il]) == ggml_nbytes(s_l[il])); + ggml_backend_tensor_copy(ckpt.s_l_shadow[il], s_l[il]); + } + } + + return true; +} + +void llama_kv_cache::checkpoint_delete() { + ckpt.saved = false; +} + +bool llama_kv_cache::per_step_alloc(int max_tokens) { + if (ckpt.per_step_max_allocated >= max_tokens) { + return true; + } + + if (!ckpt.per_step_ssm.empty()) { + for (struct ggml_context * ctx : ckpt.per_step_ctxs) { + ggml_free(ctx); + } + for (ggml_backend_buffer_t buf : ckpt.per_step_bufs) { + ggml_backend_buffer_free(buf); + } + ckpt.per_step_ctxs.clear(); + ckpt.per_step_bufs.clear(); + ckpt.per_step_ssm.clear(); + ckpt.per_step_qkv.clear(); + ckpt.per_step_max_allocated = 0; + } + + const uint32_t n_layer = (uint32_t)s_l.size(); + ckpt.per_step_ssm.resize(n_layer, nullptr); + ckpt.per_step_qkv.resize(n_layer, nullptr); + + const int64_t ssm_state_dim = ckpt.per_step_ssm_state_size; + const int64_t conv_dim = ckpt.per_step_conv_dim; + if (ssm_state_dim <= 0 || conv_dim <= 0) { + LLAMA_LOG_ERROR("%s: per_step dimensions not set (ssm=%lld, conv_dim=%lld)\n", + __func__, (long long)ssm_state_dim, (long long)conv_dim); + return false; + } + + std::map>> buft_layers; + + for (uint32_t il = 0; il < n_layer; ++il) { + if (s_l[il] == nullptr) continue; + if (s_l[il]->extra != nullptr) continue; // skip split tensors + + ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(s_l[il]->buffer); + buft_layers[buft].push_back({il, buft}); + } + + for (auto & [buft, layers] : buft_layers) { + // 2 tensors per layer: SSM states + qkv features + ggml_init_params params = { + /*.mem_size =*/ layers.size() * 2 * ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context * ctx = ggml_init(params); + if (!ctx) { + LLAMA_LOG_ERROR("%s: failed to create ggml context for per-step checkpoints\n", __func__); + return false; + } + + for (auto & [il, bt] : layers) { + // SSM state: max_tokens * ssm_state_dim + ggml_tensor * t_ssm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (int64_t)max_tokens * ssm_state_dim); + ggml_format_name(t_ssm, "per_step_ssm_l%d", il); + ckpt.per_step_ssm[il] = t_ssm; + + // Conv features (qkv_mixed): max_tokens * conv_dim + ggml_tensor * t_qkv = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (int64_t)max_tokens * conv_dim); + ggml_format_name(t_qkv, "per_step_qkv_l%d", il); + ckpt.per_step_qkv[il] = t_qkv; + } + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + LLAMA_LOG_ERROR("%s: failed to allocate buffer for per-step checkpoints\n", __func__); + ggml_free(ctx); + return false; + } + ggml_backend_buffer_clear(buf, 0); + LLAMA_LOG_INFO("%s: %10s per-step buffer = %8.2f MiB (max_tokens=%d)\n", __func__, + ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0, max_tokens); + ckpt.per_step_ctxs.push_back(ctx); + ckpt.per_step_bufs.push_back(buf); + } + + ckpt.per_step_max_allocated = max_tokens; + return true; +} + +bool llama_kv_cache::per_step_restore(int step) { + if (ckpt.per_step_ssm.empty() || step < 0) { + return false; + } + + const int64_t ssm_state_dim = ckpt.per_step_ssm_state_size; + const int64_t conv_state_dim = ckpt.per_step_conv_state_dim; + const int64_t conv_dim = ckpt.per_step_conv_dim; + const int32_t d_conv = ckpt.per_step_d_conv; + if (ssm_state_dim <= 0 || conv_dim <= 0 || d_conv <= 1) return false; + + const int64_t ssm_bytes = ssm_state_dim * sizeof(float); + const int64_t conv_bytes = conv_state_dim * sizeof(float); + const int32_t d_conv_m1 = d_conv - 1; // number of columns in conv state + + std::vector ssm_buf(ssm_state_dim); + std::vector conv_buf(conv_state_dim); // reconstructed conv state + std::vector old_conv_buf(conv_state_dim); // pre-spec conv state from shadow + const int64_t qkv_needed = (int64_t)(step + 1) * conv_dim; + std::vector qkv_buf(qkv_needed); + + const uint32_t n_layer = (uint32_t)s_l.size(); + int n_restored = 0; + for (uint32_t il = 0; il < n_layer; ++il) { + if (s_l[il] == nullptr || ckpt.per_step_ssm[il] == nullptr) continue; + if (s_l[il]->extra != nullptr) continue; + + ggml_backend_tensor_get(ckpt.per_step_ssm[il], ssm_buf.data(), + (size_t)step * ssm_bytes, ssm_bytes); + + if (ckpt.s_l_shadow[il] != nullptr) { + ggml_backend_tensor_get(ckpt.s_l_shadow[il], old_conv_buf.data(), 0, conv_bytes); + } else { + memset(old_conv_buf.data(), 0, conv_bytes); + } + + if (ckpt.per_step_qkv[il] != nullptr) { + ggml_backend_tensor_get(ckpt.per_step_qkv[il], qkv_buf.data(), 0, qkv_needed * sizeof(float)); + } else { + memset(qkv_buf.data(), 0, qkv_needed * sizeof(float)); + } + + for (int32_t col = 0; col < d_conv_m1; col++) { + int32_t src_token = step - (d_conv_m1 - 1) + col; // e.g., K-2, K-1, K for d_conv=4 + if (src_token >= 0) { + for (int64_t d = 0; d < conv_dim; d++) { + conv_buf[col + d * d_conv_m1] = qkv_buf[d + (int64_t)src_token * conv_dim]; + } + } else { + int32_t old_col = d_conv_m1 + src_token; // maps to 0, 1, ... for early steps + if (old_col >= 0 && old_col < d_conv_m1) { + for (int64_t d = 0; d < conv_dim; d++) { + conv_buf[col + d * d_conv_m1] = old_conv_buf[old_col + d * d_conv_m1]; + } + } else { + for (int64_t d = 0; d < conv_dim; d++) { + conv_buf[col + d * d_conv_m1] = 0.0f; + } + } + } + } + + ggml_backend_tensor_set(s_l[il], conv_buf.data(), 0, conv_bytes); + ggml_backend_tensor_set(s_l[il], ssm_buf.data(), conv_bytes, ssm_bytes); + n_restored++; + } + + return true; +} + static void llama_kv_cache_clear(struct llama_kv_cache & cache) { for (int32_t i = 0; i < (int32_t) cache.size; ++i) { cache.cells[i].pos = -1; @@ -6464,6 +6838,168 @@ void llama_kv_cache_clear(struct llama_context * ctx) { llama_kv_cache_clear(ctx->kv_self); } +// Unified speculative-checkpoint +static bool spec_ckpt_try_per_step(llama_kv_cache & kv, const llama_model & model, int max_tokens) { + // Graph-split tensors and mixed CPU/GPU configurations are not supported. + bool has_gpu = false; + bool has_cpu = false; + for (const auto * sl : kv.s_l) { + if (!sl) continue; + if (sl->extra) { + kv.save_per_step_ssm = false; + return false; + } + if (sl->buffer && !ggml_backend_buffer_is_host(sl->buffer)) { + has_gpu = true; + } else if (sl->buffer) { + has_cpu = true; + } + } + if (!has_gpu || has_cpu) { + if (has_cpu && has_gpu) { + LLAMA_LOG_INFO("%s: per-step disabled — mixed CPU/GPU recurrent layers\n", __func__); + } + kv.save_per_step_ssm = false; + return false; + } + + // Populate per-step dimensions from hparams + if (kv.ckpt.per_step_ssm_state_size <= 0) { + const auto & hp = model.hparams; + const int64_t nv = hp.ssm_dt_rank; + const int64_t head_v = hp.ssm_d_inner / nv; + const int64_t head_k = hp.ssm_d_state; + const int64_t nk = hp.ssm_n_group; + const int64_t key_dim = head_k * nk; + const int64_t val_dim = head_v * nv; + const int64_t conv_dim = key_dim * 2 + val_dim; + + kv.ckpt.per_step_ssm_state_size = head_v * head_v * nv; + kv.ckpt.per_step_conv_state_dim = (hp.ssm_d_conv - 1) * conv_dim; + kv.ckpt.per_step_conv_dim = conv_dim; + kv.ckpt.per_step_d_conv = hp.ssm_d_conv; + } + + if (!kv.per_step_alloc(max_tokens)) { + kv.save_per_step_ssm = false; + return false; + } + + return true; +} + +int llama_spec_ckpt_init(struct llama_context * ctx, int mode, int max_tokens) { + auto & kv = ctx->kv_self; + + kv.save_per_step_ssm = false; + kv.ckpt.selected_spec_mode = LLAMA_SPEC_CKPT_NONE; + + if (!kv.checkpoint_supported()) { + return (int)LLAMA_SPEC_CKPT_NONE; + } + + int requested = mode; + + // prefer PER_STEP → GPU_FALLBACK → CPU + if (requested == LLAMA_SPEC_CKPT_AUTO) { + requested = LLAMA_SPEC_CKPT_PER_STEP; + } + + if (requested == LLAMA_SPEC_CKPT_PER_STEP) { + if (spec_ckpt_try_per_step(kv, ctx->model, max_tokens)) { + kv.ckpt.selected_spec_mode = LLAMA_SPEC_CKPT_PER_STEP; + return (int)LLAMA_SPEC_CKPT_PER_STEP; + } + if (mode == LLAMA_SPEC_CKPT_PER_STEP) { + LLAMA_LOG_WARN("%s: per-step not available, falling back to GPU fallback mode\n", __func__); + } + requested = LLAMA_SPEC_CKPT_GPU_FALLBACK; + } + + if (requested == LLAMA_SPEC_CKPT_GPU_FALLBACK) { + kv.ckpt.selected_spec_mode = LLAMA_SPEC_CKPT_GPU_FALLBACK; + return (int)LLAMA_SPEC_CKPT_GPU_FALLBACK; + } + + kv.ckpt.selected_spec_mode = LLAMA_SPEC_CKPT_CPU; + return (int)LLAMA_SPEC_CKPT_CPU; +} + +bool llama_spec_ckpt_save(struct llama_context * ctx, llama_seq_id seq_id) { + auto & kv = ctx->kv_self; + + switch (kv.ckpt.selected_spec_mode) { + case LLAMA_SPEC_CKPT_PER_STEP: + kv.save_per_step_ssm = true; + return kv.checkpoint_save(); + + case LLAMA_SPEC_CKPT_GPU_FALLBACK: + return kv.checkpoint_save(); + + case LLAMA_SPEC_CKPT_CPU: { + const size_t need = llama_state_seq_get_size(ctx, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + kv.ckpt.cpu_state_data.resize(need); + const size_t written = llama_state_seq_get_data( + ctx, kv.ckpt.cpu_state_data.data(), need, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + kv.ckpt.cpu_state_data.resize(written); + return written > 0; + } + + default: + return false; + } +} + +bool llama_spec_ckpt_restore(struct llama_context * ctx, llama_seq_id seq_id, + llama_pos n_past, int accepted_step) { + auto & kv = ctx->kv_self; + + switch (kv.ckpt.selected_spec_mode) { + case LLAMA_SPEC_CKPT_PER_STEP: { + if (!kv.per_step_restore(accepted_step)) { + return false; + } + const llama_pos accepted_pos = n_past + accepted_step; + if (seq_id >= 0 && (uint32_t)seq_id < kv.size) { + kv.cells[seq_id].pos = accepted_pos; + } + llama_kv_cache_seq_rm(kv, seq_id, accepted_pos + 1, -1); + return true; + } + + case LLAMA_SPEC_CKPT_GPU_FALLBACK: + kv.checkpoint_restore(); + llama_kv_cache_seq_rm(kv, seq_id, n_past, -1); + return false; + + case LLAMA_SPEC_CKPT_CPU: + if (!kv.ckpt.cpu_state_data.empty()) { + llama_state_seq_set_data(ctx, kv.ckpt.cpu_state_data.data(), + kv.ckpt.cpu_state_data.size(), seq_id, + LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + } + llama_kv_cache_seq_rm(kv, seq_id, n_past, -1); + return false; + + default: + return false; + } +} + +void llama_spec_ckpt_discard(struct llama_context * ctx) { + auto & kv = ctx->kv_self; + + if (kv.ckpt.selected_spec_mode == LLAMA_SPEC_CKPT_PER_STEP) { + kv.save_per_step_ssm = false; + kv.checkpoint_delete(); + } else if (kv.ckpt.selected_spec_mode == LLAMA_SPEC_CKPT_GPU_FALLBACK) { + kv.checkpoint_delete(); + } + + kv.ckpt.selected_spec_mode = LLAMA_SPEC_CKPT_NONE; + kv.ckpt.cpu_state_data.clear(); +} + bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1); } @@ -9064,18 +9600,20 @@ void llama_sampler_dry_free(struct llama_sampler_dry* smpl) { } struct llama_sampler_dry* llama_sampler_dry_clone(struct llama_sampler_dry* smpl) { - // nullptr is passed as vocab because it is only needed for raw sequence breaker processing, which we have already done and will be copying - auto* result = llama_sampler_init_dry(nullptr, smpl->dry_multiplier, smpl->dry_base, smpl->dry_allowed_length, smpl->dry_penalty_last_n, NULL, 0); - // Copy the state, including the processed breakers - { - auto* result_ctx = smpl; - result_ctx->dry_processed_breakers = smpl->dry_processed_breakers; - result_ctx->dry_repeat_count = smpl->dry_repeat_count; - result_ctx->dry_max_token_repeat = smpl->dry_max_token_repeat; - result_ctx->last_tokens = smpl->last_tokens; + if (!smpl) { + return nullptr; } - - return result; + return new llama_sampler_dry { + /* .total_context_size = */ smpl->total_context_size, + /* .dry_multiplier = */ smpl->dry_multiplier, + /* .dry_base = */ smpl->dry_base, + /* .dry_allowed_length = */ smpl->dry_allowed_length, + /* .dry_penalty_last_n = */ smpl->dry_penalty_last_n, + /* .dry_processed_breakers = */ smpl->dry_processed_breakers, + /* .dry_repeat_count = */ smpl->dry_repeat_count, + /* .dry_max_token_repeat = */ smpl->dry_max_token_repeat, + /* .last_tokens = */ smpl->last_tokens, + }; } void llama_sampler_dry_accept(struct llama_sampler_dry* smpl, llama_token token) {