mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
introduce env ROCM_FLASH_ATTN_CU_NUM to control bwd group mode persistent kernel grid size
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp"
|
||||
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
@@ -27,6 +28,20 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Return get_num_cus() unless ROCM_FLASH_ATTN_CU_NUM is set, in which case use that value.
|
||||
// This lets you cap the persistent bwd kernel to fewer CUs without recompiling.
|
||||
CK_TILE_HOST inline index_t get_num_cus_override()
|
||||
{
|
||||
const char* env = std::getenv("ROCM_FLASH_ATTN_CU_NUM");
|
||||
if(env != nullptr)
|
||||
{
|
||||
const int v = std::atoi(env);
|
||||
if(v > 0)
|
||||
return static_cast<index_t>(v);
|
||||
}
|
||||
return get_num_cus();
|
||||
}
|
||||
|
||||
// Per-CU state for group-mode deterministic persistent scheduling.
|
||||
// Packed into a single array to reduce kargs pointer count (5 pointers → 1).
|
||||
// alignas(16): enables aligned 128-bit loads; sizeof == 32 (6×4 + 8 pad).
|
||||
@@ -126,7 +141,7 @@ struct FmhaBwdWorkspaceManager
|
||||
return 0;
|
||||
const size_t raw = GetDqAccSplitsSize<kUseQrQtrDorPipeline>(batch) +
|
||||
GetDqAccOffsetsSize(batch) + GetPrefixBatchSize(batch) +
|
||||
GetCuStateSize(get_num_cus()) + GetBatchStateSize(batch);
|
||||
GetCuStateSize(get_num_cus_override()) + GetBatchStateSize(batch);
|
||||
// Pad to 4K so dq_acc buffer always starts on a page-aligned boundary.
|
||||
return integer_least_multiple(raw, static_cast<size_t>(4096));
|
||||
}
|
||||
@@ -147,7 +162,7 @@ struct FmhaBwdWorkspaceManager
|
||||
}
|
||||
CK_TILE_HOST static size_t GetBatchStateOffset(const int batch)
|
||||
{
|
||||
return GetCuStateOffset(batch) + GetCuStateSize(get_num_cus());
|
||||
return GetCuStateOffset(batch) + GetCuStateSize(get_num_cus_override());
|
||||
}
|
||||
template <bool kUseQrQtrDorPipeline>
|
||||
CK_TILE_HOST static size_t GetDqAccDataOffset(const int batch)
|
||||
@@ -193,7 +208,7 @@ struct FmhaBwdWorkspaceManager
|
||||
else if constexpr(kIsGroupMode)
|
||||
{ // deterministic group mode (persistent)
|
||||
// Step 1: compute prefix_batch and target_w using per-batch seqlens
|
||||
const index_t num_cus = get_num_cus();
|
||||
const index_t num_cus = get_num_cus_override();
|
||||
auto* prefix_batch = reinterpret_cast<index_t*>(reinterpret_cast<char*>(cpu_ws) +
|
||||
GetDqAccSplitsSize<false>(batch_size) +
|
||||
GetDqAccOffsetsSize(batch_size));
|
||||
@@ -310,7 +325,7 @@ struct FmhaBwdWorkspaceManager
|
||||
}
|
||||
else // deterministic batch mode (kUsePersistent)
|
||||
{
|
||||
const index_t dqdqkdv_workers = get_num_cus();
|
||||
const index_t dqdqkdv_workers = get_num_cus_override();
|
||||
const index_t jobs_per_head = integer_divide_ceil(seqlen_k, kN0);
|
||||
const index_t total_jobs = batch_size * nhead_q * jobs_per_head;
|
||||
const index_t jobs_per_worker = integer_divide_ceil(total_jobs, dqdqkdv_workers);
|
||||
@@ -1074,7 +1089,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
const index_t jobs_per_head =
|
||||
kUseQrQtrDorPipeline ? 1 : integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0);
|
||||
if constexpr(kUsePersistent)
|
||||
return dim3(get_num_cus(), 1, 1);
|
||||
return dim3(get_num_cus_override(), 1, 1);
|
||||
else
|
||||
return dim3(jobs_per_head, nhead_, batch_size_);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user