spec : add self speculative decoding, ngram and refactor (#1261)

* spec : add self speculative decoding and ngram-mod and refactor

common : use common_ prefix for common library function

llama : use LLAMA_TOKEN_NULL

spec : add self speculative decoding (no draft model required) + refactor

spec : add ngram-mod

spec : various improvements ton ngram-map + docs

spec : fix the check-rate logic of ngram-simple

common : add common_speculative_is_compat()

spec : simplify time measurement using common_time_meas

refactor common_sampler_init

refactor common_token_to_piece

refactor and fix cur_p bug

clean up

* spec : remove check rate

* spec: show warnings instead of abort

---------

Co-authored-by: firecoperana <firecoperana>
Co-authored-by: Sascha Rogmann <59577610+srogmann@users.noreply.github.com>
This commit is contained in:
firecoperana
2026-02-13 12:04:55 -06:00
committed by GitHub
parent 1fdbc0dafe
commit 1cb7e1bf39
54 changed files with 2652 additions and 779 deletions

View File

@@ -6,8 +6,10 @@
#include <nlohmann/json.hpp>
using json = nlohmann::ordered_json;
struct llama_sampling_context * common_sampler_init(const struct llama_vocab* vocab, const struct llama_sampling_params & params) {
struct llama_sampling_context * result = new llama_sampling_context();
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
const llama_vocab * vocab = llama_model_get_vocab(model);
struct common_sampler * result = new common_sampler();
result->params = params;
result->grammar = nullptr;
@@ -107,7 +109,7 @@ struct llama_sampling_context * common_sampler_init(const struct llama_vocab* vo
return result;
}
void common_sampler_free(struct llama_sampling_context * ctx) {
void common_sampler_free(struct common_sampler * ctx) {
if (ctx->grammar != NULL) {
llama_grammar_free(ctx->grammar);
}
@@ -116,7 +118,7 @@ void common_sampler_free(struct llama_sampling_context * ctx) {
delete ctx;
}
static void llama_grammar_reset(llama_sampling_context * ctx) {
static void llama_grammar_reset(common_sampler * ctx) {
ctx->prev.clear();
if (!ctx->grammar) {
return;
@@ -135,19 +137,19 @@ static void llama_grammar_reset(llama_sampling_context * ctx) {
ctx->grammar = grammar_new;
}
void common_sampler_reset(const struct llama_vocab * vocab, llama_sampling_context * ctx) {
void common_sampler_reset(common_sampler * ctx) {
llama_grammar_reset(ctx);
llama_sampler_dry_reset(ctx->smpl);
}
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
void llama_sampling_set_rng_seed(struct common_sampler * ctx, uint32_t seed) {
if (seed == LLAMA_DEFAULT_SEED) {
seed = std::random_device{}();
}
ctx->rng.seed(seed);
}
void common_sampler_clone(llama_sampling_context * src, llama_sampling_context * dst) {
void common_sampler_clone(common_sampler * src, common_sampler * dst) {
if (dst->grammar) {
llama_grammar_free(dst->grammar);
dst->grammar = nullptr;
@@ -163,11 +165,11 @@ void common_sampler_clone(llama_sampling_context * src, llama_sampling_context *
dst->smpl = llama_sampler_dry_clone(src->smpl);
}
llama_token llama_sampling_last(llama_sampling_context * ctx) {
llama_token llama_sampling_last(common_sampler * ctx) {
return ctx->prev.back();
}
std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n) {
std::string llama_sampling_prev_str(common_sampler * ctx_sampling, llama_context * ctx_main, int n) {
const int size = ctx_sampling->prev.size();
n = std::min(n, size);
@@ -181,7 +183,7 @@ std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama
return result;
}
std::string llama_sampling_print(const llama_sampling_params & params) {
std::string llama_sampling_print(const common_params_sampling & params) {
char result[1024];
snprintf(result, sizeof(result),
@@ -199,7 +201,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
return std::string(result);
}
std::string llama_sampling_order_print(const llama_sampling_params & params) {
std::string llama_sampling_order_print(const common_params_sampling & params) {
std::string result = "CFG -> Penalties ";
if (params.mirostat == 0) {
for (auto sampler_type : params.samplers_sequence) {
@@ -315,8 +317,8 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin
// no reasons to expose this function in header
static void sampler_queue(
struct llama_context* ctx_main,
const llama_sampling_params& params,
llama_sampling_context * ctx_sampling,
const common_params_sampling& params,
common_sampler * ctx_sampling,
llama_token_data_array& cur_p,
size_t min_keep) {
const float temp = params.temp;
@@ -343,6 +345,7 @@ static void sampler_queue(
case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
case llama_sampler_type::XTC : llama_sample_xtc (ctx_main, &cur_p, xtc_probability, xtc_threshold, min_keep); break;
case llama_sampler_type::TOP_N_SIGMA: llama_sample_top_n_sigma(ctx_main, &cur_p, top_n_sigma); break;
case llama_sampler_type::DIST : llama_sample_dist (ctx_main, &cur_p); break;
case llama_sampler_type::TEMPERATURE:
if (dynatemp_range > 0) {
float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
@@ -364,12 +367,12 @@ static void sampler_queue(
}
static llama_token llama_sampling_sample_impl(
struct llama_sampling_context * ctx_sampling,
struct common_sampler * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx,
bool is_resampling) {
const llama_sampling_params & params = ctx_sampling->params;
const common_params_sampling & params = ctx_sampling->params;
const float temp = params.temp;
const int mirostat = params.mirostat;
@@ -378,7 +381,8 @@ static llama_token llama_sampling_sample_impl(
const float adaptive_target = params.adaptive_target;
std::vector<float> original_logits;
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
llama_token_data_array & cur_p = ctx_sampling->cur_p;
if (ctx_sampling->grammar != NULL && !is_resampling) {
GGML_ASSERT(!original_logits.empty());
}
@@ -414,22 +418,9 @@ static llama_token llama_sampling_sample_impl(
// temperature sampling
size_t min_keep = std::max(1, params.min_keep);
sampler_queue(ctx_main, params,ctx_sampling, cur_p, min_keep);
sampler_queue(ctx_main, params,ctx_sampling, cur_p, min_keep);
id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng);
//{
// const int n_top = 10;
// LOG("top %d candidates:\n", n_top);
// for (int i = 0; i < n_top; i++) {
// const llama_token id = cur_p.data[i].id;
// (void)id; // To avoid a warning that id is unused when logging is disabled.
// LOG(" - %5d: '%12s' (%.3f)\n", id, common_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p);
// }
//}
//LOG("sampled token: %5d: '%s'\n", id, common_token_to_piece(ctx_main, id).c_str());
}
}
@@ -457,20 +448,19 @@ static llama_token llama_sampling_sample_impl(
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ true);
}
}
ctx_sampling->n_valid = temp == 0.0f ? 0 : cur_p.size;
return id;
}
static llama_token_data_array llama_sampling_prepare_impl(
struct llama_sampling_context * ctx_sampling,
struct common_sampler * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx,
bool apply_grammar,
std::vector<float> * original_logits) {
const llama_sampling_params & params = ctx_sampling->params;
const common_params_sampling & params = ctx_sampling->params;
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
@@ -541,8 +531,8 @@ static llama_token_data_array llama_sampling_prepare_impl(
return cur_p;
}
llama_token common_sampler_sample(
struct llama_sampling_context * ctx_sampling,
llama_token common_sampler_sample_legacy(
struct common_sampler * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx) {
@@ -550,8 +540,17 @@ llama_token common_sampler_sample(
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ false);
}
llama_token common_sampler_sample(
struct common_sampler * ctx_sampling,
struct llama_context * ctx_main,
const int idx,
bool grammar_first) {
// Call the implementation function with is_resampling set to false by default
return llama_sampling_sample_impl(ctx_sampling, ctx_main, nullptr, idx, /* is_resampling= */ grammar_first);
}
llama_token_data_array llama_sampling_prepare(
struct llama_sampling_context * ctx_sampling,
struct common_sampler * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx,
@@ -561,7 +560,7 @@ llama_token_data_array llama_sampling_prepare(
}
void common_sampler_accept(
struct llama_sampling_context * ctx_sampling,
struct common_sampler * ctx_sampling,
struct llama_context * ctx_main,
llama_token id,
bool apply_grammar) {
@@ -579,11 +578,32 @@ void common_sampler_accept(
}
}
llama_token_data_array * common_sampler_get_candidates(struct llama_sampling_context * ctx_sampling) {
return &ctx_sampling->cur_p;
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
auto * res = &gsmpl->cur_p;
if (do_sort && !res->sorted) {
// remember the selected token before sorting
const llama_token id = res->data[res->selected].id;
std::sort(res->data, res->data + res->size, [](const llama_token_data & a, const llama_token_data & b) {
return a.p > b.p;
});
// restore the selected token after sorting
for (size_t i = 0; i < res->size; ++i) {
if (res->data[i].id == id) {
res->selected = i;
break;
}
}
res->sorted = true;
}
return res;
}
std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<llama_token> & draft) {
std::vector<llama_token> llama_sampling_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<llama_token> & draft) {
std::vector<int> idxs(draft.size() + 1);
for (size_t i = 0; i < idxs.size(); ++i) {
idxs[i] = i;
@@ -592,7 +612,7 @@ std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_samplin
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft);
}
std::vector<llama_token> common_sampler_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft) {
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft, bool grammar_first) {
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
std::vector<llama_token> result;
@@ -600,7 +620,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct llama_samplin
size_t i = 0;
for (; i < draft.size(); i++) {
const llama_token id = common_sampler_sample(gsmpl, ctx, nullptr, idxs[i]);
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
common_sampler_accept(gsmpl, ctx, id, true);
@@ -612,7 +632,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct llama_samplin
}
if (i == draft.size()) {
const llama_token id = common_sampler_sample(gsmpl, ctx, nullptr, idxs[i]);
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
common_sampler_accept(gsmpl, ctx, id, true);