[CK_TILE] fMHA batch_prefill block index & logits soft-capping optimizations (#2198)

* Write soft-sign in inline asm

* Change tile idx computation

* Add macro to turn off soft-sign asm opt

* Use simple for loop to avoid register spill

* Only do block id transform for masking cases
This commit is contained in:
Po Yen Chen
2025-05-16 15:14:46 +08:00
committed by GitHub
parent 8cb0474b3d
commit 791802b381
3 changed files with 63 additions and 9 deletions

View File

@@ -651,8 +651,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
};
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
if constexpr(kHasMask)
{
// assume that num_tile_n1 is always 1
return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
}
else
{
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
}
else
{
@@ -672,7 +679,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
if constexpr(kHasMask)
{
// assume that num_tile_n1 is always 1
return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
}
else
{
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
}
}