introduce env ROCM_FLASH_ATTN_CU_NUM to control bwd group mode persistent kernel grid size

This commit is contained in:
Ye Wang
2026-05-07 15:48:17 -05:00
parent 3e3cb36c7a
commit e9af75800d

View File

@@ -8,6 +8,7 @@
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp"
#include <cstdlib>
#include <string>
#include <type_traits>
#include <utility>
@@ -27,6 +28,20 @@
namespace ck_tile {
// Return get_num_cus() unless ROCM_FLASH_ATTN_CU_NUM is set, in which case use that value.
// This lets you cap the persistent bwd kernel to fewer CUs without recompiling.
CK_TILE_HOST inline index_t get_num_cus_override()
{
const char* env = std::getenv("ROCM_FLASH_ATTN_CU_NUM");
if(env != nullptr)
{
const int v = std::atoi(env);
if(v > 0)
return static_cast<index_t>(v);
}
return get_num_cus();
}
// Per-CU state for group-mode deterministic persistent scheduling.
// Packed into a single array to reduce kargs pointer count (5 pointers → 1).
// alignas(16): enables aligned 128-bit loads; sizeof == 32 (6×4 + 8 pad).
@@ -126,7 +141,7 @@ struct FmhaBwdWorkspaceManager
return 0;
const size_t raw = GetDqAccSplitsSize<kUseQrQtrDorPipeline>(batch) +
GetDqAccOffsetsSize(batch) + GetPrefixBatchSize(batch) +
GetCuStateSize(get_num_cus()) + GetBatchStateSize(batch);
GetCuStateSize(get_num_cus_override()) + GetBatchStateSize(batch);
// Pad to 4K so dq_acc buffer always starts on a page-aligned boundary.
return integer_least_multiple(raw, static_cast<size_t>(4096));
}
@@ -147,7 +162,7 @@ struct FmhaBwdWorkspaceManager
}
CK_TILE_HOST static size_t GetBatchStateOffset(const int batch)
{
return GetCuStateOffset(batch) + GetCuStateSize(get_num_cus());
return GetCuStateOffset(batch) + GetCuStateSize(get_num_cus_override());
}
template <bool kUseQrQtrDorPipeline>
CK_TILE_HOST static size_t GetDqAccDataOffset(const int batch)
@@ -193,7 +208,7 @@ struct FmhaBwdWorkspaceManager
else if constexpr(kIsGroupMode)
{ // deterministic group mode (persistent)
// Step 1: compute prefix_batch and target_w using per-batch seqlens
const index_t num_cus = get_num_cus();
const index_t num_cus = get_num_cus_override();
auto* prefix_batch = reinterpret_cast<index_t*>(reinterpret_cast<char*>(cpu_ws) +
GetDqAccSplitsSize<false>(batch_size) +
GetDqAccOffsetsSize(batch_size));
@@ -310,7 +325,7 @@ struct FmhaBwdWorkspaceManager
}
else // deterministic batch mode (kUsePersistent)
{
const index_t dqdqkdv_workers = get_num_cus();
const index_t dqdqkdv_workers = get_num_cus_override();
const index_t jobs_per_head = integer_divide_ceil(seqlen_k, kN0);
const index_t total_jobs = batch_size * nhead_q * jobs_per_head;
const index_t jobs_per_worker = integer_divide_ceil(total_jobs, dqdqkdv_workers);
@@ -1074,7 +1089,7 @@ struct FmhaBwdDQDKDVKernel
const index_t jobs_per_head =
kUseQrQtrDorPipeline ? 1 : integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0);
if constexpr(kUsePersistent)
return dim3(get_num_cus(), 1, 1);
return dim3(get_num_cus_override(), 1, 1);
else
return dim3(jobs_per_head, nhead_, batch_size_);
}