mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-23 14:44:09 +00:00
Graph reuse: add command line argument to turn it on
This commit is contained in:
@@ -1180,6 +1180,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
params.rope_cache = true;
|
||||
return true;
|
||||
}
|
||||
if (arg == "-reuse" || arg == "--graph-reuse") {
|
||||
params.graph_reuse = true;
|
||||
return true;
|
||||
}
|
||||
if (arg == "-ser" || arg == "--smart-expert-reduction") {
|
||||
CHECK_ARG
|
||||
auto values = string_split_pairs<int,float>(argv[i], ',');
|
||||
@@ -2004,6 +2008,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
||||
options.push_back({ "*", "-no-fug, --no-fused-up-gate", "disaable fused up-gate (default: %s)", params.fused_up_gate ? "enabled" : "disabled" });
|
||||
options.push_back({ "*", "-no-mmad, --no-fused-mul-multiadd", "disaable fused mul-multi_add (default: %s)", params.fused_mmad? "enabled" : "disabled" });
|
||||
options.push_back({ "*", "-rcache, --rope-cache", "enable RoPE cache (default: %s)", params.rope_cache ? "enabled" : "disabled" });
|
||||
options.push_back({ "*", "-reuse, --graph-reuse", "enable graph reuse (default: %s)", params.graph_reuse ? "enabled" : "disabled" });
|
||||
options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts});
|
||||
options.push_back({ "*", "-mqkv, --merge-qkv,", "merge Q,K,V (default: %d)", params.merge_qkv});
|
||||
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
|
||||
@@ -2979,6 +2984,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
||||
cparams.fused_up_gate = params.fused_up_gate;
|
||||
cparams.fused_mmad = params.fused_mmad;
|
||||
cparams.rope_cache = params.rope_cache;
|
||||
cparams.graph_reuse = params.graph_reuse;
|
||||
cparams.min_experts = params.min_experts;
|
||||
cparams.thresh_experts = params.thresh_experts;
|
||||
cparams.only_active_experts = params.only_active_exps;
|
||||
@@ -4123,7 +4129,8 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
|
||||
fprintf(stream, "grouped_expert_routing: %s # default: false\n", params.grouped_expert_routing ? "true" : "false");
|
||||
fprintf(stream, "fused_up_gate: %s # default: true\n", params.fused_up_gate ? "true" : "false");
|
||||
fprintf(stream, "fused_mmad: %s # default: true\n", params.fused_mmad ? "true" : "false");
|
||||
fprintf(stream, "rope_cache: %s # default: true\n", params.rope_cache ? "true" : "false");
|
||||
fprintf(stream, "rope_cache: %s # default: false\n", params.rope_cache ? "true" : "false");
|
||||
fprintf(stream, "graph_reuse: %s # default: false\n", params.graph_reuse ? "true" : "false");
|
||||
fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts);
|
||||
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
|
||||
|
||||
|
||||
@@ -254,6 +254,7 @@ struct gpt_params {
|
||||
bool fused_mmad = true; // fused mul+multi_add op
|
||||
bool grouped_expert_routing = false; // if to use grouped expert routing (BailingMoeV2 arch)
|
||||
bool rope_cache = false; // if to use RoPE cache (for supported models)
|
||||
bool graph_reuse = false; // if to reuse compute graphs
|
||||
int min_experts = -1;
|
||||
float thresh_experts = 0;
|
||||
|
||||
|
||||
@@ -429,6 +429,7 @@ extern "C" {
|
||||
bool fused_up_gate; // whether to use fused up/gate op [EXPERIMENTAL]
|
||||
bool fused_mmad; // whether to use fused mul+multi_add op [EXPERIMENTAL]
|
||||
bool rope_cache; // whether to use RoPE cache [EXPERIMENTAL]
|
||||
bool graph_reuse; // whether to reuse graphs when possible [EXPERIMENTAL]
|
||||
int min_experts;
|
||||
float thresh_experts;
|
||||
bool only_active_experts;
|
||||
|
||||
@@ -9,6 +9,7 @@ struct llama_model;
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
|
||||
struct llama_kv_cell {
|
||||
llama_pos pos = -1;
|
||||
@@ -205,4 +206,10 @@ struct llama_context {
|
||||
|
||||
ggml_backend_t ggml_backend_by_name(const char * name);
|
||||
|
||||
struct Prev;
|
||||
std::unique_ptr<Prev> prev;
|
||||
|
||||
void reset_scheduler();
|
||||
bool can_reuse_graph(const llama_batch & u_batch) const;
|
||||
|
||||
};
|
||||
|
||||
@@ -38,6 +38,7 @@ struct llama_cparams {
|
||||
bool fused_up_gate;
|
||||
bool fused_mmad;
|
||||
bool rope_cache;
|
||||
bool graph_reuse;
|
||||
int min_experts;
|
||||
float thresh_experts;
|
||||
|
||||
|
||||
115
src/llama.cpp
115
src/llama.cpp
@@ -535,6 +535,42 @@ static size_t llama_get_device_memory(const llama_model & model, int device) {
|
||||
GGML_UNUSED(model);
|
||||
GGML_UNUSED(device);
|
||||
}
|
||||
//llama_batch u_batch = {
|
||||
// /* .n_tokens = */ (int32_t) n_tokens,
|
||||
// /* .token = */ batch_all.token ? batch_all.token + cur_token : nullptr,
|
||||
// /* .embd = */ batch_all.embd ? batch_all.embd + cur_token*n_embd : nullptr,
|
||||
// /* .pos = */ batch_all.pos ? batch_all.pos + cur_token : nullptr,
|
||||
// /* .n_seq_id = */ batch_all.n_seq_id ? batch_all.n_seq_id + cur_token : nullptr,
|
||||
// /* .seq_id = */ batch_all.seq_id ? batch_all.seq_id + cur_token : nullptr,
|
||||
// /* .logits = */ batch_all.logits ? batch_all.logits + cur_token : nullptr,
|
||||
// /* .all_pos_0 = */ batch_all.all_pos_0 + (llama_pos) cur_token*batch_all.all_pos_1,
|
||||
// /* .all_pos_1 = */ batch_all.all_pos_1,
|
||||
// /* .all_seq_id = */ batch_all.all_seq_id,
|
||||
//};
|
||||
|
||||
struct llama_context::Prev {
|
||||
int all_pos_0, all_pos_1, all_seq_id;
|
||||
int n_outputs;
|
||||
int n_kv;
|
||||
ggml_cgraph * graph;
|
||||
};
|
||||
|
||||
void llama_context::reset_scheduler() {
|
||||
ggml_backend_sched_reset(sched);
|
||||
prev.reset();
|
||||
}
|
||||
|
||||
bool llama_context::can_reuse_graph(const llama_batch & u_batch) const {
|
||||
if (!prev || !prev->graph) return false;
|
||||
if (u_batch.n_tokens > 1) return false;
|
||||
if (u_batch.embd) return false;
|
||||
if (!cparams.graph_reuse) return false;
|
||||
return u_batch.all_pos_0 == prev->all_pos_0 &&
|
||||
u_batch.all_pos_1 == prev->all_pos_1 &&
|
||||
u_batch.all_seq_id == prev->all_seq_id &&
|
||||
kv_self.n == prev->n_kv &&
|
||||
n_outputs == prev->n_outputs;
|
||||
}
|
||||
|
||||
llama_context::llama_context(const llama_model & model)
|
||||
: model(model) , sampling(llama_n_vocab(&model)) , t_start_us(model.t_start_us) , t_load_us(model.t_load_us) {}
|
||||
@@ -2876,27 +2912,60 @@ static int llama_decode_internal(
|
||||
printf("prelude(...): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
|
||||
|
||||
//if (n_tokens_all == 1) {
|
||||
// printf("================= %s\n", __func__);
|
||||
// printf(" all_pos_0 = %d, all_pos_1 = %d, all_seq_id = %d\n", batch_all.all_pos_0, batch_all.all_pos_1, batch_all.all_seq_id);
|
||||
// printf(" embd = %p, logits = %p, token = %p\n", (const void *)batch_all.embd, (const void *)batch_all.logits, (const void *)batch_all.token);
|
||||
// printf(" n_outputs = %d, kv_self.n = %d\n", n_outputs, kv_self.n);
|
||||
//}
|
||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
ggml_backend_sched_reset(lctx.sched);
|
||||
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
||||
ggml_cgraph * gf = nullptr;
|
||||
if (!lctx.can_reuse_graph(u_batch)) {
|
||||
lctx.reset_scheduler();
|
||||
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
||||
#if IK_PRINT_TIMING
|
||||
tim2 = ggml_time_us();
|
||||
printf("sched_reset(...): %d us\n", int(tim2-tim1));
|
||||
tim2 = ggml_time_us();
|
||||
printf("sched_reset(...): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
ggml_cgraph * gf = llm_build_context::llama_build_graph(lctx, u_batch, false);
|
||||
gf = llm_build_context::llama_build_graph(lctx, u_batch, false);
|
||||
#if IK_PRINT_TIMING
|
||||
tim2 = ggml_time_us();
|
||||
printf("build_graph(...): %d us\n", int(tim2-tim1));
|
||||
tim2 = ggml_time_us();
|
||||
printf("build_graph(...): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
ggml_backend_sched_alloc_graph(lctx.sched, gf);
|
||||
#if IK_PRINT_TIMING
|
||||
tim2 = ggml_time_us();
|
||||
printf("sched_alloc_graph(...): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
//struct llama_context::Prev {
|
||||
// int all_pos_0, all_pos_1, all_seq_id;
|
||||
// int n_outputs;
|
||||
// int n_kv;
|
||||
// ggml_cgraph * graph;
|
||||
//};
|
||||
if (u_batch.n_tokens == 1 && u_batch.embd == nullptr && lctx.cparams.graph_reuse) {
|
||||
lctx.prev = std::make_unique<llama_context::Prev>(llama_context::Prev{
|
||||
(int)u_batch.all_pos_0, (int)u_batch.all_pos_1, (int)u_batch.all_seq_id,
|
||||
(int)lctx.n_outputs, (int)lctx.kv_self.n, gf});
|
||||
}
|
||||
} else {
|
||||
//printf("Reusing graph\n");
|
||||
gf = lctx.prev->graph;
|
||||
}
|
||||
|
||||
// the output is always the last tensor in the graph
|
||||
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
||||
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
|
||||
@@ -2921,15 +2990,6 @@ static int llama_decode_internal(
|
||||
}
|
||||
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
ggml_backend_sched_alloc_graph(lctx.sched, gf);
|
||||
#if IK_PRINT_TIMING
|
||||
tim2 = ggml_time_us();
|
||||
printf("sched_alloc_graph(...): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
|
||||
#if IK_PRINT_TIMING == 1
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
@@ -3060,9 +3120,11 @@ static int llama_decode_internal(
|
||||
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
||||
// overlap with device computation.
|
||||
#if IK_PRINT_TIMING
|
||||
auto tim1 = ggml_time_us();
|
||||
auto tim1 = ggml_time_us();
|
||||
#endif
|
||||
ggml_backend_sched_reset(lctx.sched);
|
||||
if (!lctx.prev) {
|
||||
lctx.reset_scheduler();
|
||||
}
|
||||
#if IK_PRINT_TIMING
|
||||
auto tim2 = ggml_time_us();
|
||||
printf("sched_reset(...): %d us\n", int(tim2-tim1));
|
||||
@@ -3158,7 +3220,7 @@ static int llama_encode_internal(
|
||||
batch.seq_id = seq_id_arr.data();
|
||||
}
|
||||
|
||||
ggml_backend_sched_reset(lctx.sched);
|
||||
lctx.reset_scheduler();
|
||||
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
||||
|
||||
ggml_cgraph * gf = llm_build_context::llama_build_graph(lctx, batch, false);
|
||||
@@ -3248,7 +3310,7 @@ static int llama_encode_internal(
|
||||
|
||||
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
||||
// overlap with device computation.
|
||||
ggml_backend_sched_reset(lctx.sched);
|
||||
lctx.reset_scheduler();
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -3462,7 +3524,7 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||
#else
|
||||
// ggml_graph defrag
|
||||
|
||||
ggml_backend_sched_reset(lctx.sched);
|
||||
lctx.reset_scheduler();
|
||||
|
||||
ggml_cgraph * gf = llm_build_context::llama_build_graph_defrag(lctx, ids);
|
||||
|
||||
@@ -3484,7 +3546,7 @@ static int32_t llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||
}
|
||||
|
||||
{
|
||||
ggml_backend_sched_reset(lctx.sched);
|
||||
lctx.reset_scheduler();
|
||||
|
||||
ggml_cgraph * gf = llm_build_context::llama_build_graph_k_shift(lctx);
|
||||
|
||||
@@ -3510,7 +3572,7 @@ static int32_t llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||
|
||||
if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) {
|
||||
{
|
||||
ggml_backend_sched_reset(lctx.sched);
|
||||
lctx.reset_scheduler();
|
||||
|
||||
ggml_cgraph * gf = llm_build_context::llama_build_graph_s_copy(lctx);
|
||||
|
||||
@@ -3553,7 +3615,7 @@ static int32_t llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||
ggml_cgraph * gf = llm_build_context::llama_build_graph(lctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true);
|
||||
|
||||
// initialize scheduler with the worst-case graph
|
||||
ggml_backend_sched_reset(lctx.sched);
|
||||
lctx.reset_scheduler();
|
||||
if (!ggml_backend_sched_reserve(lctx.sched, gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
||||
}
|
||||
@@ -3840,6 +3902,7 @@ struct llama_context_params llama_context_default_params() {
|
||||
/*.fused_up_gate =*/ true,
|
||||
/*.fused_mmad =*/ true,
|
||||
/*.rope_cache =*/ false,
|
||||
/*.graph_reuse =*/ false,
|
||||
/*.min_experts =*/ -1,
|
||||
/*.thtesh_experts =*/ 0.0f,
|
||||
/*.only_active_experts =*/ false,
|
||||
@@ -4144,6 +4207,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
cparams.fused_up_gate = params.fused_up_gate;
|
||||
cparams.fused_mmad = params.fused_mmad;
|
||||
cparams.rope_cache = params.rope_cache;
|
||||
cparams.graph_reuse = params.graph_reuse;
|
||||
cparams.min_experts = params.min_experts;
|
||||
cparams.thresh_experts = params.thresh_experts;
|
||||
cparams.cuda_params = params.cuda_params;
|
||||
@@ -4230,6 +4294,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
LLAMA_LOG_INFO("%s: fused_up_gate = %d\n", __func__, cparams.fused_up_gate);
|
||||
LLAMA_LOG_INFO("%s: fused_mmad = %d\n", __func__, cparams.fused_mmad);
|
||||
LLAMA_LOG_INFO("%s: rope_cache = %d\n", __func__, cparams.rope_cache);
|
||||
LLAMA_LOG_INFO("%s: graph_reuse = %d\n", __func__, cparams.graph_reuse);
|
||||
LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts);
|
||||
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
||||
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
||||
|
||||
Reference in New Issue
Block a user