server: improve speed of speculative decoding (#1119)

* server: improve speed of speculative decoding

change logs

rpc: add recompute

spec dec fix

* Fix n_batch_size not set to context size for draft model

---------

Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
firecoperana
2026-01-10 00:01:22 -06:00
committed by GitHub
parent 52ad1c6421
commit c03ee1a4d2
7 changed files with 164 additions and 135 deletions

View File

@@ -484,7 +484,7 @@ bool server_sent_event(httplib::DataSink& sink, const json& data) {
data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row).
LOG_VERBOSE("data stream, to_send: %s", str.c_str());
//LOG_VERBOSE("data stream, to_send: %s", str.c_str());
return sink.write(str.c_str(), str.size());
}

View File

@@ -336,6 +336,10 @@ public:
llama_pos pos_next() const;
int n_tokens() const {
return tokens.size();
}
// for debugging
std::string str() const;

View File

@@ -117,12 +117,13 @@ bool server_context::load_model(const gpt_params& params_) {
LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n");
gpt_params params_dft;
params_dft.devices = params.devices_draft;
params_dft.model = params.model_draft;
params_dft.devices = params.devices_draft;
params_dft.model = params.model_draft;
params_dft.n_gpu_layers = params.n_gpu_layers_draft;
params_dft.rpc_servers = params.rpc_servers;
params_dft.cache_type_k = params.cache_type_k_draft.empty() ? params.cache_type_k : params.cache_type_k_draft;
params_dft.cache_type_v = params.cache_type_v_draft.empty() ? params.cache_type_v : params.cache_type_v_draft;
params_dft.flash_attn = params.flash_attn;
params_dft.flash_attn = params.flash_attn;
if (!params.draft_params.empty()) {
auto [argc, argv] = parse_command_line("llama-server " + params.draft_params);
if (!gpt_params_parse(argc, argv, params_dft)) {
@@ -138,7 +139,7 @@ bool server_context::load_model(const gpt_params& params_) {
}
params_dft.n_ctx = params_dft.n_ctx == 0 ? params.n_ctx / params.n_parallel : params_dft.n_ctx;
params_dft.n_parallel = 1;
params_dft.n_batch = params_dft.n_ctx;
llama_init_result llama_init_dft = llama_init_from_gpt_params(params_dft);
llama_model* model_dft = llama_init_dft.model;
@@ -154,7 +155,6 @@ bool server_context::load_model(const gpt_params& params_) {
const int n_ctx_dft = llama_n_ctx(llama_init_dft.context);
cparams_dft = llama_context_params_from_gpt_params(params_dft);
cparams_dft.n_batch = n_ctx_dft;
model_draft = llama_init_dft.model;
ctx_draft = llama_init_dft.context;
@@ -317,6 +317,10 @@ void server_slot::reset() {
stopping_word = "";
n_past = 0;
n_sent_text = 0;
drafted.clear();
i_batch_dft.clear();
n_sent_token_probs = 0;
infill = false;
ga_i = 0;
@@ -368,6 +372,31 @@ void server_slot::add_token_string(const completion_token_output& token) {
generated_token_probs.push_back(token);
}
int server_slot::get_n_draft_max() const {
if (!ctx_dft) {
return 0;
}
// determine the max draft that fits the current slot state
int n_draft_max = params.speculative.n_max;
// note: slot.prompt is not yet expanded with the `id` token sampled above
// also, need to leave space for 1 extra token to allow context shifts
n_draft_max = std::min(n_draft_max, n_ctx - n_past - 2);
if (n_remaining > 0) {
n_draft_max = std::min(n_draft_max, n_remaining - 1);
}
SLT_DBG(*this, "max possible draft: %d\n", n_draft_max);
if (n_draft_max < params.speculative.n_min) {
SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, params.speculative.n_min);
n_draft_max = 0;
}
return n_draft_max;
}
void server_slot::release() {
if (state == SLOT_STATE_PROCESSING) {
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
@@ -468,48 +497,43 @@ size_t server_slot::find_stopping_strings(const std::string& text, const size_t
void server_slot::print_timings() const {
char buffer[512];
double t_token = t_prompt_processing / n_prompt_tokens_processed;
double n_tokens_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
snprintf(buffer, 512, "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)",
t_prompt_processing, n_prompt_tokens_processed,
t_token, n_tokens_second);
//snprintf(buffer, 512, "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)",
// t_prompt_processing, n_prompt_tokens_processed,
// t_token, n_tokens_second);
LOG_INFO(buffer, {
{"id_slot", id},
{"id_task", id_task},
{"t_prompt_processing", t_prompt_processing},
{"n_prompt_tokens_processed", n_prompt_tokens_processed},
{"t_token", t_token},
{"n_tokens_second", n_tokens_second},
});
//LOG_INFO(buffer, {});
t_token = t_token_generation / n_decoded;
n_tokens_second = 1e3 / t_token_generation * n_decoded;
double t_token_gen = t_token_generation / n_decoded;
double n_tokens_second_gen = 1e3 / t_token_generation * n_decoded;
snprintf(buffer, 512, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)",
t_token_generation, n_decoded,
t_token, n_tokens_second);
//snprintf(buffer, 512, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)",
// t_token_generation, n_decoded,
// t_token, n_tokens_second);
LOG_INFO(buffer, {
{"id_slot", id},
{"id_task", id_task},
{"t_token_generation", t_token_generation},
{"n_decoded", n_decoded},
{"t_token", t_token},
{"n_tokens_second", n_tokens_second},
});
//LOG_INFO(buffer, {});
snprintf(buffer, 512, " total time = %10.2f ms", t_prompt_processing + t_token_generation);
//snprintf(buffer, 512, " total time = %10.2f ms", t_prompt_processing + t_token_generation);
LOG_INFO(buffer, {
{"id_slot", id},
{"id_task", id_task},
{"t_prompt_processing", t_prompt_processing},
{"t_token_generation", t_token_generation},
{"t_total", t_prompt_processing + t_token_generation},
});
//LOG_INFO(buffer, {});
SLT_INF(*this,
"\n"
"prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
" eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
" total time = %10.2f ms / %5d tokens\n",
t_prompt_processing, n_prompt_tokens_processed, t_token, n_tokens_second,
t_token_generation, n_decoded, t_token_gen, n_tokens_second_gen,
t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
if (n_draft_total > 0) {
const float draft_ratio = (float)n_draft_accepted / n_draft_total;
SLT_CNT(*this,
"draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n",
draft_ratio, n_draft_accepted, n_draft_total
);
}
}
void server_metrics::init() {
@@ -2173,31 +2197,62 @@ void server_context::update_slots() {
continue;
}
slot.i_batch = batch.n_tokens;
// generate draft tokens in speculative decoding mode
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
// perform the speculative drafting for all sequences at the same time in a single batch
int n_draft_max = slot.get_n_draft_max();
if (n_draft_max > 0) {
if (mctx) {
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
GGML_ABORT("not supported by multimodal");
}
const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
struct llama_speculative_params params_spec;
params_spec.n_draft = n_draft_max;
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
params_spec.p_min = slot.params.speculative.p_min;
const llama_tokens& cached_text_tokens = slot.cache_tokens.get_text_tokens();
llama_tokens draft = llama_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
// TODO: we always have to take into account the "system_tokens"
// this is not great and needs to be improved somehow
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.cache_tokens.pos_next(), { slot.id }, true);
slot.n_past += 1;
if (slot.params.cache_prompt) {
// add the sampled token to the batch
slot.i_batch_dft.push_back(batch.n_tokens);
llama_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
slot.cache_tokens.push_back(slot.sampled);
if (slot.params.speculative.n_min > (int)draft.size()) {
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int)draft.size(), slot.params.speculative.n_min);
// fallback to normal decoding
slot.i_batch = slot.i_batch_dft[0];
slot.drafted.clear();
slot.i_batch_dft.clear();
}
else {
// keep track of total number of drafted tokens tested
slot.n_draft_total += draft.size();
// add all drafted tokens to the batch
for (size_t i = 0; i < draft.size(); i++) {
slot.i_batch_dft.push_back(batch.n_tokens);
llama_batch_add(batch, draft[i], slot.cache_tokens.pos_next(), { slot.id }, true);
slot.cache_tokens.push_back(draft[i]);
}
slot.drafted = std::move(draft);
}
}
else {
// no speculative decoding
slot.i_batch = batch.n_tokens;
LOG_VERBOSE("slot decode token", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
{"n_ctx", n_ctx},
{"n_past", slot.n_past},
{"n_system_tokens", system_tokens.size()},
{"n_cache_tokens", slot.cache_tokens.size()},
{"truncated", slot.truncated}
});
llama_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
slot.cache_tokens.push_back(slot.sampled);
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
slot.n_ctx, slot.cache_tokens.size(), slot.truncated);
}
slot.n_past = slot.cache_tokens.n_tokens();
}
// process in chunks of params.n_batch
int32_t n_batch = llama_n_batch(ctx);
int32_t n_ubatch = llama_n_ubatch(ctx);
@@ -2638,6 +2693,10 @@ void server_context::update_slots() {
continue; // continue loop of n_batch
}
// technically, measuring the time here excludes the sampling time for the last batch
// but on the other hand, we don't want to do too many system calls to measure the time, so it's ok
const int64_t t_current = ggml_time_us();
for (auto& slot : slots) {
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) {
continue; // continue loop of slots
@@ -2652,6 +2711,9 @@ void server_context::update_slots() {
}
completion_token_output result;
if (slot.i_batch_dft.size() > 0) {
continue; // sample using speculative decoding
}
const int tok_idx = slot.i_batch - i;
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, tok_idx);
@@ -2667,7 +2729,8 @@ void server_context::update_slots() {
metrics.on_prompt_eval(slot);
}
slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;
//slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
result.tok = id;
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
@@ -2687,92 +2750,34 @@ void server_context::update_slots() {
slot.i_batch = -1;
}
// Do speculative decoding
// speculative decoding - main model sample and accept
for (auto& slot : slots) {
if (!slot.is_processing() || !slot.spec) {
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch_dft.empty()) {
continue;
}
if (slot.state != SLOT_STATE_PROCESSING) {
continue;
}
if (mctx) {
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
GGML_ABORT("not supported by multimodal");
}
// determine the max draft that fits the current slot state
int n_draft_max = slot.params.speculative.n_max;
// note: n_past is not yet increased for the `id` token sampled above
// also, need to leave space for 1 extra token to allow context shifts
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
if (slot.n_predict > 0) {
n_draft_max = std::min(n_draft_max, slot.n_predict - slot.n_decoded - 1);
}
LOG_VERBOSE("max possible draft", {
{"id_slot", slot.id},
{"n_draft_max", n_draft_max}
});
if (n_draft_max < slot.params.speculative.n_min) {
LOG_VERBOSE("the max possible draft is too small", {
{"id_slot", slot.id},
{"n_draft_max", n_draft_max},
{"n_min", slot.params.speculative.n_min}
});
continue;
}
llama_token id = slot.sampled;
struct llama_speculative_params params_spec;
params_spec.n_draft = n_draft_max;
params_spec.n_reuse = cparams_dft.n_ctx - slot.params.speculative.n_max;
params_spec.p_min = slot.params.speculative.p_min;
const std::vector<llama_token>& cached_text_tokens = slot.cache_tokens.tokens_data();
std::vector<llama_token> draft = llama_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
// ignore small drafts
if (slot.params.speculative.n_min > (int)draft.size()) {
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int)draft.size(), slot.params.speculative.n_min);
continue;
}
// keep track of total number of drafted tokens tested
slot.n_draft_total += draft.size();
// construct the speculation batch
llama_batch_clear(slot.batch_spec);
llama_batch_add(slot.batch_spec, id, slot.cache_tokens.pos_next(), { slot.id }, true);
for (size_t i = 0; i < draft.size(); ++i) {
llama_batch_add(slot.batch_spec, draft[i], slot.cache_tokens.pos_next() + 1 + i, { slot.id }, true);
}
LOG_VERBOSE("decoding speculative batch", {
{"id_slot", slot.id},
{"size", slot.batch_spec.n_tokens}
});
llama_decode(ctx, slot.batch_spec);
size_t n_draft = slot.drafted.size();
// the accepted tokens from the speculation
std::vector<llama_token> ids = llama_sampling_sample_and_accept_n(slot.ctx_sampling, ctx, draft);
const auto ids = llama_sampling_sample_and_accept_n(slot.ctx_sampling, ctx, slot.i_batch_dft, slot.drafted);
slot.i_batch_dft.clear();
slot.drafted.clear();
slot.n_past += ids.size();
slot.n_decoded += ids.size();
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
// update how many tokens out of those tested were accepted
slot.n_draft_accepted += ids.size() - 1;
slot.cache_tokens.push_back(id);
// rollback to the state before sampling the draft tokens
slot.cache_tokens.keep_first(slot.cache_tokens.n_tokens() - n_draft);
// slot.n_past -= n_draft;
// add accepted tokens to the prompt
slot.cache_tokens.insert({ ids.begin(), ids.end() - 1 });
slot.sampled = ids.back(); // last accepted token
slot.n_past = slot.cache_tokens.n_tokens();
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
for (size_t i = 0; i < ids.size(); ++i) {
@@ -2795,11 +2800,11 @@ void server_context::update_slots() {
break;
}
}
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int)ids.size() - 1, (int)slot.drafted.size(), slot.n_past);
LOG_VERBOSE("speculative decoding result", {
{"id_slot", slot.id},
{"accepted", (int)ids.size() - 1},
{"total", (int)draft.size()},
{"total", (int)slot.drafted.size()},
{"new_n_past", slot.n_past}
});
}

View File

@@ -98,6 +98,11 @@ struct server_slot {
std::string generated_text;
// idx of draft tokens in the main batch
// non-empty if we went to evaluate draft tokens
// ref: https://github.com/ggml-org/llama.cpp/pull/17808
std::vector<int32_t> i_batch_dft;
std::vector<completion_token_output> generated_token_probs;
common_chat_msg chat_msg;
@@ -122,7 +127,9 @@ struct server_slot {
void prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens);
// sampling
llama_token sampled;
llama_token sampled; // in speculative mode, this is the last accepted token
llama_tokens drafted;
struct llama_sampling_params sparams;
llama_sampling_context* ctx_sampling = nullptr;
json json_schema;
@@ -168,6 +175,8 @@ struct server_slot {
void add_token_string(const completion_token_output& token);
int get_n_draft_max() const;
void release();
json get_formated_timings() const;