diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 5be57af61b..f238c2470d 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp" +#include #include #include #include @@ -27,6 +28,20 @@ namespace ck_tile { +// Return get_num_cus() unless ROCM_FLASH_ATTN_CU_NUM is set, in which case use that value. +// This lets you cap the persistent bwd kernel to fewer CUs without recompiling. +CK_TILE_HOST inline index_t get_num_cus_override() +{ + const char* env = std::getenv("ROCM_FLASH_ATTN_CU_NUM"); + if(env != nullptr) + { + const int v = std::atoi(env); + if(v > 0) + return static_cast(v); + } + return get_num_cus(); +} + // Per-CU state for group-mode deterministic persistent scheduling. // Packed into a single array to reduce kargs pointer count (5 pointers → 1). // alignas(16): enables aligned 128-bit loads; sizeof == 32 (6×4 + 8 pad). @@ -126,7 +141,7 @@ struct FmhaBwdWorkspaceManager return 0; const size_t raw = GetDqAccSplitsSize(batch) + GetDqAccOffsetsSize(batch) + GetPrefixBatchSize(batch) + - GetCuStateSize(get_num_cus()) + GetBatchStateSize(batch); + GetCuStateSize(get_num_cus_override()) + GetBatchStateSize(batch); // Pad to 4K so dq_acc buffer always starts on a page-aligned boundary. return integer_least_multiple(raw, static_cast(4096)); } @@ -147,7 +162,7 @@ struct FmhaBwdWorkspaceManager } CK_TILE_HOST static size_t GetBatchStateOffset(const int batch) { - return GetCuStateOffset(batch) + GetCuStateSize(get_num_cus()); + return GetCuStateOffset(batch) + GetCuStateSize(get_num_cus_override()); } template CK_TILE_HOST static size_t GetDqAccDataOffset(const int batch) @@ -193,7 +208,7 @@ struct FmhaBwdWorkspaceManager else if constexpr(kIsGroupMode) { // deterministic group mode (persistent) // Step 1: compute prefix_batch and target_w using per-batch seqlens - const index_t num_cus = get_num_cus(); + const index_t num_cus = get_num_cus_override(); auto* prefix_batch = reinterpret_cast(reinterpret_cast(cpu_ws) + GetDqAccSplitsSize(batch_size) + GetDqAccOffsetsSize(batch_size)); @@ -310,7 +325,7 @@ struct FmhaBwdWorkspaceManager } else // deterministic batch mode (kUsePersistent) { - const index_t dqdqkdv_workers = get_num_cus(); + const index_t dqdqkdv_workers = get_num_cus_override(); const index_t jobs_per_head = integer_divide_ceil(seqlen_k, kN0); const index_t total_jobs = batch_size * nhead_q * jobs_per_head; const index_t jobs_per_worker = integer_divide_ceil(total_jobs, dqdqkdv_workers); @@ -1074,7 +1089,7 @@ struct FmhaBwdDQDKDVKernel const index_t jobs_per_head = kUseQrQtrDorPipeline ? 1 : integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0); if constexpr(kUsePersistent) - return dim3(get_num_cus(), 1, 1); + return dim3(get_num_cus_override(), 1, 1); else return dim3(jobs_per_head, nhead_, batch_size_); }