Port speculative decoding from upstream to llama-server (#645)

* server : integrate speculative decoding

* server: Fix field names

* server: fix include, whitespace

* fix compile errors in speculative.cpp

* add llama_sampling_sample_and_accept_n to sampling

* finish porting speculative decoding in server

* port functions from common/speculative, common/sampling

* remove arg

* fix function names

* init params_dft to none

* correct value for n_ctx

* prefix kv cache tensors with model name to avoid conflict

* fix call arguments

* fix spec decoding args

* correct slot.id

* use n_max

* port the rest of sampling funcs

* fix func arguments

* slot.id starts at 1?

* Revert "prefix kv cache tensors with model name to avoid conflict"

This reverts commit fbd5dfd866.

* disable draft logging

* disable logging in speculative.cpp

in mainline, these would be LOG_DEBUG, but since ik_llama doesnt support
it, logging is disabled entirely

* add more draft model parameters

* fix

* pass flash_attn

* add speculative params for parity

* set speculative params in launch_slot_with_task instead
This commit is contained in:
g2mt
2025-08-15 21:26:44 -07:00
committed by GitHub
parent 2e2abddaa8
commit b6bc5eedad
8 changed files with 655 additions and 41 deletions

View File

@@ -505,6 +505,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.n_ctx = std::stoi(argv[i]);
return true;
}
if (arg == "-cd" || arg == "--ctx-size-draft") {
CHECK_ARG
params.n_ctx_draft = std::stoi(argv[i]);
return true;
}
if (arg == "--grp-attn-n" || arg == "-gan") {
CHECK_ARG
params.grp_attn_n = std::stoi(argv[i]);
@@ -725,7 +730,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
}
}
return true;
}
}
if (arg == "--cfg-negative-prompt") {
CHECK_ARG
sparams.cfg_negative_prompt = argv[i];
@@ -765,11 +770,21 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.n_keep = std::stoi(argv[i]);
return true;
}
if (arg == "--draft") {
if (arg == "--draft" || arg == "--draft-max" || arg == "--draft-n") {
CHECK_ARG
params.n_draft = std::stoi(argv[i]);
return true;
}
if (arg == "--draft-min" || arg == "--draft-n-min") {
CHECK_ARG
params.n_draft_min = std::stoi(argv[i]);
return true;
}
if (arg == "--draft-p-min") {
CHECK_ARG
params.p_draft_min = std::stof(argv[i]);
return true;
}
if (arg == "--chunks") {
CHECK_ARG
params.n_chunks = std::stoi(argv[i]);
@@ -934,6 +949,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.cache_type_v = argv[++i];
return true;
}
if (arg == "-ctkd" || arg == "--cache-type-k-draft") {
params.cache_type_k_draft = argv[++i];
return true;
}
if (arg == "-ctvd" || arg == "--cache-type-v-draft") {
params.cache_type_v_draft = argv[++i];
return true;
}
if (arg == "-mli" || arg == "--multiline-input") {
params.multiline_input = true;
return true;
@@ -1071,7 +1094,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
size_t pos = 0;
while ((pos = servers.find(",")) != std::string::npos) {
std::string server = servers.substr(0, pos);
ggml_backend_rpc_buffer_type(server.c_str());
ggml_backend_rpc_buffer_type(server.c_str());
servers.erase(0, pos + 1);
}
ggml_backend_rpc_buffer_type(servers.c_str());
@@ -1693,7 +1716,6 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "speculative", "-td, --threads-draft N", "number of threads to use during generation (default: same as --threads)" });
options.push_back({ "speculative", "-tbd, --threads-batch-draft N",
"number of threads to use during batch and prompt processing (default: same as --threads-draft)" });
options.push_back({ "speculative", " --draft N", "number of tokens to draft for speculative decoding (default: %d)", params.n_draft });
options.push_back({ "speculative", "-ps, --p-split N", "speculative decoding split probability (default: %.1f)", (double)params.p_split });
options.push_back({ "*", "-lcs, --lookup-cache-static FNAME",
"path to static lookup cache to use for lookup decoding (not updated by generation)" });
@@ -1701,6 +1723,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
"path to dynamic lookup cache to use for lookup decoding (updated by generation)" });
options.push_back({ "*", "-c, --ctx-size N", "size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx });
options.push_back({ "*", "-cd, --ctx-size-draft N", "size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.n_ctx_draft });
options.push_back({ "*", "-n, --predict N", "number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)", params.n_predict });
options.push_back({ "*", "-b, --batch-size N", "logical maximum batch size (default: %d)", params.n_batch });
options.push_back({ "*", "-ub, --ubatch-size N", "physical maximum batch size (default: %d)", params.n_ubatch });
@@ -1811,6 +1834,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", "-nkvo, --no-kv-offload", "disable KV offload" });
options.push_back({ "*", "-ctk, --cache-type-k TYPE", "KV cache data type for K (default: %s)", params.cache_type_k.c_str() });
options.push_back({ "*", "-ctv, --cache-type-v TYPE", "KV cache data type for V (default: %s)", params.cache_type_v.c_str() });
options.push_back({ "*", "-ctkd, --cache-type-k-draft TYPE", "KV cache data type for K for the draft model" });
options.push_back({ "*", "-ctvd, --cache-type-v-draft TYPE", "KV cache data type for V for the draft model" });
options.push_back({ "perplexity" });
options.push_back({ "perplexity", " --all-logits", "return logits for all tokens in the batch (default: %s)", params.logits_all ? "true" : "false" });
@@ -1893,6 +1918,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" });
options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" });
options.push_back({ "*", "-hft, --hf-token TOKEN", "Hugging Face access token (default: value from HF_TOKEN environment variable)" });
options.push_back({ "*", "--draft-max, --draft, --draft-n N",
"number of tokens to draft for speculative decoding (default: %d)", params.n_draft });
options.push_back({ "*", "--draft-min, --draft-n-min N", "minimum number of draft tokens to use for speculative decoding" });
options.push_back({ "*", "--draft-p-min P", "minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.p_draft_min });
options.push_back({ "retrieval" });
options.push_back({ "retrieval", " --context-file FNAME", "file to load context from (repeat to specify multiple files)" });
@@ -2052,7 +2081,7 @@ std::string string_join(const std::vector<std::string> & strs, const std::string
if (strs.empty()) {
return "";
}
std::ostringstream oss;
for (size_t i = 0; i < strs.size(); ++i) {
if (i > 0) {