From 642b70a64b63a4f89dd454a9d73e2af76fbe5bfa Mon Sep 17 00:00:00 2001 From: "T. M." Date: Fri, 25 Jul 2025 04:00:02 +0000 Subject: [PATCH] finish porting speculative decoding in server --- common/sampling.cpp | 9 +++---- common/sampling.h | 19 ++------------- examples/server/server.cpp | 48 +++++++++++++++++--------------------- 3 files changed, 27 insertions(+), 49 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index bd915626..7d460b57 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -507,15 +507,12 @@ void llama_sampling_accept( } } -std::vector llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & idxs, const std::vector & draft) { - GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); - +std::vector llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & draft) { std::vector result; - result.reserve(idxs.size()); size_t i = 0; for (; i < draft.size(); i++) { - const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, idxs[i]); + const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, i); llama_sampling_accept(gsmpl, ctx, id, true); @@ -527,7 +524,7 @@ std::vector llama_sampling_sample_and_accept_n(struct llama_samplin } if (i == draft.size()) { - const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, idxs[i]); + const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, i); llama_sampling_accept(gsmpl, ctx, id, true); diff --git a/common/sampling.h b/common/sampling.h index 2517daee..405f5a63 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -177,21 +177,6 @@ void llama_sampling_accept( llama_token id, bool apply_grammar); -// generalized version of common_sampler_sample -// -// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match -// if the sampler disagrees at some point, we stop and return the accepted tokens up to now -// -// common_sampler_sample_n(gsmpl, ctx, { idx }, {}); -// -// is equivalent to -// -// common_sampler_sample(gsmpl, ctx, idx); -// common_sampler_accept(gsmpl, token, true); -// -// requires: idxs.size() == draft.size() + 1 -// -// returns at least 1 token, up to idxs.size() -// -std::vector llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & idxs, const std::vector & draft); +// returns at least 1 token, up to draft.size() +std::vector llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & draft); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index d724bcf0..ad934137 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3,6 +3,7 @@ #include "common.h" #include "speculative.h" +#include "sampling.h" #include "json-schema-to-grammar.h" #include "llama.h" #include "grammar-parser.h" @@ -819,7 +820,8 @@ struct server_context { bool add_bos_token = true; // For speculative decoding - llama_init_result model_dft_owned; + llama_model * model_draft = nullptr; + llama_context * ctx_draft = nullptr; llama_context_params cparams_dft; int32_t n_ctx; // total context for all clients / slots @@ -853,6 +855,16 @@ struct server_context { model = nullptr; } + // Free draft model and context if they exist + if (ctx_draft) { + llama_free(ctx_draft); + ctx_draft = nullptr; + } + if (model_draft) { + llama_free_model(model_draft); + model_draft = nullptr; + } + // Clear any sampling context for (server_slot & slot : slots) { if (slot.ctx_sampling != nullptr) { @@ -917,29 +929,13 @@ struct server_context { return false; } - // Store the draft context initialization parameters for later use - cparams_dft = llama_context_default_params(); - cparams_dft.n_ctx = params_dft.n_ctx; - cparams_dft.n_batch = cparams_dft.n_ctx; - cparams_dft.n_ubatch = params_dft.n_ubatch; - cparams_dft.freq_base = params_dft.rope_freq_base; - cparams_dft.freq_scale = params_dft.rope_freq_scale; - cparams_dft.yarn_ext_factor = params_dft.yarn_ext_factor; - cparams_dft.yarn_attn_factor = params_dft.yarn_attn_factor; - cparams_dft.yarn_beta_fast = params_dft.yarn_beta_fast; - cparams_dft.yarn_beta_slow = params_dft.yarn_beta_slow; - cparams_dft.yarn_orig_ctx = params_dft.yarn_orig_ctx; - cparams_dft.clip_kqv = params_dft.clip_kqv; - cparams_dft.pooling_type = params_dft.pooling_type; - cparams_dft.defrag_thold = params_dft.defrag_thold; - cparams_dft.type_k = params_dft.type_k; - cparams_dft.type_v = params_dft.type_v; - cparams_dft.logits_all = false; - cparams_dft.embedding = false; - cparams_dft.offload_kqv = params_dft.offload_kqv; + const int n_ctx_dft = llama_n_ctx(llama_init_dft.context); - // Keep the draft model alive - model_dft_owned = llama_init_dft; + cparams_dft = llama_context_params_from_gpt_params(params_dft); + cparams_dft.n_batch = n_ctx_dft; + + model_draft = llama_init_dft.model; + ctx_draft = llama_init_dft.context; } return true; @@ -993,10 +989,10 @@ struct server_context { slot.sparams = params.sparams; // Initialize speculative decoding if a draft model is loaded - if (model_dft_owned.context) { + if (ctx_draft) { slot.batch_spec = llama_batch_init(params.n_draft + 1, 0, 1); - slot.ctx_dft = llama_init_from_model(model_dft_owned.model, cparams_dft); + slot.ctx_dft = llama_new_context_with_model(model_draft, cparams_dft); if (slot.ctx_dft == nullptr) { LOG_ERROR("failed to create draft context", {}); return; @@ -2906,7 +2902,7 @@ struct server_context { result.tok = ids[i]; result.text_to_send = llama_token_to_piece(ctx, result.tok, params.special); - result.prob = 1.0f; // set later + // result.prob = 1.0f; // set later if (!process_token(result, slot)) { // release slot because of stop condition