Port speculative decoding from upstream to llama-server (#645)

* server : integrate speculative decoding

* server: Fix field names

* server: fix include, whitespace

* fix compile errors in speculative.cpp

* add llama_sampling_sample_and_accept_n to sampling

* finish porting speculative decoding in server

* port functions from common/speculative, common/sampling

* remove arg

* fix function names

* init params_dft to none

* correct value for n_ctx

* prefix kv cache tensors with model name to avoid conflict

* fix call arguments

* fix spec decoding args

* correct slot.id

* use n_max

* port the rest of sampling funcs

* fix func arguments

* slot.id starts at 1?

* Revert "prefix kv cache tensors with model name to avoid conflict"

This reverts commit fbd5dfd866.

* disable draft logging

* disable logging in speculative.cpp

in mainline, these would be LOG_DEBUG, but since ik_llama doesnt support
it, logging is disabled entirely

* add more draft model parameters

* fix

* pass flash_attn

* add speculative params for parity

* set speculative params in launch_slot_with_task instead
This commit is contained in:
g2mt
2025-08-15 21:26:44 -07:00
committed by GitHub
parent 2e2abddaa8
commit b6bc5eedad
8 changed files with 655 additions and 41 deletions

View File

@@ -2,6 +2,8 @@
#include "utils.hpp"
#include "common.h"
#include "speculative.h"
#include "sampling.h"
#include "json-schema-to-grammar.h"
#include "llama.h"
#include "grammar-parser.h"
@@ -148,14 +150,14 @@ static std::string remove_simple_function_calls(const std::string& content) {
size_t pos = 0;
while ((pos = cleaned.find(func_pattern, pos)) != std::string::npos) {
size_t func_start = pos;
// Find the opening brace for arguments
size_t brace_pos = cleaned.find('{', pos);
if (brace_pos == std::string::npos) {
pos += func_pattern.length();
continue;
}
// Find the matching closing brace
int brace_count = 1;
size_t end_pos = brace_pos + 1;
@@ -164,7 +166,7 @@ static std::string remove_simple_function_calls(const std::string& content) {
else if (cleaned[end_pos] == '}') brace_count--;
end_pos++;
}
if (brace_count == 0) {
// Remove the entire function call
cleaned.erase(func_start, end_pos - func_start);
@@ -186,7 +188,7 @@ static std::string remove_xml_function_calls(const std::string& content) {
pos = tool_call_start + 11;
continue;
}
// Remove the entire XML tool call block
cleaned.erase(tool_call_start, tool_call_end - tool_call_start + 12);
pos = tool_call_start;
@@ -196,17 +198,17 @@ static std::string remove_xml_function_calls(const std::string& content) {
static std::string clean_all_function_call_formats(const std::string& content) {
std::string cleaned = content;
// Remove XML format first
cleaned = remove_xml_function_calls(cleaned);
// Then remove simple format
cleaned = remove_simple_function_calls(cleaned);
// Trim whitespace from cleaned content
cleaned.erase(0, cleaned.find_first_not_of(" \t\n\r"));
cleaned.erase(cleaned.find_last_not_of(" \t\n\r") + 1);
return cleaned;
}
@@ -230,6 +232,13 @@ struct slot_params {
bool timings_per_token = false;
json input_prefix;
json input_suffix;
// speculative decoding parameters
struct {
int n_max = 16; // max drafted tokens
int n_min = 0; // min drafted tokens to accept
float p_min = 0.75f; // min probability required to accept a token in the draft
} speculative;
};
struct server_slot {
@@ -293,6 +302,15 @@ struct server_slot {
int32_t ga_n = 1; // group-attention factor
int32_t ga_w = 512; // group-attention width
// speculative decoding
struct llama_speculative * spec = nullptr;
llama_context * ctx_dft = nullptr;
llama_batch batch_spec = {};
// speculative decoding stats
int32_t n_draft_total = 0; // Total draft tokens generated
int32_t n_draft_accepted = 0; // Draft tokens actually accepted
int32_t n_past_se = 0; // self-extend
// stats
@@ -321,28 +339,32 @@ struct server_slot {
n_past_se = 0;
generated_token_probs.clear();
// Reset streaming tool call state
previous_msg = ik_chat_msg();
current_msg = ik_chat_msg();
tool_call_ids.clear();
// Reset speculative decoding stats
n_draft_total = 0;
n_draft_accepted = 0;
}
// Update chat message and compute diffs for streaming tool calls
// Based on original llama.cpp update_chat_msg pattern
const ik_chat_msg & update_chat_msg(std::vector<ik_chat_msg_diff> & diffs) {
ik_chat_msg previous = current_msg;
try {
// Parse generated text incrementally (is_partial = true during generation)
bool is_partial = !stopped_eos && !stopped_word && !stopped_limit;
ik_chat_msg new_msg = parse_chat_message_incremental(generated_text, is_partial, oaicompat_model);
if (!new_msg.empty()) {
// Ensure tool call IDs are set consistently across streaming chunks
new_msg.ensure_tool_call_ids_set(tool_call_ids, generate_tool_call_id);
current_msg = new_msg;
// Compute diffs for streaming
diffs = ik_chat_msg_diff::compute_diffs(previous, current_msg);
}
@@ -350,7 +372,7 @@ struct server_slot {
// If parsing fails, don't update current_msg and return empty diffs
diffs.clear();
}
return current_msg;
}
@@ -413,17 +435,17 @@ struct server_slot {
timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed;
timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
timings.predicted_n = n_decoded;
timings.predicted_ms = (ggml_time_us() - t_start_generation) / 1e3;
timings.predicted_per_token_ms = t_token_generation / n_decoded;
timings.predicted_per_second = 1e3 / t_token_generation * n_decoded;
//// Add speculative metrics
//if (n_draft_total > 0) {
// timings.draft_n = n_draft_total;
// timings.draft_n_accepted = n_draft_accepted;
//}
// Add speculative metrics
if (n_draft_total > 0) {
timings.draft_n = n_draft_total;
timings.draft_n_accepted = n_draft_accepted;
}
return timings;
}
@@ -797,6 +819,11 @@ struct server_context {
bool clean_kv_cache = true;
bool add_bos_token = true;
// For speculative decoding
llama_model * model_draft = nullptr;
llama_context * ctx_draft = nullptr;
llama_context_params cparams_dft;
int32_t n_ctx; // total context for all clients / slots
// system prompt
@@ -829,11 +856,28 @@ struct server_context {
model = nullptr;
}
// Free draft model and context if they exist
if (ctx_draft) {
llama_free(ctx_draft);
ctx_draft = nullptr;
}
if (model_draft) {
llama_free_model(model_draft);
model_draft = nullptr;
}
// Clear any sampling context
for (server_slot & slot : slots) {
if (slot.ctx_sampling != nullptr) {
llama_sampling_free(slot.ctx_sampling);
}
if (slot.ctx_dft) {
llama_free(slot.ctx_dft);
}
if (slot.spec) {
llama_speculative_free(slot.spec);
}
llama_batch_free(slot.batch_spec);
}
llama_batch_free(batch);
@@ -869,6 +913,41 @@ struct server_context {
chat_templates = llama_chat_templates_from_model(model, params.chat_template);
}
GGML_ASSERT(chat_templates.template_default.get() != nullptr);
// Load draft model for speculative decoding if specified
if (!params.model_draft.empty()) {
LOG_INFO("loading draft model", {{"model", params.model_draft}});
gpt_params params_dft;
params_dft.model = params.model_draft;
params_dft.n_ctx = params.n_ctx_draft == 0 ? params.n_ctx / params.n_parallel : params.n_ctx_draft;
params_dft.n_gpu_layers = params.n_gpu_layers_draft;
params_dft.n_parallel = 1;
params_dft.cache_type_k = params.cache_type_k_draft.empty() ? params.cache_type_k : params.cache_type_k_draft;
params_dft.cache_type_v = params.cache_type_v_draft.empty() ? params.cache_type_v : params.cache_type_v_draft;
params_dft.flash_attn = params.flash_attn;
llama_init_result llama_init_dft = llama_init_from_gpt_params(params_dft);
llama_model * model_dft = llama_init_dft.model;
if (model_dft == nullptr) {
LOG_ERROR("failed to load draft model", {{"model", params.model_draft}});
return false;
}
if (!llama_speculative_are_compatible(ctx, llama_init_dft.context)) {
LOG_ERROR("the draft model is not compatible with the target model", {});
return false;
}
const int n_ctx_dft = llama_n_ctx(llama_init_dft.context);
cparams_dft = llama_context_params_from_gpt_params(params_dft);
cparams_dft.n_batch = n_ctx_dft;
model_draft = llama_init_dft.model;
ctx_draft = llama_init_dft.context;
}
return true;
}
@@ -943,6 +1022,23 @@ struct server_context {
slot.sparams = params.sparams;
// Initialize speculative decoding if a draft model is loaded
if (ctx_draft) {
slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
slot.ctx_dft = llama_new_context_with_model(model_draft, cparams_dft);
if (slot.ctx_dft == nullptr) {
LOG_ERROR("failed to create draft context", {});
return;
}
slot.spec = llama_speculative_init(slot.ctx_dft);
if (slot.spec == nullptr) {
LOG_ERROR("failed to create speculator", {});
return;
}
}
slot.reset();
slots.push_back(slot);
@@ -1134,6 +1230,16 @@ struct server_context {
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
// speculative decoding parameters
slot.params.speculative.n_max = json_value(data, "speculative.n_max", params.n_draft);
slot.params.speculative.n_min = json_value(data, "speculative.n_min", params.n_draft_min);
slot.params.speculative.p_min = json_value(data, "speculative.p_min", params.p_draft_min);
// Clamp speculative parameters
slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 0);
slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0);
if (slot.sparams.penalty_last_n < -1) {
throw std::runtime_error("Error: repeat_last_n must be >= -1");
}
@@ -2737,6 +2843,118 @@ struct server_context {
slot.i_batch = -1;
}
// Do speculative decoding
for (auto & slot : slots) {
if (!slot.is_processing() || !slot.spec) {
continue;
}
if (slot.state != SLOT_STATE_PROCESSING) {
continue;
}
// determine the max draft that fits the current slot state
int n_draft_max = slot.params.speculative.n_max;
// note: n_past is not yet increased for the `id` token sampled above
// also, need to leave space for 1 extra token to allow context shifts
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
if (slot.n_predict > 0) {
n_draft_max = std::min(n_draft_max, slot.n_predict - slot.n_decoded - 1);
}
LOG_VERBOSE("max possible draft", {
{"id_slot", slot.id},
{"n_draft_max", n_draft_max}
});
if (n_draft_max < slot.params.speculative.n_min) {
LOG_VERBOSE("the max possible draft is too small", {
{"id_slot", slot.id},
{"n_draft_max", n_draft_max},
{"n_min", slot.params.speculative.n_min}
});
continue;
}
llama_token id = slot.sampled;
struct llama_speculative_params params_spec;
params_spec.n_draft = n_draft_max;
params_spec.n_reuse = cparams_dft.n_ctx - slot.params.speculative.n_max;
params_spec.p_min = slot.params.speculative.p_min;
const std::vector<llama_token> & cached_text_tokens = slot.cache_tokens;
std::vector<llama_token> draft = llama_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
// ignore small drafts
if (slot.params.speculative.n_min > (int) draft.size()) {
LOG_VERBOSE("ignoring small draft", {
{"id_slot", slot.id},
{"draft_size", (int) draft.size()},
{"n_min", slot.params.speculative.n_min}
});
continue;
}
// keep track of total number of drafted tokens tested
slot.n_draft_total += draft.size();
// construct the speculation batch
llama_batch_clear(slot.batch_spec);
llama_batch_add(slot.batch_spec, id, slot.n_past, { slot.id + 1 }, true);
for (size_t i = 0; i < draft.size(); ++i) {
llama_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id + 1 }, true);
}
LOG_VERBOSE("decoding speculative batch", {
{"id_slot", slot.id},
{"size", slot.batch_spec.n_tokens}
});
llama_decode(ctx, slot.batch_spec);
// the accepted tokens from the speculation
std::vector<llama_token> ids = llama_sampling_sample_and_accept_n(slot.ctx_sampling, ctx, draft);
slot.n_past += ids.size();
slot.n_decoded += ids.size();
// update how many tokens out of those tested were accepted
slot.n_draft_accepted += ids.size() - 1;
slot.cache_tokens.push_back(id);
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1);
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;
result.tok = ids[i];
result.text_to_send = llama_token_to_piece(ctx, result.tok, params.special);
// result.prob = 1.0f; // set later
if (!process_token(result, slot)) {
// release slot because of stop condition
slot.release();
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
break;
}
}
LOG_VERBOSE("speculative decoding result", {
{"id_slot", slot.id},
{"accepted", (int) ids.size() - 1},
{"total", (int) draft.size()},
{"new_n_past", slot.n_past}
});
}
}
LOG_VERBOSE("run slots completed", {});
@@ -2763,10 +2981,10 @@ static json format_final_response_oaicompat(const json& request, json result, co
// Parse tool calls using model-specific format detection
std::string model_name = json_value(request, "model", std::string(""));
// Use the same parsing logic as streaming path for consistency
ik_chat_msg parsed_msg = parse_chat_message_incremental(content, false, model_name);
// Convert to JSON format for compatibility
json tool_calls = json::array();
for (const auto & tc : parsed_msg.tool_calls) {
@@ -2779,9 +2997,9 @@ static json format_final_response_oaicompat(const json& request, json result, co
{"id", tc.id}
});
}
bool has_tool_calls = !tool_calls.empty();
// Use cleaned content from parser (following original llama.cpp pattern)
if (has_tool_calls) {
content = parsed_msg.content; // Parser already cleaned the content
@@ -2863,14 +3081,14 @@ static std::vector<json> format_partial_response_oaicompat(server_task_result ta
// Use generated_text (complete content) for finish_reason logic, not content (empty in streaming)
std::string generated_text = json_value(result, "generated_text", std::string(""));
ik_chat_msg final_msg = parse_chat_message_incremental(generated_text, false, modelname);
// Debug logging
LOG_INFO("DEBUG: Streaming finish_reason check", {
{"generated_text", generated_text},
{"model_name", modelname},
{"model_name", modelname},
{"tool_calls_count", final_msg.tool_calls.size()}
});
finish_reason = final_msg.tool_calls.empty() ? "stop" : "tool_calls";
}
@@ -2878,18 +3096,18 @@ static std::vector<json> format_partial_response_oaicompat(server_task_result ta
// Follow original llama.cpp pattern: Always process diffs and add final chunk
std::vector<json> streaming_chunks;
// Extract diffs from task result (populated by send_partial_response)
// Following original llama.cpp pattern where diffs are stored in task result
std::vector<ik_chat_msg_diff> diffs;
if (result.contains("oaicompat_msg_diffs") && result["oaicompat_msg_diffs"].is_array()) {
for (const auto & diff_json : result["oaicompat_msg_diffs"]) {
ik_chat_msg_diff diff;
// Extract content delta
diff.content_delta = diff_json.value("content_delta", "");
// Extract tool call data
if (diff_json.contains("tool_call_index")) {
diff.tool_call_index = diff_json["tool_call_index"];
@@ -2902,13 +3120,13 @@ static std::vector<json> format_partial_response_oaicompat(server_task_result ta
} else {
diff.tool_call_index = std::string::npos;
}
diffs.push_back(diff);
}
}
streaming_chunks = generate_streaming_chunks(diffs, completion_id, modelname);
// Always add final chunk (like original llama.cpp)
if (!finish_reason.empty()) {
json finish_chunk = {
@@ -2922,6 +3140,7 @@ static std::vector<json> format_partial_response_oaicompat(server_task_result ta
};
streaming_chunks.push_back(finish_chunk);
}
if (server_task_result_dict.count(task_result.id) > 0)
{
for (auto& chunk : streaming_chunks)
@@ -3092,7 +3311,7 @@ int main(int argc, char ** argv) {
// TODO: not great to use extern vars
server_log_json = params.log_json;
server_verbose = params.verbosity > 0;
// struct that contains llama context and inference
server_context ctx_server;