mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
server: improve speed of speculative decoding (#1119)
* server: improve speed of speculative decoding change logs rpc: add recompute spec dec fix * Fix n_batch_size not set to context size for draft model --------- Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -484,7 +484,7 @@ bool server_sent_event(httplib::DataSink& sink, const json& data) {
|
||||
data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
||||
"\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row).
|
||||
|
||||
LOG_VERBOSE("data stream, to_send: %s", str.c_str());
|
||||
//LOG_VERBOSE("data stream, to_send: %s", str.c_str());
|
||||
|
||||
return sink.write(str.c_str(), str.size());
|
||||
}
|
||||
|
||||
@@ -336,6 +336,10 @@ public:
|
||||
|
||||
llama_pos pos_next() const;
|
||||
|
||||
int n_tokens() const {
|
||||
return tokens.size();
|
||||
}
|
||||
|
||||
// for debugging
|
||||
std::string str() const;
|
||||
|
||||
|
||||
@@ -117,12 +117,13 @@ bool server_context::load_model(const gpt_params& params_) {
|
||||
LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n");
|
||||
|
||||
gpt_params params_dft;
|
||||
params_dft.devices = params.devices_draft;
|
||||
params_dft.model = params.model_draft;
|
||||
params_dft.devices = params.devices_draft;
|
||||
params_dft.model = params.model_draft;
|
||||
params_dft.n_gpu_layers = params.n_gpu_layers_draft;
|
||||
params_dft.rpc_servers = params.rpc_servers;
|
||||
params_dft.cache_type_k = params.cache_type_k_draft.empty() ? params.cache_type_k : params.cache_type_k_draft;
|
||||
params_dft.cache_type_v = params.cache_type_v_draft.empty() ? params.cache_type_v : params.cache_type_v_draft;
|
||||
params_dft.flash_attn = params.flash_attn;
|
||||
params_dft.flash_attn = params.flash_attn;
|
||||
if (!params.draft_params.empty()) {
|
||||
auto [argc, argv] = parse_command_line("llama-server " + params.draft_params);
|
||||
if (!gpt_params_parse(argc, argv, params_dft)) {
|
||||
@@ -138,7 +139,7 @@ bool server_context::load_model(const gpt_params& params_) {
|
||||
}
|
||||
params_dft.n_ctx = params_dft.n_ctx == 0 ? params.n_ctx / params.n_parallel : params_dft.n_ctx;
|
||||
params_dft.n_parallel = 1;
|
||||
|
||||
params_dft.n_batch = params_dft.n_ctx;
|
||||
llama_init_result llama_init_dft = llama_init_from_gpt_params(params_dft);
|
||||
|
||||
llama_model* model_dft = llama_init_dft.model;
|
||||
@@ -154,7 +155,6 @@ bool server_context::load_model(const gpt_params& params_) {
|
||||
const int n_ctx_dft = llama_n_ctx(llama_init_dft.context);
|
||||
|
||||
cparams_dft = llama_context_params_from_gpt_params(params_dft);
|
||||
cparams_dft.n_batch = n_ctx_dft;
|
||||
|
||||
model_draft = llama_init_dft.model;
|
||||
ctx_draft = llama_init_dft.context;
|
||||
@@ -317,6 +317,10 @@ void server_slot::reset() {
|
||||
stopping_word = "";
|
||||
n_past = 0;
|
||||
n_sent_text = 0;
|
||||
|
||||
drafted.clear();
|
||||
i_batch_dft.clear();
|
||||
|
||||
n_sent_token_probs = 0;
|
||||
infill = false;
|
||||
ga_i = 0;
|
||||
@@ -368,6 +372,31 @@ void server_slot::add_token_string(const completion_token_output& token) {
|
||||
generated_token_probs.push_back(token);
|
||||
}
|
||||
|
||||
int server_slot::get_n_draft_max() const {
|
||||
if (!ctx_dft) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// determine the max draft that fits the current slot state
|
||||
int n_draft_max = params.speculative.n_max;
|
||||
|
||||
// note: slot.prompt is not yet expanded with the `id` token sampled above
|
||||
// also, need to leave space for 1 extra token to allow context shifts
|
||||
n_draft_max = std::min(n_draft_max, n_ctx - n_past - 2);
|
||||
|
||||
if (n_remaining > 0) {
|
||||
n_draft_max = std::min(n_draft_max, n_remaining - 1);
|
||||
}
|
||||
|
||||
SLT_DBG(*this, "max possible draft: %d\n", n_draft_max);
|
||||
|
||||
if (n_draft_max < params.speculative.n_min) {
|
||||
SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, params.speculative.n_min);
|
||||
n_draft_max = 0;
|
||||
}
|
||||
return n_draft_max;
|
||||
}
|
||||
|
||||
void server_slot::release() {
|
||||
if (state == SLOT_STATE_PROCESSING) {
|
||||
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
|
||||
@@ -468,48 +497,43 @@ size_t server_slot::find_stopping_strings(const std::string& text, const size_t
|
||||
|
||||
void server_slot::print_timings() const {
|
||||
char buffer[512];
|
||||
|
||||
double t_token = t_prompt_processing / n_prompt_tokens_processed;
|
||||
double n_tokens_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
|
||||
|
||||
snprintf(buffer, 512, "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)",
|
||||
t_prompt_processing, n_prompt_tokens_processed,
|
||||
t_token, n_tokens_second);
|
||||
//snprintf(buffer, 512, "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)",
|
||||
// t_prompt_processing, n_prompt_tokens_processed,
|
||||
// t_token, n_tokens_second);
|
||||
|
||||
LOG_INFO(buffer, {
|
||||
{"id_slot", id},
|
||||
{"id_task", id_task},
|
||||
{"t_prompt_processing", t_prompt_processing},
|
||||
{"n_prompt_tokens_processed", n_prompt_tokens_processed},
|
||||
{"t_token", t_token},
|
||||
{"n_tokens_second", n_tokens_second},
|
||||
});
|
||||
//LOG_INFO(buffer, {});
|
||||
|
||||
t_token = t_token_generation / n_decoded;
|
||||
n_tokens_second = 1e3 / t_token_generation * n_decoded;
|
||||
double t_token_gen = t_token_generation / n_decoded;
|
||||
double n_tokens_second_gen = 1e3 / t_token_generation * n_decoded;
|
||||
|
||||
snprintf(buffer, 512, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)",
|
||||
t_token_generation, n_decoded,
|
||||
t_token, n_tokens_second);
|
||||
//snprintf(buffer, 512, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)",
|
||||
// t_token_generation, n_decoded,
|
||||
// t_token, n_tokens_second);
|
||||
|
||||
LOG_INFO(buffer, {
|
||||
{"id_slot", id},
|
||||
{"id_task", id_task},
|
||||
{"t_token_generation", t_token_generation},
|
||||
{"n_decoded", n_decoded},
|
||||
{"t_token", t_token},
|
||||
{"n_tokens_second", n_tokens_second},
|
||||
});
|
||||
//LOG_INFO(buffer, {});
|
||||
|
||||
snprintf(buffer, 512, " total time = %10.2f ms", t_prompt_processing + t_token_generation);
|
||||
//snprintf(buffer, 512, " total time = %10.2f ms", t_prompt_processing + t_token_generation);
|
||||
|
||||
LOG_INFO(buffer, {
|
||||
{"id_slot", id},
|
||||
{"id_task", id_task},
|
||||
{"t_prompt_processing", t_prompt_processing},
|
||||
{"t_token_generation", t_token_generation},
|
||||
{"t_total", t_prompt_processing + t_token_generation},
|
||||
});
|
||||
//LOG_INFO(buffer, {});
|
||||
SLT_INF(*this,
|
||||
"\n"
|
||||
"prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
|
||||
" eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
|
||||
" total time = %10.2f ms / %5d tokens\n",
|
||||
t_prompt_processing, n_prompt_tokens_processed, t_token, n_tokens_second,
|
||||
t_token_generation, n_decoded, t_token_gen, n_tokens_second_gen,
|
||||
t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
|
||||
|
||||
if (n_draft_total > 0) {
|
||||
const float draft_ratio = (float)n_draft_accepted / n_draft_total;
|
||||
SLT_CNT(*this,
|
||||
"draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n",
|
||||
draft_ratio, n_draft_accepted, n_draft_total
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
void server_metrics::init() {
|
||||
@@ -2173,31 +2197,62 @@ void server_context::update_slots() {
|
||||
continue;
|
||||
}
|
||||
|
||||
slot.i_batch = batch.n_tokens;
|
||||
// generate draft tokens in speculative decoding mode
|
||||
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
|
||||
// perform the speculative drafting for all sequences at the same time in a single batch
|
||||
int n_draft_max = slot.get_n_draft_max();
|
||||
if (n_draft_max > 0) {
|
||||
if (mctx) {
|
||||
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
|
||||
GGML_ABORT("not supported by multimodal");
|
||||
}
|
||||
|
||||
const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
|
||||
struct llama_speculative_params params_spec;
|
||||
params_spec.n_draft = n_draft_max;
|
||||
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
|
||||
params_spec.p_min = slot.params.speculative.p_min;
|
||||
const llama_tokens& cached_text_tokens = slot.cache_tokens.get_text_tokens();
|
||||
llama_tokens draft = llama_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
|
||||
|
||||
// TODO: we always have to take into account the "system_tokens"
|
||||
// this is not great and needs to be improved somehow
|
||||
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.cache_tokens.pos_next(), { slot.id }, true);
|
||||
|
||||
slot.n_past += 1;
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
// add the sampled token to the batch
|
||||
slot.i_batch_dft.push_back(batch.n_tokens);
|
||||
llama_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
|
||||
slot.cache_tokens.push_back(slot.sampled);
|
||||
|
||||
if (slot.params.speculative.n_min > (int)draft.size()) {
|
||||
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int)draft.size(), slot.params.speculative.n_min);
|
||||
// fallback to normal decoding
|
||||
slot.i_batch = slot.i_batch_dft[0];
|
||||
slot.drafted.clear();
|
||||
slot.i_batch_dft.clear();
|
||||
}
|
||||
else {
|
||||
// keep track of total number of drafted tokens tested
|
||||
slot.n_draft_total += draft.size();
|
||||
|
||||
// add all drafted tokens to the batch
|
||||
for (size_t i = 0; i < draft.size(); i++) {
|
||||
slot.i_batch_dft.push_back(batch.n_tokens);
|
||||
llama_batch_add(batch, draft[i], slot.cache_tokens.pos_next(), { slot.id }, true);
|
||||
slot.cache_tokens.push_back(draft[i]);
|
||||
}
|
||||
slot.drafted = std::move(draft);
|
||||
}
|
||||
}
|
||||
else {
|
||||
// no speculative decoding
|
||||
slot.i_batch = batch.n_tokens;
|
||||
|
||||
LOG_VERBOSE("slot decode token", {
|
||||
{"id_slot", slot.id},
|
||||
{"id_task", slot.id_task},
|
||||
{"n_ctx", n_ctx},
|
||||
{"n_past", slot.n_past},
|
||||
{"n_system_tokens", system_tokens.size()},
|
||||
{"n_cache_tokens", slot.cache_tokens.size()},
|
||||
{"truncated", slot.truncated}
|
||||
});
|
||||
llama_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
|
||||
|
||||
slot.cache_tokens.push_back(slot.sampled);
|
||||
|
||||
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
|
||||
slot.n_ctx, slot.cache_tokens.size(), slot.truncated);
|
||||
}
|
||||
slot.n_past = slot.cache_tokens.n_tokens();
|
||||
}
|
||||
|
||||
|
||||
// process in chunks of params.n_batch
|
||||
int32_t n_batch = llama_n_batch(ctx);
|
||||
int32_t n_ubatch = llama_n_ubatch(ctx);
|
||||
@@ -2638,6 +2693,10 @@ void server_context::update_slots() {
|
||||
continue; // continue loop of n_batch
|
||||
}
|
||||
|
||||
// technically, measuring the time here excludes the sampling time for the last batch
|
||||
// but on the other hand, we don't want to do too many system calls to measure the time, so it's ok
|
||||
const int64_t t_current = ggml_time_us();
|
||||
|
||||
for (auto& slot : slots) {
|
||||
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) {
|
||||
continue; // continue loop of slots
|
||||
@@ -2652,6 +2711,9 @@ void server_context::update_slots() {
|
||||
}
|
||||
|
||||
completion_token_output result;
|
||||
if (slot.i_batch_dft.size() > 0) {
|
||||
continue; // sample using speculative decoding
|
||||
}
|
||||
const int tok_idx = slot.i_batch - i;
|
||||
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, tok_idx);
|
||||
|
||||
@@ -2667,7 +2729,8 @@ void server_context::update_slots() {
|
||||
metrics.on_prompt_eval(slot);
|
||||
}
|
||||
|
||||
slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;
|
||||
//slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;
|
||||
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
||||
|
||||
result.tok = id;
|
||||
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
|
||||
@@ -2687,92 +2750,34 @@ void server_context::update_slots() {
|
||||
slot.i_batch = -1;
|
||||
}
|
||||
|
||||
// Do speculative decoding
|
||||
// speculative decoding - main model sample and accept
|
||||
for (auto& slot : slots) {
|
||||
if (!slot.is_processing() || !slot.spec) {
|
||||
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch_dft.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (slot.state != SLOT_STATE_PROCESSING) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (mctx) {
|
||||
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
|
||||
GGML_ABORT("not supported by multimodal");
|
||||
}
|
||||
|
||||
// determine the max draft that fits the current slot state
|
||||
int n_draft_max = slot.params.speculative.n_max;
|
||||
|
||||
// note: n_past is not yet increased for the `id` token sampled above
|
||||
// also, need to leave space for 1 extra token to allow context shifts
|
||||
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
|
||||
|
||||
if (slot.n_predict > 0) {
|
||||
n_draft_max = std::min(n_draft_max, slot.n_predict - slot.n_decoded - 1);
|
||||
}
|
||||
|
||||
LOG_VERBOSE("max possible draft", {
|
||||
{"id_slot", slot.id},
|
||||
{"n_draft_max", n_draft_max}
|
||||
});
|
||||
|
||||
if (n_draft_max < slot.params.speculative.n_min) {
|
||||
LOG_VERBOSE("the max possible draft is too small", {
|
||||
{"id_slot", slot.id},
|
||||
{"n_draft_max", n_draft_max},
|
||||
{"n_min", slot.params.speculative.n_min}
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
llama_token id = slot.sampled;
|
||||
|
||||
struct llama_speculative_params params_spec;
|
||||
params_spec.n_draft = n_draft_max;
|
||||
params_spec.n_reuse = cparams_dft.n_ctx - slot.params.speculative.n_max;
|
||||
params_spec.p_min = slot.params.speculative.p_min;
|
||||
|
||||
const std::vector<llama_token>& cached_text_tokens = slot.cache_tokens.tokens_data();
|
||||
std::vector<llama_token> draft = llama_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
|
||||
|
||||
// ignore small drafts
|
||||
if (slot.params.speculative.n_min > (int)draft.size()) {
|
||||
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int)draft.size(), slot.params.speculative.n_min);
|
||||
continue;
|
||||
}
|
||||
|
||||
// keep track of total number of drafted tokens tested
|
||||
slot.n_draft_total += draft.size();
|
||||
|
||||
// construct the speculation batch
|
||||
llama_batch_clear(slot.batch_spec);
|
||||
llama_batch_add(slot.batch_spec, id, slot.cache_tokens.pos_next(), { slot.id }, true);
|
||||
|
||||
for (size_t i = 0; i < draft.size(); ++i) {
|
||||
llama_batch_add(slot.batch_spec, draft[i], slot.cache_tokens.pos_next() + 1 + i, { slot.id }, true);
|
||||
}
|
||||
|
||||
LOG_VERBOSE("decoding speculative batch", {
|
||||
{"id_slot", slot.id},
|
||||
{"size", slot.batch_spec.n_tokens}
|
||||
});
|
||||
|
||||
llama_decode(ctx, slot.batch_spec);
|
||||
size_t n_draft = slot.drafted.size();
|
||||
|
||||
// the accepted tokens from the speculation
|
||||
std::vector<llama_token> ids = llama_sampling_sample_and_accept_n(slot.ctx_sampling, ctx, draft);
|
||||
const auto ids = llama_sampling_sample_and_accept_n(slot.ctx_sampling, ctx, slot.i_batch_dft, slot.drafted);
|
||||
slot.i_batch_dft.clear();
|
||||
slot.drafted.clear();
|
||||
|
||||
slot.n_past += ids.size();
|
||||
slot.n_decoded += ids.size();
|
||||
|
||||
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
||||
|
||||
// update how many tokens out of those tested were accepted
|
||||
slot.n_draft_accepted += ids.size() - 1;
|
||||
|
||||
slot.cache_tokens.push_back(id);
|
||||
// rollback to the state before sampling the draft tokens
|
||||
slot.cache_tokens.keep_first(slot.cache_tokens.n_tokens() - n_draft);
|
||||
// slot.n_past -= n_draft;
|
||||
// add accepted tokens to the prompt
|
||||
slot.cache_tokens.insert({ ids.begin(), ids.end() - 1 });
|
||||
|
||||
slot.sampled = ids.back(); // last accepted token
|
||||
slot.n_past = slot.cache_tokens.n_tokens();
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
|
||||
|
||||
for (size_t i = 0; i < ids.size(); ++i) {
|
||||
@@ -2795,11 +2800,11 @@ void server_context::update_slots() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int)ids.size() - 1, (int)slot.drafted.size(), slot.n_past);
|
||||
LOG_VERBOSE("speculative decoding result", {
|
||||
{"id_slot", slot.id},
|
||||
{"accepted", (int)ids.size() - 1},
|
||||
{"total", (int)draft.size()},
|
||||
{"total", (int)slot.drafted.size()},
|
||||
{"new_n_past", slot.n_past}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -98,6 +98,11 @@ struct server_slot {
|
||||
|
||||
std::string generated_text;
|
||||
|
||||
// idx of draft tokens in the main batch
|
||||
// non-empty if we went to evaluate draft tokens
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/17808
|
||||
std::vector<int32_t> i_batch_dft;
|
||||
|
||||
std::vector<completion_token_output> generated_token_probs;
|
||||
common_chat_msg chat_msg;
|
||||
|
||||
@@ -122,7 +127,9 @@ struct server_slot {
|
||||
void prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens);
|
||||
|
||||
// sampling
|
||||
llama_token sampled;
|
||||
llama_token sampled; // in speculative mode, this is the last accepted token
|
||||
llama_tokens drafted;
|
||||
|
||||
struct llama_sampling_params sparams;
|
||||
llama_sampling_context* ctx_sampling = nullptr;
|
||||
json json_schema;
|
||||
@@ -168,6 +175,8 @@ struct server_slot {
|
||||
|
||||
void add_token_string(const completion_token_output& token);
|
||||
|
||||
int get_n_draft_max() const;
|
||||
|
||||
void release();
|
||||
|
||||
json get_formated_timings() const;
|
||||
|
||||
Reference in New Issue
Block a user