From a6d55da47f6e9e360571d0208d0c5fa78ae9a7c1 Mon Sep 17 00:00:00 2001 From: liang <38024827+smallmou@users.noreply.github.com> Date: Sat, 26 Jul 2025 02:46:55 +0800 Subject: [PATCH] reorder grid dim schedule (#2533) Co-authored-by: smallmou Co-authored-by: Po Yen Chen [ROCm/composable_kernel commit: d2459878cf993565b8f55f1c1c0915251b944105] --- include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 561e5fb00a..8d257a3329 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -955,9 +955,9 @@ struct FmhaFwdKernel else { // TODO: this may need tuning - return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + return dim3(nhead_, + ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), - nhead_, batch_size_); } } @@ -1003,8 +1003,8 @@ struct FmhaFwdKernel const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); - const index_t i_block = blockIdx.x; - const index_t i_nhead = blockIdx.y; + const index_t i_block = blockIdx.y; // blockIdx.x + const index_t i_nhead = blockIdx.x; // blockIdx.y const index_t i_batch = blockIdx.z; const auto f = [](index_t dividend, index_t divisor) { @@ -1018,7 +1018,7 @@ struct FmhaFwdKernel if constexpr(kHasMask) { // assume that num_tile_n1 is always 1 - return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + return ck_tile::make_tuple(gridDim.y - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); } else {