mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] For FMHA forward kernels, assign block indices reversely if using mask (#2209)
* Assign block indices reversely if kHasMask=true * Assign block indices reversely for splitkv kernel
This commit is contained in:
@@ -983,7 +983,15 @@ struct FmhaFwdKernel
|
||||
|
||||
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
// assume that num_tile_n1 is always 1
|
||||
return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1003,7 +1011,15 @@ struct FmhaFwdKernel
|
||||
|
||||
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
// assume that num_tile_n1 is always 1
|
||||
return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -561,7 +561,16 @@ struct FmhaFwdSplitKVKernel
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
// assume that num_tile_n1 is always 1
|
||||
return ck_tile::make_tuple(
|
||||
(gridDim.x / kargs.num_splits) - 1 - i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
|
||||
}
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
Reference in New Issue
Block a user