mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-05-12 17:05:57 +00:00
Speculative checkpoints for recurrent models (#1669)
* server: spec checkpoints for recurrent models * fix: save/restore sampler state during speculative checkpoint When speculative decoding rejects draft tokens and restores the recurrent state checkpoint, the sampler (RNG, grammar, prev tokens) must also be restored to maintain consistency. Without this, the sampler state reflects the rejected draft tokens, leading to potential divergence. Uses common_sampler_clone() to snapshot the sampler before the speculative batch decode, and restores it on rejection. * server: snapshot recurrent state in tensor * reset ngram mod state for rejected tokens * server: refactor checkpoint state logic * speculative: fix sampler for checkpoints * recurrent model: implement recurrent kernel checkpoint * recurrent model: refactor api * spec: free rbudget before overwriting
This commit is contained in:
committed by
GitHub
parent
1c13288164
commit
ea94afe777
@@ -22,6 +22,51 @@ static void log_text(const gpt_params & params_base, const std::string & text) {
|
||||
}
|
||||
}
|
||||
|
||||
void server_speculative_checkpoint::clear() {
|
||||
valid = false;
|
||||
per_step_enabled = false;
|
||||
n_past = 0;
|
||||
sampled = LLAMA_TOKEN_NULL;
|
||||
|
||||
if (sampler != nullptr) {
|
||||
common_sampler_free(sampler);
|
||||
sampler = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
static void discard_speculative_checkpoint(server_slot & slot, llama_context * ctx) {
|
||||
slot.spec_ckpt.clear();
|
||||
llama_spec_ckpt_discard(ctx);
|
||||
}
|
||||
|
||||
static bool save_speculative_checkpoint(server_slot & slot, llama_model * model, llama_context * ctx, int ckpt_mode) {
|
||||
slot.spec_ckpt.clear();
|
||||
slot.spec_ckpt.n_past = slot.n_past - (int32_t)(slot.drafted.size() + 1);
|
||||
slot.spec_ckpt.sampled = slot.sampled;
|
||||
|
||||
const int max_tokens = (int)slot.drafted.size() + 1;
|
||||
const int actual_mode = llama_spec_ckpt_init(ctx, ckpt_mode, max_tokens);
|
||||
if (actual_mode == LLAMA_SPEC_CKPT_NONE) {
|
||||
return false;
|
||||
}
|
||||
slot.spec_ckpt.per_step_enabled = (actual_mode == LLAMA_SPEC_CKPT_PER_STEP);
|
||||
|
||||
slot.spec_ckpt.valid = llama_spec_ckpt_save(ctx, slot.id);
|
||||
if (!slot.spec_ckpt.valid) {
|
||||
llama_spec_ckpt_discard(ctx);
|
||||
return false;
|
||||
}
|
||||
|
||||
slot.spec_ckpt.sampler = common_sampler_init(model, slot.sparams);
|
||||
if (slot.spec_ckpt.sampler == nullptr) {
|
||||
discard_speculative_checkpoint(slot, ctx);
|
||||
return false;
|
||||
}
|
||||
|
||||
common_sampler_clone(slot.ctx_sampling, slot.spec_ckpt.sampler);
|
||||
return true;
|
||||
}
|
||||
|
||||
server_context::~server_context() {
|
||||
if (ctx) {
|
||||
llama_free(ctx);
|
||||
@@ -49,6 +94,7 @@ server_context::~server_context() {
|
||||
if (slot.ctx_sampling != nullptr) {
|
||||
common_sampler_free(slot.ctx_sampling);
|
||||
}
|
||||
slot.spec_ckpt.clear();
|
||||
if (slot.ctx_dft) {
|
||||
llama_free(slot.ctx_dft);
|
||||
}
|
||||
@@ -112,15 +158,6 @@ bool server_context::load_model(const gpt_params& params_) {
|
||||
}
|
||||
// Load draft model for speculative decoding if specified
|
||||
if (has_draft_model) {
|
||||
|
||||
if (llama_model_has_recurrent(model)) {
|
||||
LLAMA_LOG_WARN("\n=======================================================================\n");
|
||||
LLAMA_LOG_WARN(" Speculative decodong is not suported for recurrent/hybrid models\n");
|
||||
LLAMA_LOG_WARN(" --> bailing out\n");
|
||||
LLAMA_LOG_WARN("========================================================================\n\n");
|
||||
GGML_ABORT("Fatal error");
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n");
|
||||
|
||||
gpt_params params_dft;
|
||||
@@ -387,6 +424,7 @@ void server_slot::reset() {
|
||||
n_sent_text = 0;
|
||||
drafted.clear();
|
||||
i_batch_dft.clear();
|
||||
spec_ckpt.clear();
|
||||
n_sent_token_probs = 0;
|
||||
infill = false;
|
||||
ga_i = 0;
|
||||
@@ -3679,6 +3717,72 @@ void server_context::extend_context(const int32_t n_tokens) {
|
||||
}
|
||||
}
|
||||
|
||||
// Restore recurrent state and re-decode accepted tokens after speculative-decode rejection.
|
||||
static void restore_speculative_checkpoint(
|
||||
server_slot & slot, llama_context * ctx, llama_model * model,
|
||||
const std::vector<llama_token> & ids, int n_draft) {
|
||||
if (slot.spec_ckpt.per_step_enabled) {
|
||||
const int step = (int)ids.size() - 1;
|
||||
llama_spec_ckpt_restore(ctx, slot.id, slot.spec_ckpt.n_past, step);
|
||||
|
||||
if (slot.spec_ckpt.sampler) {
|
||||
common_sampler_clone(slot.spec_ckpt.sampler, slot.ctx_sampling);
|
||||
}
|
||||
for (llama_token id : ids) {
|
||||
common_sampler_accept(slot.ctx_sampling, ctx, id, true);
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "per-step restore: step=%d (rejected %d drafts)\n",
|
||||
step, (int)(n_draft - (ids.size() - 1)));
|
||||
} else {
|
||||
// Restore pre-speculation recurrent state then re-decode accepted tokens.
|
||||
llama_spec_ckpt_restore(ctx, slot.id, slot.spec_ckpt.n_past, 0);
|
||||
|
||||
if (slot.spec_ckpt.sampler) {
|
||||
common_sampler_clone(slot.spec_ckpt.sampler, slot.ctx_sampling);
|
||||
}
|
||||
|
||||
if (!ids.empty()) {
|
||||
// Re-decode to advance recurrent state to the accepted position.
|
||||
const int n_re = (int)ids.size();
|
||||
llama_batch re_batch = llama_batch_init(n_re, 0, 1);
|
||||
common_batch_add(re_batch, slot.spec_ckpt.sampled, slot.spec_ckpt.n_past, { slot.id }, n_re == 1);
|
||||
for (int j = 0; j < n_re - 1; j++) {
|
||||
common_batch_add(re_batch, ids[j], slot.spec_ckpt.n_past + 1 + j, { slot.id }, j == n_re - 2);
|
||||
}
|
||||
|
||||
if (slot.has_mtp) {
|
||||
llama_set_embeddings(ctx, true);
|
||||
}
|
||||
|
||||
const int ret = llama_decode(ctx, re_batch);
|
||||
if (ret != 0) {
|
||||
SLT_ERR(slot, "failed to re-decode accepted tokens after checkpoint restore: %d\n", ret);
|
||||
}
|
||||
|
||||
if (slot.has_mtp) {
|
||||
llama_set_embeddings(ctx, false);
|
||||
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
||||
const float * emb = llama_get_embeddings_ith(ctx, -1);
|
||||
if (emb) {
|
||||
slot.mtp_hidden_state.resize(n_embd);
|
||||
memcpy(slot.mtp_hidden_state.data(), emb, n_embd * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
for (llama_token id : ids) {
|
||||
common_sampler_accept(slot.ctx_sampling, ctx, id, true);
|
||||
}
|
||||
|
||||
llama_batch_free(re_batch);
|
||||
SLT_DBG(slot, "spec checkpoint restored: re-decoded %d tokens (rejected %d drafts)\n",
|
||||
n_re, (int)(n_draft - (ids.size() - 1)));
|
||||
}
|
||||
}
|
||||
|
||||
discard_speculative_checkpoint(slot, ctx);
|
||||
}
|
||||
|
||||
void server_context::speculative_decoding_accept() {
|
||||
for (auto& slot : slots) {
|
||||
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch_dft.empty()) {
|
||||
@@ -3739,7 +3843,14 @@ void server_context::speculative_decoding_accept() {
|
||||
slot.sampled = ids.back(); // last accepted token
|
||||
slot.n_past = slot.cache_tokens.n_tokens();
|
||||
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
|
||||
// for recurrent/hybrid models: if any drafts were rejected, restore recurrent state
|
||||
const bool any_rejected = (ids.size() - 1) < n_draft;
|
||||
if (any_rejected && slot.spec_ckpt.valid) {
|
||||
restore_speculative_checkpoint(slot, ctx, model, ids, n_draft);
|
||||
} else {
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
|
||||
discard_speculative_checkpoint(slot, ctx);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < ids.size(); ++i) {
|
||||
completion_token_output result;
|
||||
@@ -4305,6 +4416,23 @@ void server_context::update_slots() {
|
||||
// make sure we're in the right embedding mode
|
||||
llama_set_embeddings(ctx, batch_type == 1);
|
||||
|
||||
if (llama_model_has_recurrent(model)) {
|
||||
const int ckpt_mode = params_base.speculative.recurrent_ckpt_mode;
|
||||
|
||||
for (auto & slot : slots) {
|
||||
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch_dft.empty()) {
|
||||
continue;
|
||||
}
|
||||
if (save_speculative_checkpoint(slot, model, ctx, ckpt_mode)) {
|
||||
const char * mode_name = slot.spec_ckpt.per_step_enabled ? "per-step" : "shadow/cpu";
|
||||
SLT_DBG(slot, "spec checkpoint saved (mode=%s), n_past_pre_spec=%d\n",
|
||||
mode_name, slot.spec_ckpt.n_past);
|
||||
} else {
|
||||
SLT_WRN(slot, "%s", "failed to save spec checkpoint\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// process the created batch of tokens
|
||||
process_batch_tokens(n_batch); // Decode with batch
|
||||
|
||||
|
||||
Reference in New Issue
Block a user