From 9d9ed6a032699bec20833e633474e49eb83c2883 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 13 Jan 2026 10:23:50 +0200 Subject: [PATCH] Add -sas, --scheduler-async to llama-bench (#1140) --- examples/llama-bench/llama-bench.cpp | 68 +++++++++++++++++++--------- 1 file changed, 47 insertions(+), 21 deletions(-) diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 0672d1b2..8a16701a 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -270,6 +270,7 @@ struct cmd_params { bool mqkv = false; bool muge = false; bool rcache = false; + bool sas = false; output_formats output_format; output_formats output_format_stderr; }; @@ -313,6 +314,7 @@ static const cmd_params cmd_params_defaults = { /* mqkv */ false, /* muge */ false, /* rcache */ false, + /* sas */ false, /* output_format */ MARKDOWN, /* output_format_stderr */ NONE, }; @@ -364,6 +366,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -ger, --grouped-expert-routing <0|1>(default: %s)\n", cmd_params_defaults.ger ? "1" : "0"); printf(" -no-fug, --no-fused-up-gate <0|1> (default: %s)\n", cmd_params_defaults.no_fug? "1" : "0"); printf(" -no-ooae, --no-offload-only-active-experts <0|1> (default: %s)\n", cmd_params_defaults.no_ooae? "1" : "0"); + printf(" -sas, --scheduler-async <0|1> (default: %s)\n", cmd_params_defaults.sas ? "1" : "0"); printf("\n"); printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n"); } @@ -798,6 +801,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { break; } params.muge = std::stoi(argv[i]); + } else if (arg == "-sas" || arg == "--scheduler-async") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.sas = std::stoi(argv[i]); } else if (arg == "-rcache" || arg == "--rope-cache") { if (++i >= argc) { invalid_param = true; @@ -937,6 +946,7 @@ struct cmd_params_instance { bool mqkv = false; bool muge = false; bool rcache = false; + bool sas = false; const llama_model_tensor_buft_override* buft_overrides; llama_model_params to_llama_mparams() const { @@ -996,6 +1006,7 @@ struct cmd_params_instance { cparams.thresh_experts = ser.second; cparams.embeddings = embeddings; cparams.cuda_params = (void *)cuda_params.data(); + cparams.scheduler_async = sas; return cparams; } @@ -1061,6 +1072,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .mqkv = */ params.mqkv, /* .muge = */ params.muge, /* .rcache = */ params.rcache, + /* .sas = */ params.sas, /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); @@ -1103,6 +1115,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .mqkv = */ params.mqkv, /* .muge = */ params.muge, /* .rcache = */ params.rcache, + /* .sas = */ params.sas, /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); @@ -1145,6 +1158,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .mqkv = */ params.mqkv, /* .muge = */ params.muge, /* .rcache = */ params.rcache, + /* .sas = */ params.sas, /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); @@ -1187,6 +1201,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .mqkv = */ params.mqkv, /* .muge = */ params.muge, /* .rcache = */ params.rcache, + /* .sas = */ params.sas, /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); @@ -1240,6 +1255,7 @@ struct test { bool mqkv = false; bool muge = false; bool rcache = false; + bool sas = false; std::string override_tensor; int n_prompt; int n_gen; @@ -1280,6 +1296,7 @@ struct test { fmoe = inst.fmoe; ger = inst.ger; rcache = inst.rcache; + sas = inst.sas; no_fug = inst.no_fug; use_thp = inst.use_thp; no_ooae = inst.no_ooae; @@ -1376,25 +1393,6 @@ struct test { return "CPU"; } - static const std::vector & get_fields() { - static const std::vector fields = { - "build_commit", "build_number", - "cuda", "vulkan", "kompute", "metal", "sycl", "rpc", "gpu_blas", "blas", - "cpu_info", "gpu_info", - "model_filename", "model_type", "model_size", "model_n_params", - "n_batch", "n_ubatch", - "n_threads", "type_k", "type_v", - "n_gpu_layers", "split_mode", - "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser", "reuse", - "tensor_split", "use_mmap", "embeddings", "repack", "mqkv", "muge", "fused_moe", "grouped_er", - "no_fused_up_gate", "use_thp", "no_ooae", "rcache", "cuda_params", "override_tensor", - "n_prompt", "n_gen", "test_time", - "avg_ns", "stddev_ns", - "avg_ts", "stddev_ts", "test", - }; - return fields; - } - enum field_type {STRING, BOOL, INT, FLOAT}; static field_type get_field_type(const std::string & field) { @@ -1410,7 +1408,7 @@ struct test { field == "gpu_blas" || field == "blas" || field == "sycl" || field == "no_kv_offload" || field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || field == "use_thp" || field == "fused_moe" || field == "grouped_er" || field == "no_fused_up_gate" || field == "no_ooae" || field == "mqkv" || - field == "rcache" || field == "reuse" || field == "muge") { + field == "rcache" || field == "reuse" || field == "muge" || field == "sas") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -1454,7 +1452,7 @@ struct test { std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser), std::to_string(reuse), tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), std::to_string(mqkv), std::to_string(muge), std::to_string(fmoe), std::to_string(ger), - std::to_string(no_fug), std::to_string(use_thp), std::to_string(no_ooae), std::to_string(rcache), + std::to_string(no_fug), std::to_string(use_thp), std::to_string(no_ooae), std::to_string(rcache), std::to_string(sas), cuda_params, override_tensor, std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), @@ -1464,6 +1462,25 @@ struct test { return values; } + static const std::vector & get_fields() { + static const std::vector fields = { + "build_commit", "build_number", + "cuda", "vulkan", "kompute", "metal", "sycl", "rpc", "gpu_blas", "blas", + "cpu_info", "gpu_info", + "model_filename", "model_type", "model_size", "model_n_params", + "n_batch", "n_ubatch", + "n_threads", "type_k", "type_v", + "n_gpu_layers", "split_mode", + "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser", "reuse", + "tensor_split", "use_mmap", "embeddings", "repack", "mqkv", "muge", "fused_moe", "grouped_er", + "no_fused_up_gate", "use_thp", "no_ooae", "rcache", "sas", "cuda_params", "override_tensor", + "n_prompt", "n_gen", "test_time", + "avg_ns", "stddev_ns", + "avg_ts", "stddev_ts", "test", + }; + return fields; + } + std::map get_map() const { std::map map; auto fields = get_fields(); @@ -1642,6 +1659,9 @@ struct markdown_printer : public printer { if (field == "muge") { return 4; } + if (field == "sas") { + return 3; + } if (field == "use_thp") { return 3; } @@ -1712,6 +1732,9 @@ struct markdown_printer : public printer { if (field == "muge") { return "muge"; } + if (field == "sas") { + return "sas"; + } if (field == "use_thp") { return "thp"; } @@ -1815,6 +1838,9 @@ struct markdown_printer : public printer { if (params.mqkv != cmd_params_defaults.mqkv) { fields.emplace_back("mqkv"); } + if (params.sas != cmd_params_defaults.sas) { + fields.emplace_back("sas"); + } if (params.muge != cmd_params_defaults.muge) { fields.emplace_back("muge"); }