mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-11 00:40:09 +00:00
Port speculative decoding from upstream to llama-server (#645)
* server : integrate speculative decoding
* server: Fix field names
* server: fix include, whitespace
* fix compile errors in speculative.cpp
* add llama_sampling_sample_and_accept_n to sampling
* finish porting speculative decoding in server
* port functions from common/speculative, common/sampling
* remove arg
* fix function names
* init params_dft to none
* correct value for n_ctx
* prefix kv cache tensors with model name to avoid conflict
* fix call arguments
* fix spec decoding args
* correct slot.id
* use n_max
* port the rest of sampling funcs
* fix func arguments
* slot.id starts at 1?
* Revert "prefix kv cache tensors with model name to avoid conflict"
This reverts commit fbd5dfd866.
* disable draft logging
* disable logging in speculative.cpp
in mainline, these would be LOG_DEBUG, but since ik_llama doesnt support
it, logging is disabled entirely
* add more draft model parameters
* fix
* pass flash_attn
* add speculative params for parity
* set speculative params in launch_slot_with_task instead
This commit is contained in:
@@ -442,7 +442,9 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
||||
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
||||
ctx_sampling->cur_p = { cur.data(), cur.size(), false };
|
||||
|
||||
llama_token_data_array & cur_p = ctx_sampling->cur_p;
|
||||
|
||||
// apply penalties
|
||||
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
|
||||
@@ -506,3 +508,47 @@ void llama_sampling_accept(
|
||||
llama_sampler_dry_accept(ctx_sampling->smpl, id);
|
||||
}
|
||||
}
|
||||
|
||||
llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling_context * ctx_sampling) {
|
||||
return &ctx_sampling->cur_p;
|
||||
}
|
||||
|
||||
std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<llama_token> & draft) {
|
||||
std::vector<int> idxs(draft.size() + 1);
|
||||
for (size_t i = 0; i < idxs.size(); ++i) {
|
||||
idxs[i] = i;
|
||||
}
|
||||
|
||||
return llama_sampling_sample_and_accept_n(gsmpl, ctx, idxs, draft);
|
||||
}
|
||||
|
||||
std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft) {
|
||||
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
|
||||
|
||||
std::vector<llama_token> result;
|
||||
result.reserve(idxs.size());
|
||||
|
||||
size_t i = 0;
|
||||
for (; i < draft.size(); i++) {
|
||||
const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, idxs[i]);
|
||||
|
||||
llama_sampling_accept(gsmpl, ctx, id, true);
|
||||
|
||||
result.push_back(id);
|
||||
|
||||
if (draft[i] != id) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (i == draft.size()) {
|
||||
const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, idxs[i]);
|
||||
|
||||
llama_sampling_accept(gsmpl, ctx, id, true);
|
||||
|
||||
result.push_back(id);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user