spec : add self speculative decoding, ngram and refactor (#1261)

* spec : add self speculative decoding and ngram-mod and refactor

common : use common_ prefix for common library function

llama : use LLAMA_TOKEN_NULL

spec : add self speculative decoding (no draft model required) + refactor

spec : add ngram-mod

spec : various improvements ton ngram-map + docs

spec : fix the check-rate logic of ngram-simple

common : add common_speculative_is_compat()

spec : simplify time measurement using common_time_meas

refactor common_sampler_init

refactor common_token_to_piece

refactor and fix cur_p bug

clean up

* spec : remove check rate

* spec: show warnings instead of abort

---------

Co-authored-by: firecoperana <firecoperana>
Co-authored-by: Sascha Rogmann <59577610+srogmann@users.noreply.github.com>
This commit is contained in:
firecoperana
2026-02-13 12:04:55 -06:00
committed by GitHub
parent 1fdbc0dafe
commit 1cb7e1bf39
54 changed files with 2652 additions and 779 deletions

View File

@@ -72,7 +72,7 @@ struct client {
std::string prompt;
std::string response;
struct llama_sampling_context * ctx_sampling = nullptr;
struct common_sampler * ctx_sampling = nullptr;
};
static void print_date_time() {
@@ -161,11 +161,11 @@ int main(int argc, char ** argv) {
for (size_t i = 0; i < clients.size(); ++i) {
auto & client = clients[i];
client.id = i;
client.ctx_sampling = common_sampler_init(llama_get_model_vocab(model), params.sparams);
client.ctx_sampling = common_sampler_init(model, params.sparams);
}
std::vector<llama_token> tokens_system;
tokens_system = ::llama_tokenize(ctx, k_system, true);
tokens_system = ::common_tokenize(ctx, k_system, true);
const int32_t n_tokens_system = tokens_system.size();
llama_seq_id g_seq_id = 0;
@@ -253,11 +253,11 @@ int main(int argc, char ** argv) {
client.prompt = client.input + "\nAssistant:";
client.response = "";
common_sampler_reset(llama_get_model_vocab(model), client.ctx_sampling);
common_sampler_reset(client.ctx_sampling);
// do not prepend BOS because we have a system prompt!
std::vector<llama_token> tokens_prompt;
tokens_prompt = ::llama_tokenize(ctx, client.prompt, false);
tokens_prompt = ::common_tokenize(ctx, client.prompt, false);
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
common_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false);
@@ -341,7 +341,7 @@ int main(int argc, char ** argv) {
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
const llama_token id = common_sampler_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
const llama_token id = common_sampler_sample(client.ctx_sampling, ctx, client.i_batch - i);
common_sampler_accept(client.ctx_sampling, ctx, id, true);