server : integrate speculative decoding

This commit is contained in:
T. M.
2025-07-25 02:51:00 +00:00
parent 4e9c78c039
commit de5ecab4fb
3 changed files with 533 additions and 5 deletions

View File

@@ -230,6 +230,13 @@ struct slot_params {
bool timings_per_token = false;
json input_prefix;
json input_suffix;
// speculative decoding parameters
struct {
int n_max = 0; // max drafted tokens
int n_min = 0; // min drafted tokens to accept
float p_min = 0.75f; // min probability required to accept a token in the draft
} speculative;
};
struct server_slot {
@@ -292,6 +299,15 @@ struct server_slot {
int32_t ga_i = 0; // group-attention state
int32_t ga_n = 1; // group-attention factor
int32_t ga_w = 512; // group-attention width
// speculative decoding
struct common_speculative * spec = nullptr;
llama_context * ctx_dft = nullptr;
llama_batch batch_spec = {};
// speculative decoding stats
int32_t n_draft_total = 0; // Total draft tokens generated
int32_t n_draft_accepted = 0; // Draft tokens actually accepted
int32_t n_past_se = 0; // self-extend
@@ -326,6 +342,10 @@ struct server_slot {
previous_msg = ik_chat_msg();
current_msg = ik_chat_msg();
tool_call_ids.clear();
// Reset speculative decoding stats
n_draft_total = 0;
n_draft_accepted = 0;
}
// Update chat message and compute diffs for streaming tool calls
@@ -419,11 +439,11 @@ struct server_slot {
timings.predicted_per_token_ms = t_token_generation / n_decoded;
timings.predicted_per_second = 1e3 / t_token_generation * n_decoded;
//// Add speculative metrics
//if (n_draft_total > 0) {
// timings.draft_n = n_draft_total;
// timings.draft_n_accepted = n_draft_accepted;
//}
// Add speculative metrics
if (n_draft_total > 0) {
timings.draft_n = n_draft_total;
timings.draft_n_accepted = n_draft_accepted;
}
return timings;
}
@@ -796,6 +816,10 @@ struct server_context {
bool clean_kv_cache = true;
bool add_bos_token = true;
// For speculative decoding
llama_init_result model_dft_owned;
llama_context_params cparams_dft;
int32_t n_ctx; // total context for all clients / slots
@@ -833,6 +857,13 @@ struct server_context {
if (slot.ctx_sampling != nullptr) {
llama_sampling_free(slot.ctx_sampling);
}
if (slot.ctx_dft) {
llama_free(slot.ctx_dft);
}
if (slot.spec) {
common_speculative_free(slot.spec);
}
llama_batch_free(slot.batch_spec);
}
llama_batch_free(batch);
@@ -860,6 +891,56 @@ struct server_context {
add_bos_token = llama_should_add_bos_token(model);
GGML_ASSERT(llama_add_eos_token(model) != 1);
// Load draft model for speculative decoding if specified
if (!params.speculative_model.empty()) {
LOG_INFO("loading draft model", {{"model", params.speculative_model}});
gpt_params params_dft = params;
params_dft.model = params.speculative_model;
params_dft.n_ctx = params.speculative_n_ctx == 0 ? params.n_ctx / params.n_parallel : params.speculative_n_ctx;
params_dft.n_gpu_layers = params.speculative_n_gpu_layers;
params_dft.n_parallel = 1;
params_dft.cache_type_k = params.speculative_cache_type_k;
params_dft.cache_type_v = params.speculative_cache_type_v;
llama_init_result llama_init_dft = llama_init_from_gpt_params(params_dft);
llama_model * model_dft = llama_init_dft.model;
if (model_dft == nullptr) {
LOG_ERROR("failed to load draft model", {{"model", params.speculative_model}});
return false;
}
if (!common_speculative_are_compatible(ctx, llama_init_dft.context)) {
LOG_ERROR("the draft model is not compatible with the target model", {});
return false;
}
// Store the draft context initialization parameters for later use
cparams_dft = llama_context_default_params();
cparams_dft.n_ctx = params_dft.n_ctx;
cparams_dft.n_batch = cparams_dft.n_ctx;
cparams_dft.n_ubatch = params_dft.n_ubatch;
cparams_dft.freq_base = params_dft.rope_freq_base;
cparams_dft.freq_scale = params_dft.rope_freq_scale;
cparams_dft.yarn_ext_factor = params_dft.yarn_ext_factor;
cparams_dft.yarn_attn_factor = params_dft.yarn_attn_factor;
cparams_dft.yarn_beta_fast = params_dft.yarn_beta_fast;
cparams_dft.yarn_beta_slow = params_dft.yarn_beta_slow;
cparams_dft.yarn_orig_ctx = params_dft.yarn_orig_ctx;
cparams_dft.clip_kqv = params_dft.clip_kqv;
cparams_dft.pooling_type = params_dft.pooling_type;
cparams_dft.defrag_thold = params_dft.defrag_thold;
cparams_dft.type_k = params_dft.type_k;
cparams_dft.type_v = params_dft.type_v;
cparams_dft.logits_all = false;
cparams_dft.embedding = false;
cparams_dft.offload_kqv = params_dft.offload_kqv;
// Keep the draft model alive
model_dft_owned = llama_init_dft;
}
return true;
}
@@ -909,6 +990,23 @@ struct server_context {
slot.ga_w = ga_w;
slot.sparams = params.sparams;
// Initialize speculative decoding if a draft model is loaded
if (model_dft_owned.context) {
slot.batch_spec = llama_batch_init(params.speculative_n_max + 1, 0, 1);
slot.ctx_dft = llama_init_from_model(model_dft_owned.model, cparams_dft);
if (slot.ctx_dft == nullptr) {
LOG_ERROR("failed to create draft context", {});
return;
}
slot.spec = common_speculative_init(slot.ctx_dft);
if (slot.spec == nullptr) {
LOG_ERROR("failed to create speculator", {});
return;
}
}
slot.reset();
@@ -1100,6 +1198,16 @@ struct server_context {
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
// speculative decoding parameters
slot.params.speculative.n_max = json_value(data, "speculative.n_max", 0);
slot.params.speculative.n_min = json_value(data, "speculative.n_min", 0);
slot.params.speculative.p_min = json_value(data, "speculative.p_min", 0.75f);
// Clamp speculative parameters
slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 0);
slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0);
if (slot.sparams.penalty_last_n < -1) {
throw std::runtime_error("Error: repeat_last_n must be >= -1");
@@ -2704,6 +2812,118 @@ struct server_context {
slot.i_batch = -1;
}
// Do speculative decoding
for (auto & slot : slots) {
if (!slot.is_processing() || !slot.spec) {
continue;
}
if (slot.state != SLOT_STATE_PROCESSING) {
continue;
}
// determine the max draft that fits the current slot state
int n_draft_max = slot.params.speculative.n_max;
// note: n_past is not yet increased for the `id` token sampled above
// also, need to leave space for 1 extra token to allow context shifts
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
if (slot.n_predict > 0) {
n_draft_max = std::min(n_draft_max, slot.n_predict - slot.n_decoded - 1);
}
LOG_VERBOSE("max possible draft", {
{"id_slot", slot.id},
{"n_draft_max", n_draft_max}
});
if (n_draft_max < slot.params.speculative.n_min) {
LOG_VERBOSE("the max possible draft is too small", {
{"id_slot", slot.id},
{"n_draft_max", n_draft_max},
{"n_min", slot.params.speculative.n_min}
});
continue;
}
llama_token id = slot.sampled;
struct common_speculative_params params_spec;
params_spec.n_draft = n_draft_max;
params_spec.n_reuse = cparams_dft.n_ctx - slot.params.speculative.n_max;
params_spec.p_min = slot.params.speculative.p_min;
const std::vector<llama_token> & cached_text_tokens = slot.cache_tokens;
std::vector<llama_token> draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
// ignore small drafts
if (slot.params.speculative.n_min > (int) draft.size()) {
LOG_VERBOSE("ignoring small draft", {
{"id_slot", slot.id},
{"draft_size", (int) draft.size()},
{"n_min", slot.params.speculative.n_min}
});
continue;
}
// keep track of total number of drafted tokens tested
slot.n_draft_total += draft.size();
// construct the speculation batch
llama_batch_clear(slot.batch_spec);
llama_batch_add(slot.batch_spec, id, slot.n_past, { slot.id + 1 }, true);
for (size_t i = 0; i < draft.size(); ++i) {
llama_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id + 1 }, true);
}
LOG_VERBOSE("decoding speculative batch", {
{"id_slot", slot.id},
{"size", slot.batch_spec.n_tokens}
});
llama_decode(ctx, slot.batch_spec);
// the accepted tokens from the speculation
std::vector<llama_token> ids = llama_sampling_sample_and_accept_n(slot.ctx_sampling, ctx, draft);
slot.n_past += ids.size();
slot.n_decoded += ids.size();
// update how many tokens out of those tested were accepted
slot.n_draft_accepted += ids.size() - 1;
slot.cache_tokens.push_back(id);
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1);
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;
result.tok = ids[i];
result.text_to_send = llama_token_to_piece(ctx, result.tok, params.special);
result.prob = 1.0f; // set later
if (!process_token(result, slot)) {
// release slot because of stop condition
slot.release();
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
break;
}
}
LOG_VERBOSE("speculative decoding result", {
{"id_slot", slot.id},
{"accepted", (int) ids.size() - 1},
{"total", (int) draft.size()},
{"new_n_past", slot.n_past}
});
}
}
LOG_VERBOSE("run slots completed", {});