finish porting speculative decoding in server

This commit is contained in:
T. M.
2025-07-25 04:00:02 +00:00
parent 99c1ef3c01
commit 642b70a64b
3 changed files with 27 additions and 49 deletions

View File

@@ -507,15 +507,12 @@ void llama_sampling_accept(
}
}
std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft) {
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<llama_token> & draft) {
std::vector<llama_token> result;
result.reserve(idxs.size());
size_t i = 0;
for (; i < draft.size(); i++) {
const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, idxs[i]);
const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, i);
llama_sampling_accept(gsmpl, ctx, id, true);
@@ -527,7 +524,7 @@ std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_samplin
}
if (i == draft.size()) {
const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, idxs[i]);
const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, i);
llama_sampling_accept(gsmpl, ctx, id, true);

View File

@@ -177,21 +177,6 @@ void llama_sampling_accept(
llama_token id,
bool apply_grammar);
// generalized version of common_sampler_sample
//
// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
// if the sampler disagrees at some point, we stop and return the accepted tokens up to now
//
// common_sampler_sample_n(gsmpl, ctx, { idx }, {});
//
// is equivalent to
//
// common_sampler_sample(gsmpl, ctx, idx);
// common_sampler_accept(gsmpl, token, true);
//
// requires: idxs.size() == draft.size() + 1
//
// returns at least 1 token, up to idxs.size()
//
std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft);
// returns at least 1 token, up to draft.size()
std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<llama_token> & draft);

View File

@@ -3,6 +3,7 @@
#include "common.h"
#include "speculative.h"
#include "sampling.h"
#include "json-schema-to-grammar.h"
#include "llama.h"
#include "grammar-parser.h"
@@ -819,7 +820,8 @@ struct server_context {
bool add_bos_token = true;
// For speculative decoding
llama_init_result model_dft_owned;
llama_model * model_draft = nullptr;
llama_context * ctx_draft = nullptr;
llama_context_params cparams_dft;
int32_t n_ctx; // total context for all clients / slots
@@ -853,6 +855,16 @@ struct server_context {
model = nullptr;
}
// Free draft model and context if they exist
if (ctx_draft) {
llama_free(ctx_draft);
ctx_draft = nullptr;
}
if (model_draft) {
llama_free_model(model_draft);
model_draft = nullptr;
}
// Clear any sampling context
for (server_slot & slot : slots) {
if (slot.ctx_sampling != nullptr) {
@@ -917,29 +929,13 @@ struct server_context {
return false;
}
// Store the draft context initialization parameters for later use
cparams_dft = llama_context_default_params();
cparams_dft.n_ctx = params_dft.n_ctx;
cparams_dft.n_batch = cparams_dft.n_ctx;
cparams_dft.n_ubatch = params_dft.n_ubatch;
cparams_dft.freq_base = params_dft.rope_freq_base;
cparams_dft.freq_scale = params_dft.rope_freq_scale;
cparams_dft.yarn_ext_factor = params_dft.yarn_ext_factor;
cparams_dft.yarn_attn_factor = params_dft.yarn_attn_factor;
cparams_dft.yarn_beta_fast = params_dft.yarn_beta_fast;
cparams_dft.yarn_beta_slow = params_dft.yarn_beta_slow;
cparams_dft.yarn_orig_ctx = params_dft.yarn_orig_ctx;
cparams_dft.clip_kqv = params_dft.clip_kqv;
cparams_dft.pooling_type = params_dft.pooling_type;
cparams_dft.defrag_thold = params_dft.defrag_thold;
cparams_dft.type_k = params_dft.type_k;
cparams_dft.type_v = params_dft.type_v;
cparams_dft.logits_all = false;
cparams_dft.embedding = false;
cparams_dft.offload_kqv = params_dft.offload_kqv;
const int n_ctx_dft = llama_n_ctx(llama_init_dft.context);
// Keep the draft model alive
model_dft_owned = llama_init_dft;
cparams_dft = llama_context_params_from_gpt_params(params_dft);
cparams_dft.n_batch = n_ctx_dft;
model_draft = llama_init_dft.model;
ctx_draft = llama_init_dft.context;
}
return true;
@@ -993,10 +989,10 @@ struct server_context {
slot.sparams = params.sparams;
// Initialize speculative decoding if a draft model is loaded
if (model_dft_owned.context) {
if (ctx_draft) {
slot.batch_spec = llama_batch_init(params.n_draft + 1, 0, 1);
slot.ctx_dft = llama_init_from_model(model_dft_owned.model, cparams_dft);
slot.ctx_dft = llama_new_context_with_model(model_draft, cparams_dft);
if (slot.ctx_dft == nullptr) {
LOG_ERROR("failed to create draft context", {});
return;
@@ -2906,7 +2902,7 @@ struct server_context {
result.tok = ids[i];
result.text_to_send = llama_token_to_piece(ctx, result.tok, params.special);
result.prob = 1.0f; // set later
// result.prob = 1.0f; // set later
if (!process_token(result, slot)) {
// release slot because of stop condition