[CK_TILE] For FMHA forward kernels, assign block indices reversely if using mask (#2209)

* Assign block indices reversely if kHasMask=true

* Assign block indices reversely for splitkv kernel
This commit is contained in:
Po Yen Chen
2025-05-27 10:58:58 +08:00
committed by GitHub
parent 5727af98d1
commit c42b957d65
2 changed files with 28 additions and 3 deletions

View File

@@ -983,7 +983,15 @@ struct FmhaFwdKernel
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
if constexpr(kHasMask)
{
// assume that num_tile_n1 is always 1
return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
}
else
{
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
}
else
{
@@ -1003,7 +1011,15 @@ struct FmhaFwdKernel
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
if constexpr(kHasMask)
{
// assume that num_tile_n1 is always 1
return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
}
else
{
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
}
}

View File

@@ -561,7 +561,16 @@ struct FmhaFwdSplitKVKernel
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
if constexpr(kHasMask)
{
// assume that num_tile_n1 is always 1
return ck_tile::make_tuple(
(gridDim.x / kargs.num_splits) - 1 - i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
}
else
{
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
}
}
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }