spec : add self speculative decoding, ngram and refactor (#1261)

* spec : add self speculative decoding and ngram-mod and refactor

common : use common_ prefix for common library function

llama : use LLAMA_TOKEN_NULL

spec : add self speculative decoding (no draft model required) + refactor

spec : add ngram-mod

spec : various improvements ton ngram-map + docs

spec : fix the check-rate logic of ngram-simple

common : add common_speculative_is_compat()

spec : simplify time measurement using common_time_meas

refactor common_sampler_init

refactor common_token_to_piece

refactor and fix cur_p bug

clean up

* spec : remove check rate

* spec: show warnings instead of abort

---------

Co-authored-by: firecoperana <firecoperana>
Co-authored-by: Sascha Rogmann <59577610+srogmann@users.noreply.github.com>
This commit is contained in:
firecoperana
2026-02-13 12:04:55 -06:00
committed by GitHub
parent 1fdbc0dafe
commit 1cb7e1bf39
54 changed files with 2652 additions and 779 deletions

View File

@@ -87,6 +87,13 @@
#endif
using json = nlohmann::ordered_json;
common_time_meas::common_time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
common_time_meas::~common_time_meas() {
if (t_start_us >= 0) {
t_acc += ggml_time_us() - t_start_us;
}
}
//
// Environment variable utils
//
@@ -403,7 +410,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
bool invalid_param = false;
std::string arg;
const std::string arg_prefix = "--";
llama_sampling_params & sparams = params.sparams;
common_params_sampling & sparams = params.sparams;
for (int i = 1; i < argc; i++) {
arg = argv[i];
@@ -440,7 +447,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
}
}
for (auto & rep : params.replacements_draft) {
for (auto & rep : params.speculative.replacements) {
string_process_escapes(rep.first);
string_process_escapes(rep.second);
}
@@ -566,7 +573,7 @@ std::vector<std::pair<T1,T2>> string_split_pairs(const std::string & str, char d
bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) {
const char split_delim = ',';
llama_sampling_params & sparams = params.sparams;
common_params_sampling & sparams = params.sparams;
if (arg == "-s" || arg == "--seed") {
CHECK_ARG
@@ -593,17 +600,17 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
}
if (arg == "-td" || arg == "--threads-draft") {
CHECK_ARG
params.n_threads_draft = std::stoi(argv[i]);
if (params.n_threads_draft <= 0) {
params.n_threads_draft = std::thread::hardware_concurrency();
params.speculative.n_threads = std::stoi(argv[i]);
if (params.speculative.n_threads <= 0) {
params.speculative.n_threads = std::thread::hardware_concurrency();
}
return true;
}
if (arg == "-tbd" || arg == "--threads-batch-draft") {
CHECK_ARG
params.n_threads_batch_draft = std::stoi(argv[i]);
if (params.n_threads_batch_draft <= 0) {
params.n_threads_batch_draft = std::thread::hardware_concurrency();
params.speculative.n_threads_batch = std::stoi(argv[i]);
if (params.speculative.n_threads_batch <= 0) {
params.speculative.n_threads_batch = std::thread::hardware_concurrency();
}
return true;
}
@@ -696,7 +703,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
}
if (arg == "-cd" || arg == "--ctx-size-draft") {
CHECK_ARG
params.n_ctx_draft = std::stoi(argv[i]);
params.speculative.n_ctx = std::stoi(argv[i]);
return true;
}
if (arg == "--grp-attn-n" || arg == "-gan") {
@@ -949,7 +956,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
std::string target = argv[i];
CHECK_ARG
std::string draft = argv[i];
params.replacements_draft.emplace_back(std::move(target), std::move(draft));
params.speculative.replacements.emplace_back(std::move(target), std::move(draft));
return true;
}
if (arg == "--cfg-negative-prompt") {
@@ -993,17 +1000,17 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
}
if (arg == "--draft" || arg == "--draft-max" || arg == "--draft-n") {
CHECK_ARG
params.n_draft = std::stoi(argv[i]);
params.speculative.n_max = std::stoi(argv[i]);
return true;
}
if (arg == "--draft-min" || arg == "--draft-n-min") {
CHECK_ARG
params.n_draft_min = std::stoi(argv[i]);
params.speculative.n_min = std::stoi(argv[i]);
return true;
}
if (arg == "--draft-p-min") {
CHECK_ARG
params.p_draft_min = std::stof(argv[i]);
params.speculative.p_min = std::stof(argv[i]);
return true;
}
if (arg == "--chunks") {
@@ -1033,7 +1040,54 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
}
if (arg == "-md" || arg == "--model-draft") {
CHECK_ARG
params.model_draft = argv[i];
params.speculative.model = argv[i];
return true;
}
if (arg == "--spec-type") {
CHECK_ARG
std::string value = argv[i];
if (value == "none") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
} else if (value == "ngram-cache") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE;
} else if (value == "ngram-simple") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE;
} else if (value == "ngram-map-k") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K;
} else if (value == "ngram-map-k4v") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V;
} else if (value == "ngram-mod") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD;
} else {
throw std::invalid_argument("unknown speculative decoding type without draft model");
}
return true;
}
if (arg == "--spec-ngram-size-n") {
CHECK_ARG
int value = std::stoi(argv[i]);
if (value < 1 || value > 1024) {
throw std::invalid_argument("ngram size N must be between 1 and 1024 inclusive");
}
params.speculative.ngram_size_n = value;
return true;
}
if (arg == "--spec-ngram-size-m") {
CHECK_ARG
int value = std::stoi(argv[i]);
if (value < 1 || value > 1024) {
throw std::invalid_argument("ngram size M must be between 1 and 1024 inclusive");
}
params.speculative.ngram_size_m = value;
return true;
}
if (arg == "--spec-ngram-min-hits") {
CHECK_ARG
int value = std::stoi(argv[i]);
if (value < 1) {
throw std::invalid_argument("ngram min hits must be at least 1");
}
params.speculative.ngram_min_hits = value;
return true;
}
if (arg == "-a" || arg == "--alias") {
@@ -1190,11 +1244,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "-ctkd" || arg == "--cache-type-k-draft") {
params.cache_type_k_draft = argv[++i];
params.speculative.cache_type_k = argv[++i];
return true;
}
if (arg == "-ctvd" || arg == "--cache-type-v-draft") {
params.cache_type_v_draft = argv[++i];
params.speculative.cache_type_v = argv[++i];
return true;
}
if (arg == "-mli" || arg == "--multiline-input") {
@@ -1304,7 +1358,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
}
if (arg == "-ngld" || arg == "--gpu-layers-draft" || arg == "--n-gpu-layers-draft") {
CHECK_ARG
params.n_gpu_layers_draft = std::stoi(argv[i]);
params.speculative.n_gpu_layers = std::stoi(argv[i]);
if (!llama_supports_gpu_offload()) {
fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers-draft option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
@@ -1409,7 +1463,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
}
if (arg == "-draft" || arg == "--draft-params") {
CHECK_ARG
params.draft_params = argv[i];
params.speculative.params = argv[i];
return true;
}
if (arg == "--cpu-moe" || arg == "-cmoe") {
@@ -1500,7 +1554,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
if (arg == "-devd" || arg == "--device-draft") {
CHECK_ARG
std::string value(argv[i]);
params.devices_draft = parse_device_list(value);
params.speculative.devices = parse_device_list(value);
return true;
}
if (arg == "-v" || arg == "--verbose") {
@@ -2111,7 +2165,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
#endif
void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
const llama_sampling_params & sparams = params.sparams;
const common_params_sampling & sparams = params.sparams;
std::string sampler_type_chars;
std::string sampler_type_names;
@@ -2165,7 +2219,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
"path to dynamic lookup cache to use for lookup decoding (updated by generation)" });
options.push_back({ "*", "-c, --ctx-size N", "size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx });
options.push_back({ "*", "-cd, --ctx-size-draft N", "size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.n_ctx_draft });
options.push_back({ "*", "-cd, --ctx-size-draft N", "size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.n_ctx });
options.push_back({ "*", "-cram, --cache-ram N", "set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)",params.cache_ram_mib });
options.push_back({ "*", "-crs, --cache-ram-similarity N", "max of similarity of prompt tokens to cache tokens that triggers prompt cache (default: %.2f).",params.cache_ram_similarity });
options.push_back({ "*", "-cram-n-min --cache-ram-n-min N", "minimum number of the cached tokens that triggers prompt cache (default: %d).", params.cache_ram_n_min });
@@ -2420,9 +2474,15 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" });
options.push_back({ "*", "-hft, --hf-token TOKEN", "Hugging Face access token (default: value from HF_TOKEN environment variable)" });
options.push_back({ "*", "--draft-max, --draft, --draft-n N",
"number of tokens to draft for speculative decoding (default: %d)", params.n_draft });
"number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max });
options.push_back({ "*", "--draft-min, --draft-n-min N", "minimum number of draft tokens to use for speculative decoding" });
options.push_back({ "*", "--draft-p-min P", "minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.p_draft_min });
options.push_back({ "*", "--draft-p-min P", "minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.speculative.p_min });
options.push_back({ "*", "--spec-type Name", "[none | ngram - cache | ngram - simple | ngram - map - k | ngram - map - k4v | ngram - mod]", "type of speculative decoding to use when no draft model is provided (default: %s)\n", (double)params.speculative.type});
options.push_back({ "*", "--spec-ngram-size-n N", "ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)\n",params.speculative.ngram_size_n });
options.push_back({ "*", "--spec-ngram-size-m N", "ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram (default: %d)\n", params.speculative.ngram_size_m });
options.push_back({ "*", "--spec-ngram-min-hits N", "minimum hits for ngram-map speculative decoding (default: %d)\n", params.speculative.ngram_min_hits });
options.push_back({ "retrieval" });
options.push_back({ "retrieval", " --context-file FNAME", "file to load context from (repeat to specify multiple files)" });
@@ -2998,7 +3058,7 @@ std::string fs_get_cache_file(const std::string & filename) {
//
struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
llama_init_result iparams;
auto mparams = llama_model_params_from_gpt_params(params);
auto mparams = common_model_params_to_llama(params);
llama_model * model = nullptr;
@@ -3007,7 +3067,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
} else if (!params.model_url.empty()) {
model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams);
} else {
model = llama_load_model_from_file(params.model.c_str(), mparams);
model = llama_model_load_from_file(params.model.c_str(), mparams);
}
if (model == NULL) {
@@ -3015,9 +3075,9 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
return iparams;
}
auto cparams = llama_context_params_from_gpt_params(params);
auto cparams = common_context_params_to_llama(params);
llama_context * lctx = llama_new_context_with_model(model, cparams);
llama_context * lctx = llama_init_from_model(model, cparams);
if (lctx == NULL) {
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
llama_free_model(model);
@@ -3096,7 +3156,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
if (llama_model_has_encoder(model)) {
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0));
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == -1) {
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
decoder_start_token_id = bos;
}
tmp.clear();
@@ -3124,7 +3184,7 @@ void llama_lora_adapters_apply(struct llama_context * ctx, std::vector<llama_lor
}
}
struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) {
struct llama_model_params common_model_params_to_llama(const gpt_params & params) {
auto mparams = llama_model_default_params();
mparams.devices = params.devices.c_str();
@@ -3215,7 +3275,7 @@ static ggml_type ggml_type_from_str(const std::string & s) {
throw std::runtime_error("Invalid graph reduce type: " + s);
}
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
struct llama_context_params common_context_params_to_llama(const gpt_params & params) {
auto cparams = llama_context_default_params();
int n_batch = params.n_batch;
int n_ubatch = params.n_ubatch;
@@ -3658,7 +3718,7 @@ void common_batch_add(
// Vocab utils
//
std::vector<llama_token> llama_tokenize(
std::vector<llama_token> common_tokenize(
const struct llama_context * ctx,
const std::string & text,
bool add_special,
@@ -3740,13 +3800,19 @@ std::string llama_token_to_piece(const struct llama_model* model, llama_token to
return piece;
}
std::string common_token_to_piece(const llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
std::string common_detokenize(const struct llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
return common_detokenize(vocab, tokens, special);
}
std::string common_detokenize(const struct llama_vocab * vocab, const std::vector<llama_token> & tokens, bool special) {
std::string text;
text.resize(std::max(text.capacity(), tokens.size()));
int32_t n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
int32_t n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
if (n_chars < 0) {
text.resize(-n_chars);
n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization
}
@@ -3756,11 +3822,25 @@ std::string common_token_to_piece(const llama_context * ctx, const std::vector<l
return text;
}
std::string common_token_to_piece(const struct llama_vocab * vocab, llama_token token, bool special) {
std::string piece;
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
const int n_chars = llama_token_to_piece_vocab(vocab, token, &piece[0], piece.size(), 0, special);
if (n_chars < 0) {
piece.resize(-n_chars);
int check = llama_token_to_piece_vocab(vocab, token, &piece[0], piece.size(), 0, special);
GGML_ASSERT(check == -n_chars);
} else {
piece.resize(n_chars);
}
return piece;
}
bool llama_should_add_bos_token(const llama_model * model) {
const int add_bos = llama_add_bos_token(model);
return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
const llama_vocab * vocab = llama_get_model_vocab(model);
return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(vocab) == LLAMA_VOCAB_TYPE_SPM);
}
@@ -4086,7 +4166,7 @@ void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const cha
void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const llama_context * lctx,
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) {
const llama_sampling_params & sparams = params.sparams;
const common_params_sampling & sparams = params.sparams;
fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT);
fprintf(stream, "build_number: %d\n", LLAMA_BUILD_NUMBER);
@@ -4198,7 +4278,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "top_n_sigma: %f # default: 0.0\n", sparams.top_n_sigma);
fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false");
fprintf(stream, "model: %s # default: %s\n", params.model.c_str(), DEFAULT_MODEL_PATH);
fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str());
fprintf(stream, "model_draft: %s # default:\n", params.speculative.model.c_str());
fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false");
fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers);
fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict);