From 6d05977940c6483f1ae86446f4ff0fc86f9c0ef6 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 25 Oct 2025 09:32:01 +0300 Subject: [PATCH] Change flash attention to be on by default --- common/common.cpp | 6 +++--- common/common.h | 2 +- examples/llama-bench/llama-bench.cpp | 2 +- src/llama.cpp | 26 +++++++++++++------------- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 0e3da14d..7c0bd16d 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -994,8 +994,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.cont_batching = false; return true; } - if (arg == "-fa" || arg == "--flash-attn") { - params.flash_attn = true; + if (arg == "-no-fa" || arg == "--no-flash-attn") { + params.flash_attn = false; return true; } if (arg == "-mla" || arg == "--mla-use") { @@ -1804,7 +1804,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-ub, --ubatch-size N", "physical maximum batch size (default: %d)", params.n_ubatch }); options.push_back({ "*", " --keep N", "number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep }); options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks }); - options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" }); + options.push_back({ "*", "-no-fa, --no-flash-attn", "disable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" }); options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn }); options.push_back({ "*", "-amb, --attention-max-batch", "max batch size for attention computations (default: %d)", params.attn_max_batch}); options.push_back({ "*", "-no-fmoe, --no-fused-moe", "disable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" }); diff --git a/common/common.h b/common/common.h index 2deece17..f41e7580 100644 --- a/common/common.h +++ b/common/common.h @@ -230,7 +230,7 @@ struct gpt_params { bool multiline_input = false; // reverse the usage of `\` bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly - bool flash_attn = false; // flash attention + bool flash_attn = true; // flash attention int mla_attn = 0; // MLA 0: standard attention, 1: MLA with K and transposed V cache, 2: MLA with just K cache int attn_max_batch = 0; // Max batch size to use when computing attention (only applicable if flash_attn = false) bool fused_moe_up_gate = true; // fused up*unary(gate) op for MoE models diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 135bc029..3e948e68 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -285,7 +285,7 @@ static const cmd_params cmd_params_defaults = { /* split_mode */ {LLAMA_SPLIT_MODE_LAYER}, /* main_gpu */ {0}, /* no_kv_offload */ {false}, - /* flash_attn */ {false}, + /* flash_attn */ {true}, /* mla_attn */ {0}, /* attn_max_batch */ {0}, /* ser */ {{-1,0.0f}}, diff --git a/src/llama.cpp b/src/llama.cpp index 0011e59b..2cce5384 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3750,7 +3750,7 @@ struct llama_context_params llama_context_default_params() { /*.logits_all =*/ false, /*.embeddings =*/ false, /*.offload_kqv =*/ true, - /*.flash_attn =*/ false, + /*.flash_attn =*/ true, /*.mla_attn =*/ 0, /*.attn_max_batch =*/ 0, /*.fused_moe_up_gate =*/ true, @@ -4040,19 +4040,19 @@ struct llama_context * llama_new_context_with_model( cparams.mla_attn = 0; } - LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); - LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); - LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); - LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); - LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn); - LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch); - LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate); - LLAMA_LOG_INFO("%s: grouped er = %d\n", __func__, cparams.grouped_expert_routing); + LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); + LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); + LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); + LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); + LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn); + LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch); + LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate); + LLAMA_LOG_INFO("%s: grouped er = %d\n", __func__, cparams.grouped_expert_routing); LLAMA_LOG_INFO("%s: fused_up_gate = %d\n", __func__, cparams.fused_up_gate); - LLAMA_LOG_INFO("%s: fused_mmad = %d\n", __func__, cparams.fused_mmad); - LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts); - LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); - LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); + LLAMA_LOG_INFO("%s: fused_mmad = %d\n", __func__, cparams.fused_mmad); + LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts); + LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); + LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); ctx->abort_callback = params.abort_callback; ctx->abort_callback_data = params.abort_callback_data;