diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 52b1dbbd..a12543a5 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -251,6 +251,7 @@ struct cmd_params { std::vector mla_attn; std::vector attn_max_batch; std::vector ser; + std::vector reuse; std::vector> tensor_split; std::vector use_mmap; std::vector embeddings; @@ -289,9 +290,10 @@ static const cmd_params cmd_params_defaults = { /* main_gpu */ {0}, /* no_kv_offload */ {false}, /* flash_attn */ {true}, - /* mla_attn */ {0}, + /* mla_attn */ {3}, /* attn_max_batch */ {0}, /* ser */ {{-1,0.0f}}, + /* reuse */ {false}, /* tensor_split */ {std::vector(llama_max_devices(), 0.0f)}, /* use_mmap */ {true}, /* embeddings */ {false}, @@ -339,6 +341,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -mla, --mla-attn <0|1|2> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str()); printf(" -amb, --attn-max-batch (default: %s)\n", join(cmd_params_defaults.attn_max_batch, ",").c_str()); printf(" -ser, --smart-expert-reduction (default: %s)\n", join(cmd_params_defaults.attn_max_batch, ",").c_str()); + printf(" -reuse, --graph-reuse <0|1> (default: %s)\n", join(cmd_params_defaults.reuse, ",").c_str()); printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); printf(" --numa (default: disabled)\n"); printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); @@ -681,6 +684,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = string_split(argv[i], split_delim); params.attn_max_batch.insert(params.attn_max_batch.end(), p.begin(), p.end()); + } else if (arg == "-reuse" || arg == "--graph-reuse") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + params.reuse.insert(params.reuse.end(), p.begin(), p.end()); } else if (arg == "-ser" || arg == "--smart-expert-reduction") { if (++i >= argc) { invalid_param = true; @@ -852,6 +862,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; } if (params.mla_attn.empty()) { params.mla_attn = cmd_params_defaults.mla_attn; } if (params.attn_max_batch.empty()){ params.attn_max_batch = cmd_params_defaults.attn_max_batch; } + if (params.reuse.empty()) { params.reuse = cmd_params_defaults.reuse; } if (params.ser.empty()) { params.ser = cmd_params_defaults.ser; } if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; } if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; } @@ -891,6 +902,7 @@ struct cmd_params_instance { bool flash_attn; int mla_attn; int attn_max_batch; + bool reuse; Ser ser; std::vector tensor_split; std::string cuda_params; @@ -950,6 +962,7 @@ struct cmd_params_instance { cparams.flash_attn = flash_attn; cparams.mla_attn = mla_attn; cparams.attn_max_batch = attn_max_batch; + cparams.graph_reuse = reuse; cparams.fused_moe_up_gate = fmoe; cparams.grouped_expert_routing = ger; cparams.rope_cache = rcache; @@ -984,6 +997,7 @@ static std::vector get_cmd_params_instances(const cmd_param for (const auto & fa : params.flash_attn) for (const auto & mla : params.mla_attn) for (const auto & amb : params.attn_max_batch) + for (const auto & reuse : params.reuse) for (const auto & ser : params.ser) for (const auto & nt : params.n_threads) { for (const auto & n_prompt : params.n_prompt) { @@ -1008,6 +1022,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .flash_attn = */ fa, /* .mla_attn = */ mla, /* .attn_max_b = */ amb, + /* .reuse = */ reuse, /* .ser = */ ser, /* .tensor_split = */ ts, /* .cuda_params = */ params.cuda_params, @@ -1048,6 +1063,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .flash_attn = */ fa, /* .mla_attn = */ mla, /* .attn_max_b = */ amb, + /* .reuse = */ reuse, /* .ser = */ ser, /* .tensor_split = */ ts, /* .cuda_params = */ params.cuda_params, @@ -1088,6 +1104,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .flash_attn = */ fa, /* .mla_attn = */ mla, /* .attn_max_b = */ amb, + /* .reuse = */ reuse, /* .ser = */ ser, /* .tensor_split = */ ts, /* .cuda_params = */ params.cuda_params, @@ -1128,6 +1145,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .flash_attn = */ fa, /* .mla_attn = */ mla, /* .attn_max_b = */ amb, + /* .reuse = */ reuse, /* .ser = */ ser, /* .tensor_split = */ ts, /* .cuda_params = */ params.cuda_params, @@ -1179,6 +1197,7 @@ struct test { bool flash_attn; int mla_attn; int attn_max_batch; + bool reuse; Ser ser; std::vector tensor_split; std::string cuda_params; @@ -1219,6 +1238,7 @@ struct test { flash_attn = inst.flash_attn; mla_attn = inst.mla_attn; attn_max_batch = inst.attn_max_batch; + reuse = inst.reuse; ser = inst.ser; tensor_split = inst.tensor_split; cuda_params = inst.cuda_params; @@ -1321,7 +1341,7 @@ struct test { "n_batch", "n_ubatch", "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", - "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser", + "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser", "reuse", "tensor_split", "use_mmap", "embeddings", "repack", "mqkv", "fused_moe", "grouped_er", "fused_up_gate", "use_thp", "ooae", "rcache", "n_prompt", "n_gen", "test_time", @@ -1346,7 +1366,7 @@ struct test { field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || field == "use_thp" || field == "fused_moe" || field == "grouped_er" || field == "fused_up_gate" || field == "ooae" || field == "mqkv" || - field == "rcache") { + field == "rcache" || field == "reuse") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -1387,7 +1407,7 @@ struct test { std::to_string(is_gen ? n_threads.first : n_threads.second), ggml_type_name(type_k), ggml_type_name(type_v), std::to_string(n_gpu_layers), split_mode_str(split_mode), std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), - std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser), + std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser), std::to_string(reuse), tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), std::to_string(fmoe), std::to_string(ger), std::to_string(rcache), std::to_string(no_fug), std::to_string(use_thp), std::to_string(no_ooae), std::to_string(mqkv), @@ -1559,6 +1579,9 @@ struct markdown_printer : public printer { if (field == "attn_max_batch") { return 5; } + if (field == "reuse") { + return 6; + } if (field == "ser") { return 10; } @@ -1623,7 +1646,10 @@ struct markdown_printer : public printer { if (field == "attn_max_batch") { return "amb"; } - if (field == "attn_max_batch") { + if (field == "reuse") { + return "reuse"; + } + if (field == "ser") { return "ser"; } if (field == "use_mmap") { @@ -1702,9 +1728,12 @@ struct markdown_printer : public printer { if (params.mla_attn.size() > 1 || params.mla_attn != cmd_params_defaults.mla_attn) { fields.emplace_back("mla_attn"); } - if (params.attn_max_batch.size() > 1 || params.attn_max_batch != cmd_params_defaults.mla_attn) { + if (params.attn_max_batch.size() > 1 || params.attn_max_batch != cmd_params_defaults.attn_max_batch) { fields.emplace_back("attn_max_batch"); } + if (params.reuse.size() > 1 || params.reuse != cmd_params_defaults.reuse) { + fields.emplace_back("reuse"); + } if (params.ser.size() > 1 || params.ser != cmd_params_defaults.ser) { fields.emplace_back("ser"); } diff --git a/src/llama.cpp b/src/llama.cpp index d5fab312..fe75f039 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -535,21 +535,9 @@ static size_t llama_get_device_memory(const llama_model & model, int device) { GGML_UNUSED(model); GGML_UNUSED(device); } - //llama_batch u_batch = { - // /* .n_tokens = */ (int32_t) n_tokens, - // /* .token = */ batch_all.token ? batch_all.token + cur_token : nullptr, - // /* .embd = */ batch_all.embd ? batch_all.embd + cur_token*n_embd : nullptr, - // /* .pos = */ batch_all.pos ? batch_all.pos + cur_token : nullptr, - // /* .n_seq_id = */ batch_all.n_seq_id ? batch_all.n_seq_id + cur_token : nullptr, - // /* .seq_id = */ batch_all.seq_id ? batch_all.seq_id + cur_token : nullptr, - // /* .logits = */ batch_all.logits ? batch_all.logits + cur_token : nullptr, - // /* .all_pos_0 = */ batch_all.all_pos_0 + (llama_pos) cur_token*batch_all.all_pos_1, - // /* .all_pos_1 = */ batch_all.all_pos_1, - // /* .all_seq_id = */ batch_all.all_seq_id, - //}; struct llama_context::Prev { - int all_pos_0, all_pos_1, all_seq_id; + int all_seq_id; int n_outputs; int n_kv; ggml_cgraph * graph; @@ -565,9 +553,8 @@ bool llama_context::can_reuse_graph(const llama_batch & u_batch) const { if (u_batch.n_tokens > 1) return false; if (u_batch.embd) return false; if (!cparams.graph_reuse) return false; - return u_batch.all_pos_0 == prev->all_pos_0 && - u_batch.all_pos_1 == prev->all_pos_1 && - u_batch.all_seq_id == prev->all_seq_id && + return u_batch.all_seq_id == prev->all_seq_id && + kv_self.head > 0 && kv_self.n == prev->n_kv && n_outputs == prev->n_outputs; } @@ -2950,16 +2937,9 @@ static int llama_decode_internal( tim2 = ggml_time_us(); printf("sched_alloc_graph(...): %d us\n", int(tim2-tim1)); #endif -//struct llama_context::Prev { -// int all_pos_0, all_pos_1, all_seq_id; -// int n_outputs; -// int n_kv; -// ggml_cgraph * graph; -//}; if (u_batch.n_tokens == 1 && u_batch.embd == nullptr && lctx.cparams.graph_reuse) { lctx.prev = std::make_unique(llama_context::Prev{ - (int)u_batch.all_pos_0, (int)u_batch.all_pos_1, (int)u_batch.all_seq_id, - (int)lctx.n_outputs, (int)lctx.kv_self.n, gf}); + (int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)lctx.kv_self.n, gf}); } } else { //printf("Reusing graph\n");