From 7ce58f6b46f5f225505f3ebda6d5cbf450f8fcb9 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Tue, 27 May 2025 10:58:58 +0800 Subject: [PATCH] [CK_TILE] For FMHA forward kernels, assign block indices reversely if using mask (#2209) * Assign block indices reversely if kHasMask=true * Assign block indices reversely for splitkv kernel [ROCm/composable_kernel commit: c42b957d654826bd9c218ccb66225865019a5140] --- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 20 +++++++++++++++++-- .../fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 11 +++++++++- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index bedf20626f..ac37f5dd06 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -983,7 +983,15 @@ struct FmhaFwdKernel const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + if constexpr(kHasMask) + { + // assume that num_tile_n1 is always 1 + return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } } else { @@ -1003,7 +1011,15 @@ struct FmhaFwdKernel const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + if constexpr(kHasMask) + { + // assume that num_tile_n1 is always 1 + return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } } } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 63011d2ba9..501aa26667 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -561,7 +561,16 @@ struct FmhaFwdSplitKVKernel const index_t i_nhead = blockIdx.y; const index_t i_batch = blockIdx.z; - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch); + if constexpr(kHasMask) + { + // assume that num_tile_n1 is always 1 + return ck_tile::make_tuple( + (gridDim.x / kargs.num_splits) - 1 - i_tile_m, i_tile_n, i_split, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch); + } } __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }