mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE][FMHA]Add new tile size for async (#3623)
* Revert "Revert "[CK_TILE][FMHA] Add new tile size for async (#3586)" (#3613)"
This reverts commit cfdad49edda4b2ccef92571f23646a8505bb2859.
* Add new tile_size for async pipeline
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
---------
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
[ROCm/composable_kernel commit: 0b13697a88]
This commit is contained in:
@@ -315,7 +315,7 @@ class FmhaFwdApiTrait:
|
||||
assert False
|
||||
|
||||
def seqtune(self, max_bm0: int) -> str:
|
||||
if self.bm0 == max_bm0:
|
||||
if self.bm0 == max_bm0 or self.bm0 == 64:
|
||||
return "true/*fall back to largest tile*/"
|
||||
else:
|
||||
return f"a.seqlen_q <= {self.bm0}"
|
||||
@@ -847,6 +847,11 @@ class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory):
|
||||
(problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128)
|
||||
and kernel_ctx.tile.F_bm0 != 128
|
||||
)
|
||||
or (
|
||||
(problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128)
|
||||
and kernel_ctx.pipeline.tag != "qr_async"
|
||||
and kernel_ctx.tile.F_bk0 == 64
|
||||
)
|
||||
):
|
||||
# non qr_async_trload only support km0=128 tile size when hdim is not 128
|
||||
# non qr_async only support kn0=128 tile size when hdim is 128
|
||||
@@ -942,6 +947,7 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
|
||||
( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(128, 128) : [FmhaFwdTileSize( 16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16, -1, CppConstraint('get_num_blocks(64) <= num_cus')),
|
||||
FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
|
||||
@@ -329,6 +329,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
{
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
buffer_load_fence(0); // rocm-7.1.1, if whole tile is masked out, need to fence(0)
|
||||
// otherwise will have compute error(maybe compiler bug?)
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
auto lse =
|
||||
@@ -345,10 +347,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
|
||||
// otherwise will have compute error(maybe compiler bug?)
|
||||
|
||||
// Note: here occ are all cleard, return it
|
||||
// Note: here occ are all cleared, return it
|
||||
return o_acc;
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check
|
||||
|
||||
Reference in New Issue
Block a user