Port speculative decoding from upstream to llama-server (#645)

* server : integrate speculative decoding

* server: Fix field names

* server: fix include, whitespace

* fix compile errors in speculative.cpp

* add llama_sampling_sample_and_accept_n to sampling

* finish porting speculative decoding in server

* port functions from common/speculative, common/sampling

* remove arg

* fix function names

* init params_dft to none

* correct value for n_ctx

* prefix kv cache tensors with model name to avoid conflict

* fix call arguments

* fix spec decoding args

* correct slot.id

* use n_max

* port the rest of sampling funcs

* fix func arguments

* slot.id starts at 1?

* Revert "prefix kv cache tensors with model name to avoid conflict"

This reverts commit fbd5dfd866.

* disable draft logging

* disable logging in speculative.cpp

in mainline, these would be LOG_DEBUG, but since ik_llama doesnt support
it, logging is disabled entirely

* add more draft model parameters

* fix

* pass flash_attn

* add speculative params for parity

* set speculative params in launch_slot_with_task instead
This commit is contained in:
g2mt
2025-08-15 21:26:44 -07:00
committed by GitHub
parent 2e2abddaa8
commit b6bc5eedad
8 changed files with 655 additions and 41 deletions

View File

@@ -101,6 +101,8 @@ struct llama_sampling_context {
size_t n_valid; // Number of correct top tokens with correct probabilities.
llama_token_data_array cur_p; // current candidates
std::mt19937 rng;
};
@@ -176,3 +178,11 @@ void llama_sampling_accept(
struct llama_context * ctx_main,
llama_token id,
bool apply_grammar);
// returns at least 1 token, up to draft.size()
// access the internal list of current candidate tokens
llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling_context * ctx_sampling);
std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<llama_token> & draft);
std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft);