diff --git a/common/common.cpp b/common/common.cpp index 0632b0a1..d308fe9d 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1033,6 +1033,23 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.speculative.p_min = std::stof(argv[i]); return true; } + if (arg == "--recurrent-ckpt-mode") { + CHECK_ARG + const std::string val = argv[i]; + if (val == "auto" || val == "AUTO") { + params.speculative.recurrent_ckpt_mode = LLAMA_SPEC_CKPT_AUTO; + } else if (val == "per-step" || val == "PER_STEP") { + params.speculative.recurrent_ckpt_mode = LLAMA_SPEC_CKPT_PER_STEP; + } else if (val == "gpu-fallback" || val == "GPU_FALLBACK") { + params.speculative.recurrent_ckpt_mode = LLAMA_SPEC_CKPT_GPU_FALLBACK; + } else if (val == "cpu" || val == "CPU") { + params.speculative.recurrent_ckpt_mode = LLAMA_SPEC_CKPT_CPU; + } else { + throw std::invalid_argument("unknown --recurrent-ckpt-mode value: " + val + + "; expected auto, per-step, gpu-fallback, or cpu"); + } + return true; + } if (arg == "--spec-autotune") { params.speculative.autotune = true; return true; @@ -2732,6 +2749,11 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param "number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max }); options.push_back({ "*", "--draft-min, --draft-n-min N", "minimum number of draft tokens to use for speculative decoding" }); options.push_back({ "*", "--draft-p-min P", "minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.speculative.p_min }); + options.push_back({ "*", "--recurrent-ckpt-mode MODE", "checkpoint strategy for recurrent/hybrid speculative decoding\n" + " auto auto-select: per-step if CUDA full-GPU, gpu-fallback otherwise (default)\n" + " per-step save SSM state per draft step in VRAM; no re-decode on rejection\n" + " gpu-fallback copy state to GPU buffer; re-decode on rejection\n" + " cpu serialise state via llama_state_seq; re-decode on rejection" }); options.push_back({ "*", "--spec-type Name [none | ngram - cache | ngram - simple | ngram - map - k | ngram - map - k4v | ngram - mod | suffix]", "type of speculative decoding to use when no draft model is provided (default: %d)\n", (int)params.speculative.type}); options.push_back({ "*", "--spec-ngram-size-n N", "ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)\n",params.speculative.ngram_size_n }); diff --git a/common/common.h b/common/common.h index 37cf6ec2..734d93de 100644 --- a/common/common.h +++ b/common/common.h @@ -164,6 +164,9 @@ struct common_ngram_mod; struct common_params_speculative { common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE; // type of speculative decoding + // Recurrent-model checkpoint strategy for speculative decoding. + int recurrent_ckpt_mode = LLAMA_SPEC_CKPT_AUTO; + std::string devices; std::string params; int32_t n_threads = -1; diff --git a/common/ngram-map.cpp b/common/ngram-map.cpp index 48699428..affb26c1 100644 --- a/common/ngram-map.cpp +++ b/common/ngram-map.cpp @@ -208,7 +208,7 @@ void common_ngram_map_begin( count_keys, count_keys_del, count_values_del, count_map_entries_upd); } - map.idx_last_check = (map.size_last_begin > 0) ? map.size_last_begin - 1 : 0; + map.idx_last_check = size_begin; map.size_last_begin = size_begin; } @@ -386,7 +386,7 @@ void common_ngram_map_draft(common_ngram_map & map, LOG_DBG("%s: key_idx = %zu, key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__, curr_key.key_idx, key_offset, curr_key.key_num, draft.size()); - map.last_draft_created = false; + map.last_draft_created = true; map.last_draft_key_idx = key_offset; map.last_draft_value_idx = 0; // value 0 is used for simple mode return; diff --git a/common/sampling.cpp b/common/sampling.cpp index 3fb18358..d573b5e7 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -218,6 +218,12 @@ void llama_sampling_set_rng_seed(struct common_sampler * ctx, uint32_t seed) { } void common_sampler_clone(common_sampler * src, common_sampler * dst) { + dst->params = src->params; + dst->mirostat_mu = src->mirostat_mu; + dst->n_valid = src->n_valid; + dst->rng = src->rng; + dst->server_biases = src->server_biases; + if (dst->grammar) { llama_grammar_free(dst->grammar); dst->grammar = nullptr; @@ -230,7 +236,18 @@ void common_sampler_clone(common_sampler * src, common_sampler * dst) { } dst->prev = src->prev; - dst->smpl = llama_sampler_dry_clone(src->smpl); + if (dst->smpl) { + llama_sampler_dry_free(dst->smpl); + dst->smpl = nullptr; + } + if (src->smpl) { + dst->smpl = llama_sampler_dry_clone(src->smpl); + } + + if (dst->rbudget) { + common_reasoning_budget_free(dst->rbudget); + dst->rbudget = nullptr; + } if (src->rbudget) { dst->rbudget = common_reasoning_budget_clone(src->rbudget); } diff --git a/common/speculative.cpp b/common/speculative.cpp index 00696b65..f76d40ab 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -598,6 +598,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { i_last = 0; n_draft_last = 0; + n_low = 0; const size_t n = mod.get_n(); @@ -1265,13 +1266,11 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) { spec->t_step_start_us = 0; } - if (n_accepted == 0) { - return; - } - common_speculative_state * impl = spec->curr_impl; - GGML_ASSERT(impl); + if (!impl) { + return; + } { common_time_meas tm(impl->t_accept_us, !impl->gen_perf); diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index 71a818b3..e5e1ec1d 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -22,6 +22,51 @@ static void log_text(const gpt_params & params_base, const std::string & text) { } } +void server_speculative_checkpoint::clear() { + valid = false; + per_step_enabled = false; + n_past = 0; + sampled = LLAMA_TOKEN_NULL; + + if (sampler != nullptr) { + common_sampler_free(sampler); + sampler = nullptr; + } +} + +static void discard_speculative_checkpoint(server_slot & slot, llama_context * ctx) { + slot.spec_ckpt.clear(); + llama_spec_ckpt_discard(ctx); +} + +static bool save_speculative_checkpoint(server_slot & slot, llama_model * model, llama_context * ctx, int ckpt_mode) { + slot.spec_ckpt.clear(); + slot.spec_ckpt.n_past = slot.n_past - (int32_t)(slot.drafted.size() + 1); + slot.spec_ckpt.sampled = slot.sampled; + + const int max_tokens = (int)slot.drafted.size() + 1; + const int actual_mode = llama_spec_ckpt_init(ctx, ckpt_mode, max_tokens); + if (actual_mode == LLAMA_SPEC_CKPT_NONE) { + return false; + } + slot.spec_ckpt.per_step_enabled = (actual_mode == LLAMA_SPEC_CKPT_PER_STEP); + + slot.spec_ckpt.valid = llama_spec_ckpt_save(ctx, slot.id); + if (!slot.spec_ckpt.valid) { + llama_spec_ckpt_discard(ctx); + return false; + } + + slot.spec_ckpt.sampler = common_sampler_init(model, slot.sparams); + if (slot.spec_ckpt.sampler == nullptr) { + discard_speculative_checkpoint(slot, ctx); + return false; + } + + common_sampler_clone(slot.ctx_sampling, slot.spec_ckpt.sampler); + return true; +} + server_context::~server_context() { if (ctx) { llama_free(ctx); @@ -49,6 +94,7 @@ server_context::~server_context() { if (slot.ctx_sampling != nullptr) { common_sampler_free(slot.ctx_sampling); } + slot.spec_ckpt.clear(); if (slot.ctx_dft) { llama_free(slot.ctx_dft); } @@ -112,15 +158,6 @@ bool server_context::load_model(const gpt_params& params_) { } // Load draft model for speculative decoding if specified if (has_draft_model) { - - if (llama_model_has_recurrent(model)) { - LLAMA_LOG_WARN("\n=======================================================================\n"); - LLAMA_LOG_WARN(" Speculative decodong is not suported for recurrent/hybrid models\n"); - LLAMA_LOG_WARN(" --> bailing out\n"); - LLAMA_LOG_WARN("========================================================================\n\n"); - GGML_ABORT("Fatal error"); - } - LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n"); gpt_params params_dft; @@ -387,6 +424,7 @@ void server_slot::reset() { n_sent_text = 0; drafted.clear(); i_batch_dft.clear(); + spec_ckpt.clear(); n_sent_token_probs = 0; infill = false; ga_i = 0; @@ -3679,6 +3717,72 @@ void server_context::extend_context(const int32_t n_tokens) { } } +// Restore recurrent state and re-decode accepted tokens after speculative-decode rejection. +static void restore_speculative_checkpoint( + server_slot & slot, llama_context * ctx, llama_model * model, + const std::vector & ids, int n_draft) { + if (slot.spec_ckpt.per_step_enabled) { + const int step = (int)ids.size() - 1; + llama_spec_ckpt_restore(ctx, slot.id, slot.spec_ckpt.n_past, step); + + if (slot.spec_ckpt.sampler) { + common_sampler_clone(slot.spec_ckpt.sampler, slot.ctx_sampling); + } + for (llama_token id : ids) { + common_sampler_accept(slot.ctx_sampling, ctx, id, true); + } + + SLT_DBG(slot, "per-step restore: step=%d (rejected %d drafts)\n", + step, (int)(n_draft - (ids.size() - 1))); + } else { + // Restore pre-speculation recurrent state then re-decode accepted tokens. + llama_spec_ckpt_restore(ctx, slot.id, slot.spec_ckpt.n_past, 0); + + if (slot.spec_ckpt.sampler) { + common_sampler_clone(slot.spec_ckpt.sampler, slot.ctx_sampling); + } + + if (!ids.empty()) { + // Re-decode to advance recurrent state to the accepted position. + const int n_re = (int)ids.size(); + llama_batch re_batch = llama_batch_init(n_re, 0, 1); + common_batch_add(re_batch, slot.spec_ckpt.sampled, slot.spec_ckpt.n_past, { slot.id }, n_re == 1); + for (int j = 0; j < n_re - 1; j++) { + common_batch_add(re_batch, ids[j], slot.spec_ckpt.n_past + 1 + j, { slot.id }, j == n_re - 2); + } + + if (slot.has_mtp) { + llama_set_embeddings(ctx, true); + } + + const int ret = llama_decode(ctx, re_batch); + if (ret != 0) { + SLT_ERR(slot, "failed to re-decode accepted tokens after checkpoint restore: %d\n", ret); + } + + if (slot.has_mtp) { + llama_set_embeddings(ctx, false); + const int n_embd = llama_model_n_embd(llama_get_model(ctx)); + const float * emb = llama_get_embeddings_ith(ctx, -1); + if (emb) { + slot.mtp_hidden_state.resize(n_embd); + memcpy(slot.mtp_hidden_state.data(), emb, n_embd * sizeof(float)); + } + } + + for (llama_token id : ids) { + common_sampler_accept(slot.ctx_sampling, ctx, id, true); + } + + llama_batch_free(re_batch); + SLT_DBG(slot, "spec checkpoint restored: re-decoded %d tokens (rejected %d drafts)\n", + n_re, (int)(n_draft - (ids.size() - 1))); + } + } + + discard_speculative_checkpoint(slot, ctx); +} + void server_context::speculative_decoding_accept() { for (auto& slot : slots) { if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch_dft.empty()) { @@ -3739,7 +3843,14 @@ void server_context::speculative_decoding_accept() { slot.sampled = ids.back(); // last accepted token slot.n_past = slot.cache_tokens.n_tokens(); - llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1); + // for recurrent/hybrid models: if any drafts were rejected, restore recurrent state + const bool any_rejected = (ids.size() - 1) < n_draft; + if (any_rejected && slot.spec_ckpt.valid) { + restore_speculative_checkpoint(slot, ctx, model, ids, n_draft); + } else { + llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1); + discard_speculative_checkpoint(slot, ctx); + } for (size_t i = 0; i < ids.size(); ++i) { completion_token_output result; @@ -4305,6 +4416,23 @@ void server_context::update_slots() { // make sure we're in the right embedding mode llama_set_embeddings(ctx, batch_type == 1); + if (llama_model_has_recurrent(model)) { + const int ckpt_mode = params_base.speculative.recurrent_ckpt_mode; + + for (auto & slot : slots) { + if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch_dft.empty()) { + continue; + } + if (save_speculative_checkpoint(slot, model, ctx, ckpt_mode)) { + const char * mode_name = slot.spec_ckpt.per_step_enabled ? "per-step" : "shadow/cpu"; + SLT_DBG(slot, "spec checkpoint saved (mode=%s), n_past_pre_spec=%d\n", + mode_name, slot.spec_ckpt.n_past); + } else { + SLT_WRN(slot, "%s", "failed to save spec checkpoint\n"); + } + } + } + // process the created batch of tokens process_batch_tokens(n_batch); // Decode with batch diff --git a/examples/server/server-context.h b/examples/server/server-context.h index f42c6d46..074787b5 100644 --- a/examples/server/server-context.h +++ b/examples/server/server-context.h @@ -22,6 +22,16 @@ enum slot_command { SLOT_COMMAND_RELEASE, }; +struct server_speculative_checkpoint { + bool valid = false; + bool per_step_enabled = false; // per-step SSM checkpoints active + llama_pos n_past = 0; + llama_token sampled = LLAMA_TOKEN_NULL; + common_sampler * sampler = nullptr; // saved sampler state + + void clear(); +}; + struct server_slot { int id; int id_task = -1; @@ -160,6 +170,9 @@ struct server_slot { bool has_mtp = false; std::vector mtp_hidden_state; + // saves recurrent state before a speculative batch so it can be restored on rejection + server_speculative_checkpoint spec_ckpt; + // speculative decoding stats int32_t n_draft_total = 0; // Total draft tokens generated int32_t n_draft_accepted = 0; // Draft tokens actually accepted diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 0d164166..250d924e 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2526,7 +2526,8 @@ extern "C" { struct ggml_tensor * v, struct ggml_tensor * g, struct ggml_tensor * beta, - struct ggml_tensor * state); + struct ggml_tensor * state, + bool save_all_steps); // custom operators diff --git a/ggml/src/ggml-cuda/delta-net.cu b/ggml/src/ggml-cuda/delta-net.cu index 35233255..d43a855c 100644 --- a/ggml/src/ggml-cuda/delta-net.cu +++ b/ggml/src/ggml-cuda/delta-net.cu @@ -35,13 +35,14 @@ __global__ void delta_net_recurrent_f32( const float * __restrict__ g, // [n_tokens, 1, n_heads, n_seqs] const float * __restrict__ beta_in, // [1, n_tokens, n_heads, n_seqs] const float * __restrict__ state_in, // [HEAD_DIM, HEAD_DIM*n_heads, 1, n_seqs] - float * __restrict__ dst, // output + new_state concatenated + float * __restrict__ dst, // output + new_state(s) concatenated const int64_t n_heads, const int64_t gqa_ratio, const int repeat_type, const int64_t n_tokens, const int64_t n_seqs, const int64_t output_offset, // offset where state starts in output + const int save_all_states, // 1 = save per-step states, 0 = final only size_t vnb1, size_t vnb2, size_t vnb3) { constexpr int warps_per_head = HEAD_DIM/WARP_SIZE; const int batch_idx = blockIdx.x / (warps_per_head*n_heads); @@ -69,6 +70,9 @@ __global__ void delta_net_recurrent_f32( const int64_t state_head_offset = head_idx * HEAD_DIM * HEAD_DIM; const int64_t state_batch_stride = HEAD_DIM * HEAD_DIM * n_heads; + // State step stride for save_all_states: HEAD_DIM^2 * n_heads * n_seqs + const int64_t state_step_stride = HEAD_DIM * HEAD_DIM * n_heads * n_seqs; + // Pointers for this batch/head const float * q_ptr = q + batch_idx * qkv_stride_batch_kq + head_idx_kq * qkv_stride_head; const float * k_ptr = k + batch_idx * qkv_stride_batch_kq + head_idx_kq * qkv_stride_head; @@ -155,6 +159,15 @@ __global__ void delta_net_recurrent_f32( state_local[i] = new_state_val; } + // Save per-step state if requested + if (save_all_states) { + float * state_step_dst = dst + output_offset + t * state_step_stride + batch_idx * state_batch_stride + state_head_offset; + for (int i = 0; i < HEAD_DIM/num_warps; ++i) { + int col = num_warps*i + col_idx_0; + state_step_dst[col*HEAD_DIM + row_out] = state_local[i]; + } + } + // Barrier required: (a) sK reads in the state update above must complete // before next iteration overwrites sK at the top of the loop, and (b) this // single barrier also orders all_sum1/all_sum2 reads above vs. the next @@ -163,9 +176,11 @@ __global__ void delta_net_recurrent_f32( __syncthreads(); } // Copy the final state to its destination - for (int i = 0; i < HEAD_DIM/num_warps; ++i) { - int col = num_warps*i + col_idx_0; - state_dst[col*HEAD_DIM + row_out] = state_local[i]; + if (!save_all_states) { + for (int i = 0; i < HEAD_DIM/num_warps; ++i) { + int col = num_warps*i + col_idx_0; + state_dst[col*HEAD_DIM + row_out] = state_local[i]; + } } } @@ -183,9 +198,10 @@ static void delta_net_f32_cuda( const int64_t gqa_ratio, const int repeat_type, const int64_t n_seqs, + const int save_all_states, size_t vnb1, size_t vnb2, size_t vnb3, const int device_id, - const int cc, // compute capability (e.g., 890 for SM 8.9, 1200 for SM 12.0) + const int cc, // compute capability (e.g., 890 for SM 8.9, 1200 for SM 12.0) cudaStream_t stream) { GGML_UNUSED(device_id); GGML_UNUSED(cc); @@ -204,19 +220,19 @@ static void delta_net_f32_cuda( constexpr int threads_per_block = 256; if (head_dim == 64) { delta_net_recurrent_f32<64, threads_per_block><<>>( - q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, vnb1, vnb2, vnb3); + q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, save_all_states, vnb1, vnb2, vnb3); } else { delta_net_recurrent_f32<128, threads_per_block><<>>( - q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, vnb1, vnb2, vnb3); + q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, save_all_states, vnb1, vnb2, vnb3); } } else { constexpr int threads_per_block = 128; if (head_dim == 64) { delta_net_recurrent_f32<64, threads_per_block><<>>( - q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, vnb1, vnb2, vnb3); + q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, save_all_states, vnb1, vnb2, vnb3); } else { delta_net_recurrent_f32<128, threads_per_block><<>>( - q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, vnb1, vnb2, vnb3); + q, k, v, g, beta, state_in, dst, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, output_offset, save_all_states, vnb1, vnb2, vnb3); } } @@ -258,9 +274,12 @@ void ggml_cuda_op_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) // Verify output tensor size const int64_t output_size = head_dim * n_tokens * n_heads * n_seqs; const int64_t state_size = head_dim * head_dim * n_heads * n_seqs; - GGML_ASSERT(ggml_nelements(dst) == output_size + state_size); int repeat_type = dst->op_params[0]; + int save_all_states = dst->op_params[1]; + + const int64_t expected_size = save_all_states ? (output_size + n_tokens * state_size) : (output_size + state_size); + GGML_ASSERT(ggml_nelements(dst) == expected_size); GGML_ASSERT(head_dim <= 256); // Reasonable limit for shared memory @@ -277,6 +296,7 @@ void ggml_cuda_op_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) (const float *)src5->data, (float *)dst->data, head_dim, n_tokens, n_heads, gqa_ratio, repeat_type, n_seqs, + save_all_states, src2->nb[1]/sizeof(float), src2->nb[2]/sizeof(float), src2->nb[3]/sizeof(float), device_id, cc, ctx.stream()); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 59842a2c..4454ffb9 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -9940,7 +9940,8 @@ struct ggml_tensor * ggml_delta_net( struct ggml_tensor * v, struct ggml_tensor * g, struct ggml_tensor * beta, - struct ggml_tensor * state) { + struct ggml_tensor * state, + bool save_all_steps) { GGML_ASSERT(ggml_is_contiguous(q)); GGML_ASSERT(ggml_is_contiguous(k)); GGML_ASSERT(ggml_is_contiguous(state)); @@ -9971,9 +9972,11 @@ struct ggml_tensor * ggml_delta_net( const int64_t output_size = S_v * H_v * n_tokens * n_seqs; const int64_t state_size = S_v * S_v * H_v * n_seqs; - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, output_size + state_size); + const int64_t state_slots = save_all_steps ? n_tokens : 1; + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, output_size + state_slots * state_size); result->op = GGML_OP_DELTA_NET; + result->op_params[1] = save_all_steps ? 1 : 0; result->src[0] = q; result->src[1] = k; result->src[2] = v; @@ -22654,17 +22657,19 @@ static void ggml_compute_forward_delta_net_f32( const float * beta_data = (const float *) src4->data; const float * state_in = (const float *) src5->data; float * out_data = (float *) dst->data; - float * state_out = out_data + output_size; const int ith = params->ith; const int nth = params->nth; int repeat_type = dst->op_params[0]; + // save_all_steps is handled by the CUDA backend only; + // the CPU path always writes to the single state slot after the output. + float * state_working = out_data + output_size; if (iqk_fused_delta_net(head_dim, n_heads, gqa_ratio, repeat_type, n_tokens, n_seqs, src2->nb[1]/sizeof(float), src2->nb[2]/sizeof(float), src2->nb[3]/sizeof(float), q_data, k_data, v_data, g_data, beta_data, state_in, - out_data, state_out, ith, nth)) { + out_data, state_working, ith, nth)) { return; } @@ -22694,10 +22699,10 @@ static void ggml_compute_forward_delta_net_f32( const int64_t out_token_stride = head_dim * n_heads; for (int64_t i = 0; i < head_dim * head_dim; ++i) { - state_out[state_head_offset + i] = state_in[state_head_offset + i]; + state_working[state_head_offset + i] = state_in[state_head_offset + i]; } - float * state = state_out + state_head_offset; + float * state = state_working + state_head_offset; for (int64_t t = 0; t < n_tokens; ++t) { const float * q_t = q_data + qkv_head_offset_kq + t * qkv_token_stride; diff --git a/include/llama.h b/include/llama.h index 732bdde7..8d6cd295 100644 --- a/include/llama.h +++ b/include/llama.h @@ -800,6 +800,28 @@ extern "C" { LLAMA_API void llama_kv_cache_clear( struct llama_context * ctx); + // Unified checkpoint API for recurrent/hybrid speculative decoding. + enum llama_spec_ckpt_mode { + LLAMA_SPEC_CKPT_NONE = -1, + LLAMA_SPEC_CKPT_AUTO = 0, + LLAMA_SPEC_CKPT_PER_STEP = 1, + LLAMA_SPEC_CKPT_GPU_FALLBACK = 2, + LLAMA_SPEC_CKPT_CPU = 3, + }; + + // Initialise the checkpoint system for the upcoming speculation window. + LLAMA_API int llama_spec_ckpt_init(struct llama_context * ctx, int mode, int max_tokens); + + // Save the current recurrent state as a speculative checkpoint. + LLAMA_API bool llama_spec_ckpt_save(struct llama_context * ctx, llama_seq_id seq_id); + + // Restore the recurrent state after speculative decode. + LLAMA_API bool llama_spec_ckpt_restore(struct llama_context * ctx, llama_seq_id seq_id, + llama_pos n_past, int accepted_step); + + // Discard the saved checkpoint and reset internal mode state. + LLAMA_API void llama_spec_ckpt_discard(struct llama_context * ctx); + // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails // seq_id < 0 : match any sequence diff --git a/src/llama-context.h b/src/llama-context.h index be563122..7b6e56cf 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -59,6 +59,9 @@ struct llama_kv_cache { std::vector v_l; std::vector s_l; // per layer recurrent state storage (Qwen3Next) + // When true, the delta_net graph builder will enable per-step SSM state saves + bool save_per_step_ssm = false; + std::vector split_k_l; std::vector split_v_l; std::vector split_s_l; @@ -74,6 +77,74 @@ struct llama_kv_cache { return size; } + // GPU-resident checkpoint for recurrent/hybrid speculative decoding + struct gpu_checkpoint { + std::vector cells_snapshot; + uint32_t head_snapshot = 0; + uint32_t used_snapshot = 0; + + std::vector s_l_shadow; + + std::vector> split_s_l_shadow; + + // Per-step SSM state checkpoints for speculative decoding. + std::vector per_step_ssm; + + // Per-step conv feature buffer: stores qkv_mixed features from the + // verification forward pass so conv state can be reconstructed at any step. + // One tensor per recurrent layer, each sized [conv_dim * max_tokens]. + std::vector per_step_qkv; + + int32_t per_step_n_tokens = 0; + int32_t per_step_max_allocated = 0; + int64_t per_step_ssm_state_size = 0; + int64_t per_step_conv_state_dim = 0; + int64_t per_step_conv_dim = 0; + int32_t per_step_d_conv = 0; + + int selected_spec_mode = -1; + + // Serialised sequence state for CPU mode + std::vector cpu_state_data; + + // Separate storage for per-step allocations + std::vector per_step_ctxs; + std::vector per_step_bufs; + + std::vector shadow_ctxs; + std::vector shadow_bufs; + + bool allocated = false; + bool saved = false; + + ~gpu_checkpoint() { + for (struct ggml_context * ctx : shadow_ctxs) { + ggml_free(ctx); + } + for (ggml_backend_buffer_t buf : shadow_bufs) { + ggml_backend_buffer_free(buf); + } + for (struct ggml_context * ctx : per_step_ctxs) { + ggml_free(ctx); + } + for (ggml_backend_buffer_t buf : per_step_bufs) { + ggml_backend_buffer_free(buf); + } + } + }; + + gpu_checkpoint ckpt; + + bool checkpoint_alloc_shadows(); + bool checkpoint_supported() const; + bool checkpoint_save(); + bool checkpoint_restore(); + void checkpoint_delete(); + + // Per-step checkpoint: allocate, restore step k's full state (SSM + conv) to cache + bool per_step_alloc(int max_tokens); + bool per_step_restore(int step); + ~llama_kv_cache() { for (struct ggml_context * ctx : ctxs) { ggml_free(ctx); diff --git a/src/llama-delta-net.cpp b/src/llama-delta-net.cpp index 060294df..1c24fc2e 100644 --- a/src/llama-delta-net.cpp +++ b/src/llama-delta-net.cpp @@ -70,6 +70,7 @@ delta_net::delta_net(llama_context & _lctx, const llama_batch & _batch) : lctx(_ GGML_ASSERT((uint32_t) s < qnext_state_slots); } + save_per_step_states = lctx.kv_self.save_per_step_ssm && batch.n_tokens > 1; } delta_net::~delta_net() = default; @@ -77,7 +78,9 @@ delta_net::~delta_net() = default; std::pair delta_net::build_fused_delta_net(ggml_context * ctx0, ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state, - int il, const llm_build_cb & cb, int repeat_type) { + int il, const llm_build_cb & cb, int repeat_type, + bool save_all_steps, + ggml_cgraph * gf, ggml_tensor * per_step_ckpt) { const int64_t S_k = q->ne[0]; const int64_t H_k = q->ne[2]; @@ -119,7 +122,7 @@ std::pair delta_net::build_fused_delta_net(ggml_co cb(beta, "beta_fused", il); cb(state_flat,"state_fused", il); - ggml_tensor * fused_result = ggml_delta_net(ctx0, q, k, v, g, beta, state_flat); + ggml_tensor * fused_result = ggml_delta_net(ctx0, q, k, v, g, beta, state_flat, save_all_steps); cb(fused_result, "delta_net_fused_raw", il); fused_result->op_params[0] = repeat_type; @@ -133,13 +136,34 @@ std::pair delta_net::build_fused_delta_net(ggml_co ggml_row_size(fused_result->type, S_v * H_v * n_tokens), 0); //output_tokens = ggml_cont_4d(ctx0, output_tokens, S_v, H_v, n_tokens, n_seqs); + // per-step states are at [output_size, output_size + n_tokens*state_size) + const int64_t last_state_offset = save_all_steps + ? (output_size + (n_tokens - 1) * state_size) + : output_size; + ggml_tensor * new_state_flat = ggml_view_1d(ctx0, fused_result, state_size, - output_size * ggml_element_size(fused_result)); + last_state_offset * ggml_element_size(fused_result)); ggml_tensor * new_state = ggml_reshape_4d(ctx0, new_state_flat, S_v, S_v, H_v, n_seqs); cb(output_tokens, "output_tokens", il); cb(new_state, "new_state", il); + // Copy all per-step SSM states to persistent checkpoint tensor + if (save_all_steps && per_step_ckpt != nullptr && gf != nullptr && n_tokens > 1) { + const int64_t per_step_total = n_tokens * state_size; + if (per_step_total <= ggml_nelements(per_step_ckpt)) { + ggml_tensor * all_steps_src = ggml_view_1d(ctx0, fused_result, per_step_total, + output_size * ggml_element_size(fused_result)); + ggml_tensor * ckpt_dst = ggml_view_1d(ctx0, per_step_ckpt, per_step_total, 0); + auto ckpt_cpy = ggml_cpy(ctx0, all_steps_src, ckpt_dst); + cb(ckpt_cpy, "per_step_ckpt_cpy", il); + ggml_build_forward_expand(gf, ckpt_cpy); + } else { + LLAMA_LOG_WARN("%s: per-step checkpoint tensor too small for %lld tokens (need %lld, have %lld), skipping per-step save\n", + __func__, (long long)n_tokens, (long long)per_step_total, (long long)ggml_nelements(per_step_ckpt)); + } + } + return {output_tokens, new_state}; } @@ -281,7 +305,8 @@ ggml_tensor * delta_net::build_qkv(ggml_context * ctx0, ggml_tensor * state_stor ggml_tensor * qkv_mixed, ggml_tensor * inp_s_seq_qnext, ggml_tensor * beta, ggml_tensor * gate, int64_t head_k_dim, int64_t num_k_heads, int64_t head_v_dim, int64_t num_v_heads, int64_t ssm_d_conv, int64_t state_seq_id_local, uint32_t qnext_state_slots, bool reset_state_local, - float eps_norm, int repeat_type, int il, const llm_build_cb & cb, ggml_cgraph * gf) { + float eps_norm, int repeat_type, int il, const llm_build_cb & cb, ggml_cgraph * gf, + bool save_per_step_states, ggml_tensor * per_step_ckpt) { const int64_t key_dim = head_k_dim * num_k_heads; const int64_t value_dim = head_v_dim * num_v_heads; const int64_t conv_dim = key_dim * 2 + value_dim; @@ -366,7 +391,8 @@ ggml_tensor * delta_net::build_qkv(ggml_context * ctx0, ggml_tensor * state_stor cb(q_conv, "q_conv_normed", il); cb(k_conv, "k_conv_normed", il); - auto [output, new_state] = build_fused_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb, repeat_type); + auto [output, new_state] = build_fused_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb, repeat_type, + save_per_step_states, gf, per_step_ckpt); cb(output, "attn_output", il); cb(new_state, "new_state", il); @@ -566,11 +592,29 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_ auto [beta, gate] = build_beta_gate(lctx, ctx0, model.layers[il].ssm_beta_alpha, model.layers[il].ssm_beta, model.layers[il].ssm_alpha, model.layers[il].ssm_dt, model.layers[il].ssm_a, num_k_heads, num_v_heads, n_seqs, cur, il, cb, gf); + // Get per-step checkpoint tensor if available + ggml_tensor * per_step_ckpt = nullptr; + if (save_per_step_states && il < (int)kv_self.ckpt.per_step_ssm.size()) { + per_step_ckpt = kv_self.ckpt.per_step_ssm[il]; + } + + // Save qkv_mixed features for per-step conv state reconstruction + if (save_per_step_states && il < (int)kv_self.ckpt.per_step_qkv.size() && kv_self.ckpt.per_step_qkv[il] != nullptr) { + const int64_t conv_dim = qkv_mixed->ne[0]; + const int64_t n_tok_qkv = qkv_mixed->ne[1] * qkv_mixed->ne[2]; + ggml_tensor * qkv_flat = ggml_reshape_2d(ctx0, qkv_mixed, conv_dim, n_tok_qkv); + ggml_tensor * qkv_dst = ggml_view_2d(ctx0, kv_self.ckpt.per_step_qkv[il], + conv_dim, n_tok_qkv, conv_dim * sizeof(float), 0); + auto qkv_cpy = ggml_cpy(ctx0, qkv_flat, qkv_dst); + ggml_build_forward_expand(gf, qkv_cpy); + } + auto output = build_qkv(ctx0, kv_self.s_l[il], model.layers[il].ssm_conv1d, qkv_mixed, inp_s_seq_qnext, beta, gate, head_k_dim, num_k_heads, head_v_dim, num_v_heads, hparams.ssm_d_conv, state_seq_id_local, qnext_state_slots, reset_state_local, hparams.f_norm_rms_eps, - model.layers[il].ssm_beta_alpha ? 0 : 1, il, cb, gf); + model.layers[il].ssm_beta_alpha ? 0 : 1, il, cb, gf, + save_per_step_states, per_step_ckpt); auto gated_output = build_gated_output(lctx, ctx0, model.layers[il].ssm_norm, model.layers[il].ssm_out, output, z, head_v_dim, num_v_heads, n_tok, il, cb); if (inp_out_ids) { diff --git a/src/llama-delta-net.h b/src/llama-delta-net.h index 58259633..2fe499ad 100644 --- a/src/llama-delta-net.h +++ b/src/llama-delta-net.h @@ -8,10 +8,15 @@ struct delta_net { delta_net(llama_context & lctx, const llama_batch & batch); ~delta_net(); + // Used for speculative decoding to enable per-step state checkpoint restoration. + bool save_per_step_states = false; + static std::pair build_fused_delta_net(ggml_context * ctx0, ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state, - int il, const llm_build_cb & cb, int repeat_type); + int il, const llm_build_cb & cb, int repeat_type, + bool save_all_steps = false, + ggml_cgraph * gf = nullptr, ggml_tensor * per_step_ckpt = nullptr); ggml_tensor * build_layer_attn_linear_core(ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * inp_s_seq_qnext, ggml_tensor * inp_out_ids, @@ -46,7 +51,8 @@ private: ggml_tensor * qkv_mixed, ggml_tensor * inp_s_seq_qnext, ggml_tensor * beta, ggml_tensor * gate, int64_t head_k_dim, int64_t num_k_heads, int64_t head_v_dim, int64_t num_v_heads, int64_t ssm_d_conv, int64_t state_seq_id_local, uint32_t qnext_state_slots, bool reset_state_local, - float eps_norm, int repeat_type, int il, const llm_build_cb & cb, ggml_cgraph * gf); + float eps_norm, int repeat_type, int il, const llm_build_cb & cb, ggml_cgraph * gf, + bool save_per_step_states = false, ggml_tensor * per_step_ckpt = nullptr); static ggml_tensor * build_gated_output(llama_context & lctx, ggml_context * ctx0, ggml_tensor * ssm_norm, ggml_tensor * ssm_out, ggml_tensor * output, ggml_tensor * z, int64_t head_v_dim, int64_t num_v_heads, int64_t n_tok, int il, const llm_build_cb & cb); diff --git a/src/llama.cpp b/src/llama.cpp index 3c3f1ef4..7a832716 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1204,6 +1204,380 @@ static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache, uin return 0; } +bool llama_kv_cache::checkpoint_supported() const { + for (const auto * s : s_l) { + if (s != nullptr) { + return true; + } + } + return false; +} + +bool llama_kv_cache::checkpoint_alloc_shadows() { + if (ckpt.allocated) { + return true; + } + + const uint32_t n_layer = (uint32_t)s_l.size(); + ckpt.s_l_shadow.resize(n_layer, nullptr); + + struct tensor_entry { + ggml_tensor * primary; + uint32_t il; + int split_idx; // -1 for non-split + }; + + const bool conv_only_shadow = save_per_step_ssm && ckpt.per_step_conv_state_dim > 0; + std::vector nonsplit_entries; + + std::map> split_buft_entries; + + uint32_t split_s_idx = 0; + for (uint32_t il = 0; il < n_layer; ++il) { + if (s_l[il] == nullptr) { + continue; + } + + auto * extra = s_l[il]->extra; + if (extra != nullptr) { + auto * split_info = (const ggml_split_tensor_t *)extra; + for (int d = 0; d < split_info->n_device; ++d) { + if (split_info->splits[d] == nullptr) continue; + ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(split_info->splits[d]->buffer); + split_buft_entries[buft].push_back({split_info->splits[d], il, d}); + } + split_s_idx++; + } else { + nonsplit_entries.push_back({s_l[il], il, -1}); + } + } + + if (!nonsplit_entries.empty()) { + ggml_init_params params = { + /*.mem_size =*/ nonsplit_entries.size() * ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context * ctx = ggml_init(params); + if (!ctx) { + LLAMA_LOG_ERROR("%s: failed to create ggml context for shadow tensors\n", __func__); + return false; + } + + for (auto & entry : nonsplit_entries) { + // Only need the conv portion when per-step is active. + const int64_t nelems = conv_only_shadow + ? ckpt.per_step_conv_state_dim + : (int64_t)ggml_nelements(entry.primary); + ggml_tensor * shadow = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nelems); + ggml_format_name(shadow, "shadow_s_l%d", entry.il); + ckpt.s_l_shadow[entry.il] = shadow; + } + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_cpu_buffer_type()); + if (!buf) { + LLAMA_LOG_ERROR("%s: failed to allocate CPU buffer for shadow tensors\n", __func__); + ggml_free(ctx); + return false; + } + ggml_backend_buffer_clear(buf, 0); + LLAMA_LOG_INFO("%s: CPU shadow buffer = %8.2f MiB (%s)\n", __func__, + ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0, + conv_only_shadow ? "conv-state only" : "full recurrent state"); + ckpt.shadow_ctxs.push_back(ctx); + ckpt.shadow_bufs.push_back(buf); + } + + // Allocate split shadows on their respective devices + for (auto & [buft, entries] : split_buft_entries) { + ggml_init_params params = { + /*.mem_size =*/ entries.size() * ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context * ctx = ggml_init(params); + if (!ctx) { + LLAMA_LOG_ERROR("%s: failed to create ggml context for split shadow tensors\n", __func__); + return false; + } + + for (auto & entry : entries) { + ggml_tensor * shadow = ggml_dup_tensor(ctx, entry.primary); + ggml_format_name(shadow, "shadow_s_l%d_d%d", entry.il, entry.split_idx); + entry.primary = shadow; + } + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + LLAMA_LOG_ERROR("%s: failed to allocate buffer for split shadow tensors\n", __func__); + ggml_free(ctx); + return false; + } + ggml_backend_buffer_clear(buf, 0); + LLAMA_LOG_INFO("%s: %10s split shadow buffer = %8.2f MiB\n", __func__, + ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0); + ckpt.shadow_ctxs.push_back(ctx); + ckpt.shadow_bufs.push_back(buf); + } + + // Build split shadow lookup + ckpt.split_s_l_shadow.resize(split_s_l.size()); + split_s_idx = 0; + for (uint32_t il = 0; il < n_layer; ++il) { + if (s_l[il] == nullptr || s_l[il]->extra == nullptr) { + continue; + } + + auto * split_info = (const ggml_split_tensor_t *)s_l[il]->extra; + auto & shadow_split = ckpt.split_s_l_shadow[split_s_idx]; + shadow_split.resize(split_info->n_device, nullptr); + + for (int d = 0; d < split_info->n_device; ++d) { + if (split_info->splits[d] == nullptr) continue; + ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(split_info->splits[d]->buffer); + for (auto & entry : split_buft_entries[buft]) { + if (entry.il == il && entry.split_idx == d) { + shadow_split[d] = entry.primary; + break; + } + } + } + split_s_idx++; + } + + ckpt.allocated = true; + return true; +} + +bool llama_kv_cache::checkpoint_save() { + if (!checkpoint_alloc_shadows()) { + return false; + } + + const uint32_t n_layer = (uint32_t)s_l.size(); + + ckpt.cells_snapshot = cells; + ckpt.head_snapshot = head; + ckpt.used_snapshot = used; + + uint32_t split_s_idx = 0; + for (uint32_t il = 0; il < n_layer; ++il) { + if (s_l[il] == nullptr) { + continue; + } + + if (s_l[il]->extra != nullptr) { + auto * split_info = (const ggml_split_tensor_t *)s_l[il]->extra; + auto & shadow_split = ckpt.split_s_l_shadow[split_s_idx]; + for (int d = 0; d < split_info->n_device; ++d) { + if (split_info->splits[d] && shadow_split[d]) { + ggml_backend_tensor_copy(split_info->splits[d], shadow_split[d]); + } + } + split_s_idx++; + } else { + const size_t nbytes = ggml_nbytes(ckpt.s_l_shadow[il]); + ggml_backend_tensor_get(s_l[il], ckpt.s_l_shadow[il]->data, 0, nbytes); + } + } + + ckpt.saved = true; + return true; +} + +bool llama_kv_cache::checkpoint_restore() { + if (!ckpt.saved) { + LLAMA_LOG_ERROR("%s: no checkpoint saved\n", __func__); + return false; + } + + const uint32_t n_layer = (uint32_t)s_l.size(); + + cells = ckpt.cells_snapshot; + head = ckpt.head_snapshot; + used = ckpt.used_snapshot; + + uint32_t split_s_idx = 0; + for (uint32_t il = 0; il < n_layer; ++il) { + if (s_l[il] == nullptr) { + continue; + } + + if (s_l[il]->extra != nullptr) { + auto * split_info = (const ggml_split_tensor_t *)s_l[il]->extra; + auto & shadow_split = ckpt.split_s_l_shadow[split_s_idx]; + for (int d = 0; d < split_info->n_device; ++d) { + if (split_info->splits[d] && shadow_split[d]) { + ggml_backend_tensor_copy(shadow_split[d], split_info->splits[d]); + } + } + split_s_idx++; + } else { + GGML_ASSERT(ggml_nbytes(ckpt.s_l_shadow[il]) == ggml_nbytes(s_l[il])); + ggml_backend_tensor_copy(ckpt.s_l_shadow[il], s_l[il]); + } + } + + return true; +} + +void llama_kv_cache::checkpoint_delete() { + ckpt.saved = false; +} + +bool llama_kv_cache::per_step_alloc(int max_tokens) { + if (ckpt.per_step_max_allocated >= max_tokens) { + return true; + } + + if (!ckpt.per_step_ssm.empty()) { + for (struct ggml_context * ctx : ckpt.per_step_ctxs) { + ggml_free(ctx); + } + for (ggml_backend_buffer_t buf : ckpt.per_step_bufs) { + ggml_backend_buffer_free(buf); + } + ckpt.per_step_ctxs.clear(); + ckpt.per_step_bufs.clear(); + ckpt.per_step_ssm.clear(); + ckpt.per_step_qkv.clear(); + ckpt.per_step_max_allocated = 0; + } + + const uint32_t n_layer = (uint32_t)s_l.size(); + ckpt.per_step_ssm.resize(n_layer, nullptr); + ckpt.per_step_qkv.resize(n_layer, nullptr); + + const int64_t ssm_state_dim = ckpt.per_step_ssm_state_size; + const int64_t conv_dim = ckpt.per_step_conv_dim; + if (ssm_state_dim <= 0 || conv_dim <= 0) { + LLAMA_LOG_ERROR("%s: per_step dimensions not set (ssm=%lld, conv_dim=%lld)\n", + __func__, (long long)ssm_state_dim, (long long)conv_dim); + return false; + } + + std::map>> buft_layers; + + for (uint32_t il = 0; il < n_layer; ++il) { + if (s_l[il] == nullptr) continue; + if (s_l[il]->extra != nullptr) continue; // skip split tensors + + ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(s_l[il]->buffer); + buft_layers[buft].push_back({il, buft}); + } + + for (auto & [buft, layers] : buft_layers) { + // 2 tensors per layer: SSM states + qkv features + ggml_init_params params = { + /*.mem_size =*/ layers.size() * 2 * ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context * ctx = ggml_init(params); + if (!ctx) { + LLAMA_LOG_ERROR("%s: failed to create ggml context for per-step checkpoints\n", __func__); + return false; + } + + for (auto & [il, bt] : layers) { + // SSM state: max_tokens * ssm_state_dim + ggml_tensor * t_ssm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (int64_t)max_tokens * ssm_state_dim); + ggml_format_name(t_ssm, "per_step_ssm_l%d", il); + ckpt.per_step_ssm[il] = t_ssm; + + // Conv features (qkv_mixed): max_tokens * conv_dim + ggml_tensor * t_qkv = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (int64_t)max_tokens * conv_dim); + ggml_format_name(t_qkv, "per_step_qkv_l%d", il); + ckpt.per_step_qkv[il] = t_qkv; + } + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + LLAMA_LOG_ERROR("%s: failed to allocate buffer for per-step checkpoints\n", __func__); + ggml_free(ctx); + return false; + } + ggml_backend_buffer_clear(buf, 0); + LLAMA_LOG_INFO("%s: %10s per-step buffer = %8.2f MiB (max_tokens=%d)\n", __func__, + ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0, max_tokens); + ckpt.per_step_ctxs.push_back(ctx); + ckpt.per_step_bufs.push_back(buf); + } + + ckpt.per_step_max_allocated = max_tokens; + return true; +} + +bool llama_kv_cache::per_step_restore(int step) { + if (ckpt.per_step_ssm.empty() || step < 0) { + return false; + } + + const int64_t ssm_state_dim = ckpt.per_step_ssm_state_size; + const int64_t conv_state_dim = ckpt.per_step_conv_state_dim; + const int64_t conv_dim = ckpt.per_step_conv_dim; + const int32_t d_conv = ckpt.per_step_d_conv; + if (ssm_state_dim <= 0 || conv_dim <= 0 || d_conv <= 1) return false; + + const int64_t ssm_bytes = ssm_state_dim * sizeof(float); + const int64_t conv_bytes = conv_state_dim * sizeof(float); + const int32_t d_conv_m1 = d_conv - 1; // number of columns in conv state + + std::vector ssm_buf(ssm_state_dim); + std::vector conv_buf(conv_state_dim); // reconstructed conv state + std::vector old_conv_buf(conv_state_dim); // pre-spec conv state from shadow + const int64_t qkv_needed = (int64_t)(step + 1) * conv_dim; + std::vector qkv_buf(qkv_needed); + + const uint32_t n_layer = (uint32_t)s_l.size(); + int n_restored = 0; + for (uint32_t il = 0; il < n_layer; ++il) { + if (s_l[il] == nullptr || ckpt.per_step_ssm[il] == nullptr) continue; + if (s_l[il]->extra != nullptr) continue; + + ggml_backend_tensor_get(ckpt.per_step_ssm[il], ssm_buf.data(), + (size_t)step * ssm_bytes, ssm_bytes); + + if (ckpt.s_l_shadow[il] != nullptr) { + ggml_backend_tensor_get(ckpt.s_l_shadow[il], old_conv_buf.data(), 0, conv_bytes); + } else { + memset(old_conv_buf.data(), 0, conv_bytes); + } + + if (ckpt.per_step_qkv[il] != nullptr) { + ggml_backend_tensor_get(ckpt.per_step_qkv[il], qkv_buf.data(), 0, qkv_needed * sizeof(float)); + } else { + memset(qkv_buf.data(), 0, qkv_needed * sizeof(float)); + } + + for (int32_t col = 0; col < d_conv_m1; col++) { + int32_t src_token = step - (d_conv_m1 - 1) + col; // e.g., K-2, K-1, K for d_conv=4 + if (src_token >= 0) { + for (int64_t d = 0; d < conv_dim; d++) { + conv_buf[col + d * d_conv_m1] = qkv_buf[d + (int64_t)src_token * conv_dim]; + } + } else { + int32_t old_col = d_conv_m1 + src_token; // maps to 0, 1, ... for early steps + if (old_col >= 0 && old_col < d_conv_m1) { + for (int64_t d = 0; d < conv_dim; d++) { + conv_buf[col + d * d_conv_m1] = old_conv_buf[old_col + d * d_conv_m1]; + } + } else { + for (int64_t d = 0; d < conv_dim; d++) { + conv_buf[col + d * d_conv_m1] = 0.0f; + } + } + } + } + + ggml_backend_tensor_set(s_l[il], conv_buf.data(), 0, conv_bytes); + ggml_backend_tensor_set(s_l[il], ssm_buf.data(), conv_bytes, ssm_bytes); + n_restored++; + } + + return true; +} + static void llama_kv_cache_clear(struct llama_kv_cache & cache) { for (int32_t i = 0; i < (int32_t) cache.size; ++i) { cache.cells[i].pos = -1; @@ -6464,6 +6838,168 @@ void llama_kv_cache_clear(struct llama_context * ctx) { llama_kv_cache_clear(ctx->kv_self); } +// Unified speculative-checkpoint +static bool spec_ckpt_try_per_step(llama_kv_cache & kv, const llama_model & model, int max_tokens) { + // Graph-split tensors and mixed CPU/GPU configurations are not supported. + bool has_gpu = false; + bool has_cpu = false; + for (const auto * sl : kv.s_l) { + if (!sl) continue; + if (sl->extra) { + kv.save_per_step_ssm = false; + return false; + } + if (sl->buffer && !ggml_backend_buffer_is_host(sl->buffer)) { + has_gpu = true; + } else if (sl->buffer) { + has_cpu = true; + } + } + if (!has_gpu || has_cpu) { + if (has_cpu && has_gpu) { + LLAMA_LOG_INFO("%s: per-step disabled — mixed CPU/GPU recurrent layers\n", __func__); + } + kv.save_per_step_ssm = false; + return false; + } + + // Populate per-step dimensions from hparams + if (kv.ckpt.per_step_ssm_state_size <= 0) { + const auto & hp = model.hparams; + const int64_t nv = hp.ssm_dt_rank; + const int64_t head_v = hp.ssm_d_inner / nv; + const int64_t head_k = hp.ssm_d_state; + const int64_t nk = hp.ssm_n_group; + const int64_t key_dim = head_k * nk; + const int64_t val_dim = head_v * nv; + const int64_t conv_dim = key_dim * 2 + val_dim; + + kv.ckpt.per_step_ssm_state_size = head_v * head_v * nv; + kv.ckpt.per_step_conv_state_dim = (hp.ssm_d_conv - 1) * conv_dim; + kv.ckpt.per_step_conv_dim = conv_dim; + kv.ckpt.per_step_d_conv = hp.ssm_d_conv; + } + + if (!kv.per_step_alloc(max_tokens)) { + kv.save_per_step_ssm = false; + return false; + } + + return true; +} + +int llama_spec_ckpt_init(struct llama_context * ctx, int mode, int max_tokens) { + auto & kv = ctx->kv_self; + + kv.save_per_step_ssm = false; + kv.ckpt.selected_spec_mode = LLAMA_SPEC_CKPT_NONE; + + if (!kv.checkpoint_supported()) { + return (int)LLAMA_SPEC_CKPT_NONE; + } + + int requested = mode; + + // prefer PER_STEP → GPU_FALLBACK → CPU + if (requested == LLAMA_SPEC_CKPT_AUTO) { + requested = LLAMA_SPEC_CKPT_PER_STEP; + } + + if (requested == LLAMA_SPEC_CKPT_PER_STEP) { + if (spec_ckpt_try_per_step(kv, ctx->model, max_tokens)) { + kv.ckpt.selected_spec_mode = LLAMA_SPEC_CKPT_PER_STEP; + return (int)LLAMA_SPEC_CKPT_PER_STEP; + } + if (mode == LLAMA_SPEC_CKPT_PER_STEP) { + LLAMA_LOG_WARN("%s: per-step not available, falling back to GPU fallback mode\n", __func__); + } + requested = LLAMA_SPEC_CKPT_GPU_FALLBACK; + } + + if (requested == LLAMA_SPEC_CKPT_GPU_FALLBACK) { + kv.ckpt.selected_spec_mode = LLAMA_SPEC_CKPT_GPU_FALLBACK; + return (int)LLAMA_SPEC_CKPT_GPU_FALLBACK; + } + + kv.ckpt.selected_spec_mode = LLAMA_SPEC_CKPT_CPU; + return (int)LLAMA_SPEC_CKPT_CPU; +} + +bool llama_spec_ckpt_save(struct llama_context * ctx, llama_seq_id seq_id) { + auto & kv = ctx->kv_self; + + switch (kv.ckpt.selected_spec_mode) { + case LLAMA_SPEC_CKPT_PER_STEP: + kv.save_per_step_ssm = true; + return kv.checkpoint_save(); + + case LLAMA_SPEC_CKPT_GPU_FALLBACK: + return kv.checkpoint_save(); + + case LLAMA_SPEC_CKPT_CPU: { + const size_t need = llama_state_seq_get_size(ctx, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + kv.ckpt.cpu_state_data.resize(need); + const size_t written = llama_state_seq_get_data( + ctx, kv.ckpt.cpu_state_data.data(), need, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + kv.ckpt.cpu_state_data.resize(written); + return written > 0; + } + + default: + return false; + } +} + +bool llama_spec_ckpt_restore(struct llama_context * ctx, llama_seq_id seq_id, + llama_pos n_past, int accepted_step) { + auto & kv = ctx->kv_self; + + switch (kv.ckpt.selected_spec_mode) { + case LLAMA_SPEC_CKPT_PER_STEP: { + if (!kv.per_step_restore(accepted_step)) { + return false; + } + const llama_pos accepted_pos = n_past + accepted_step; + if (seq_id >= 0 && (uint32_t)seq_id < kv.size) { + kv.cells[seq_id].pos = accepted_pos; + } + llama_kv_cache_seq_rm(kv, seq_id, accepted_pos + 1, -1); + return true; + } + + case LLAMA_SPEC_CKPT_GPU_FALLBACK: + kv.checkpoint_restore(); + llama_kv_cache_seq_rm(kv, seq_id, n_past, -1); + return false; + + case LLAMA_SPEC_CKPT_CPU: + if (!kv.ckpt.cpu_state_data.empty()) { + llama_state_seq_set_data(ctx, kv.ckpt.cpu_state_data.data(), + kv.ckpt.cpu_state_data.size(), seq_id, + LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + } + llama_kv_cache_seq_rm(kv, seq_id, n_past, -1); + return false; + + default: + return false; + } +} + +void llama_spec_ckpt_discard(struct llama_context * ctx) { + auto & kv = ctx->kv_self; + + if (kv.ckpt.selected_spec_mode == LLAMA_SPEC_CKPT_PER_STEP) { + kv.save_per_step_ssm = false; + kv.checkpoint_delete(); + } else if (kv.ckpt.selected_spec_mode == LLAMA_SPEC_CKPT_GPU_FALLBACK) { + kv.checkpoint_delete(); + } + + kv.ckpt.selected_spec_mode = LLAMA_SPEC_CKPT_NONE; + kv.ckpt.cpu_state_data.clear(); +} + bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1); } @@ -9064,18 +9600,20 @@ void llama_sampler_dry_free(struct llama_sampler_dry* smpl) { } struct llama_sampler_dry* llama_sampler_dry_clone(struct llama_sampler_dry* smpl) { - // nullptr is passed as vocab because it is only needed for raw sequence breaker processing, which we have already done and will be copying - auto* result = llama_sampler_init_dry(nullptr, smpl->dry_multiplier, smpl->dry_base, smpl->dry_allowed_length, smpl->dry_penalty_last_n, NULL, 0); - // Copy the state, including the processed breakers - { - auto* result_ctx = smpl; - result_ctx->dry_processed_breakers = smpl->dry_processed_breakers; - result_ctx->dry_repeat_count = smpl->dry_repeat_count; - result_ctx->dry_max_token_repeat = smpl->dry_max_token_repeat; - result_ctx->last_tokens = smpl->last_tokens; + if (!smpl) { + return nullptr; } - - return result; + return new llama_sampler_dry { + /* .total_context_size = */ smpl->total_context_size, + /* .dry_multiplier = */ smpl->dry_multiplier, + /* .dry_base = */ smpl->dry_base, + /* .dry_allowed_length = */ smpl->dry_allowed_length, + /* .dry_penalty_last_n = */ smpl->dry_penalty_last_n, + /* .dry_processed_breakers = */ smpl->dry_processed_breakers, + /* .dry_repeat_count = */ smpl->dry_repeat_count, + /* .dry_max_token_repeat = */ smpl->dry_max_token_repeat, + /* .last_tokens = */ smpl->last_tokens, + }; } void llama_sampler_dry_accept(struct llama_sampler_dry* smpl, llama_token token) {