spec : add self speculative decoding, ngram and refactor (#1261)

* spec : add self speculative decoding and ngram-mod and refactor

common : use common_ prefix for common library function

llama : use LLAMA_TOKEN_NULL

spec : add self speculative decoding (no draft model required) + refactor

spec : add ngram-mod

spec : various improvements ton ngram-map + docs

spec : fix the check-rate logic of ngram-simple

common : add common_speculative_is_compat()

spec : simplify time measurement using common_time_meas

refactor common_sampler_init

refactor common_token_to_piece

refactor and fix cur_p bug

clean up

* spec : remove check rate

* spec: show warnings instead of abort

---------

Co-authored-by: firecoperana <firecoperana>
Co-authored-by: Sascha Rogmann <59577610+srogmann@users.noreply.github.com>
This commit is contained in:
firecoperana
2026-02-13 12:04:55 -06:00
committed by GitHub
parent 1fdbc0dafe
commit 1cb7e1bf39
54 changed files with 2652 additions and 779 deletions

View File

@@ -136,7 +136,7 @@ int main(int argc, char ** argv) {
return 1;
}
llama_sampling_params & sparams = params.sparams;
common_params_sampling & sparams = params.sparams;
#ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("main", "log"));
@@ -211,8 +211,8 @@ int main(int argc, char ** argv) {
model = llama_init.model;
ctx = llama_init.context;
if (sparams.cfg_scale > 1.f) {
struct llama_context_params lparams = llama_context_params_from_gpt_params(params);
ctx_guidance = llama_new_context_with_model(model, lparams);
struct llama_context_params lparams = common_context_params_to_llama(params);
ctx_guidance = llama_init_from_model(model, lparams);
}
if (model == NULL) {
@@ -299,7 +299,7 @@ int main(int argc, char ** argv) {
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
LOG("tokenize the prompt\n");
embd_inp = ::llama_tokenize(ctx, prompt, true, true);
embd_inp = ::common_tokenize(ctx, prompt, true, true);
} else {
LOG("use session tokens\n");
embd_inp = session_tokens;
@@ -327,10 +327,10 @@ int main(int argc, char ** argv) {
if (ctx_guidance) {
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, true, true);
guidance_inp = ::common_tokenize(ctx_guidance, sparams.cfg_negative_prompt, true, true);
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true, true);
std::vector<llama_token> original_inp = ::common_tokenize(ctx, params.prompt, true, true);
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
original_prompt_len = original_inp.size();
@@ -447,7 +447,7 @@ int main(int argc, char ** argv) {
for (const auto & antiprompt : params.antiprompt) {
LOG_TEE("Reverse prompt: '%s'\n", antiprompt.c_str());
if (params.verbose_prompt) {
auto tmp = ::llama_tokenize(ctx, antiprompt, false, true);
auto tmp = ::common_tokenize(ctx, antiprompt, false, true);
for (int i = 0; i < (int) tmp.size(); i++) {
LOG_TEE("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx, tmp[i]).c_str());
}
@@ -462,7 +462,7 @@ int main(int argc, char ** argv) {
if (!params.input_prefix.empty()) {
LOG_TEE("Input prefix: '%s'\n", params.input_prefix.c_str());
if (params.verbose_prompt) {
auto tmp = ::llama_tokenize(ctx, params.input_prefix, true, true);
auto tmp = ::common_tokenize(ctx, params.input_prefix, true, true);
for (int i = 0; i < (int) tmp.size(); i++) {
LOG_TEE("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx, tmp[i]).c_str());
}
@@ -472,7 +472,7 @@ int main(int argc, char ** argv) {
if (!params.input_suffix.empty()) {
LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str());
if (params.verbose_prompt) {
auto tmp = ::llama_tokenize(ctx, params.input_suffix, false, true);
auto tmp = ::common_tokenize(ctx, params.input_suffix, false, true);
for (int i = 0; i < (int) tmp.size(); i++) {
LOG_TEE("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx, tmp[i]).c_str());
}
@@ -546,10 +546,10 @@ int main(int argc, char ** argv) {
antiprompt_ids.reserve(params.antiprompt.size());
for (const std::string & antiprompt : params.antiprompt) {
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
antiprompt_ids.emplace_back(::common_tokenize(ctx, antiprompt, false, true));
}
struct llama_sampling_context * ctx_sampling = common_sampler_init(llama_get_model_vocab(model), sparams);
common_sampler * ctx_sampling = common_sampler_init(model, sparams);
if (!ctx_sampling) {
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
exit(1);
@@ -565,7 +565,7 @@ int main(int argc, char ** argv) {
}
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == -1) {
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
decoder_start_token_id = llama_token_bos(model);
}
@@ -750,7 +750,7 @@ int main(int argc, char ** argv) {
LOG("saved session to %s\n", path_session.c_str());
}
const llama_token id = common_sampler_sample(ctx_sampling, ctx, ctx_guidance);
const llama_token id = common_sampler_sample_legacy(ctx_sampling, ctx, ctx_guidance);
common_sampler_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true);
@@ -861,7 +861,7 @@ int main(int argc, char ** argv) {
if (params.interactive) {
if (!params.antiprompt.empty()) {
// tokenize and inject first reverse prompt
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false, true);
const auto first_antiprompt = ::common_tokenize(ctx, params.antiprompt.front(), false, true);
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
is_antiprompt = true;
}
@@ -935,16 +935,16 @@ int main(int argc, char ** argv) {
? chat_add_and_format(model, *chat_templates, chat_msgs, "user", std::move(buffer))
: std::move(buffer);
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
const auto line_inp = ::llama_tokenize(ctx, user_inp, false, format_chat);
const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true);
const auto line_pfx = ::common_tokenize(ctx, params.input_prefix, false, true);
const auto line_inp = ::common_tokenize(ctx, user_inp, false, format_chat);
const auto line_sfx = ::common_tokenize(ctx, params.input_suffix, false, true);
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());
// if user stop generation mid-way, we must add EOT to finish model's last response
if (need_insert_eot && format_chat) {
llama_token eot = llama_token_eot(model);
embd_inp.push_back(eot == -1 ? llama_token_eos(model) : eot);
embd_inp.push_back(eot == LLAMA_TOKEN_NULL ? llama_token_eos(model) : eot);
need_insert_eot = false;
}
@@ -973,7 +973,7 @@ int main(int argc, char ** argv) {
if (n_past > 0 || waiting_for_first_input) {
if (is_interacting) {
common_sampler_reset(llama_get_model_vocab(model), ctx_sampling);
common_sampler_reset(ctx_sampling);
}
is_interacting = false;
waiting_for_first_input = false;