From 34920de53d0d5f6b1bd2aa2ca24cd8d5b7d98e03 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Tue, 27 May 2025 03:24:26 +0000 Subject: [PATCH] Merge commit 'c42b957d654826bd9c218ccb66225865019a5140' into develop --- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 20 +++++++++++++++++-- .../fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 11 +++++++++- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index bedf20626f..ac37f5dd06 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -983,7 +983,15 @@ struct FmhaFwdKernel const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + if constexpr(kHasMask) + { + // assume that num_tile_n1 is always 1 + return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } } else { @@ -1003,7 +1011,15 @@ struct FmhaFwdKernel const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + if constexpr(kHasMask) + { + // assume that num_tile_n1 is always 1 + return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } } } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 63011d2ba9..501aa26667 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -561,7 +561,16 @@ struct FmhaFwdSplitKVKernel const index_t i_nhead = blockIdx.y; const index_t i_batch = blockIdx.z; - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch); + if constexpr(kHasMask) + { + // assume that num_tile_n1 is always 1 + return ck_tile::make_tuple( + (gridDim.x / kargs.num_splits) - 1 - i_tile_m, i_tile_n, i_split, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch); + } } __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }