diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index dd65c0298b..81c7b067d3 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -315,7 +315,7 @@ class FmhaFwdApiTrait: assert False def seqtune(self, max_bm0: int) -> str: - if self.bm0 == max_bm0: + if self.bm0 == max_bm0 or self.bm0 == 64: return "true/*fall back to largest tile*/" else: return f"a.seqlen_q <= {self.bm0}" @@ -847,6 +847,11 @@ class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory): (problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128) and kernel_ctx.tile.F_bm0 != 128 ) + or ( + (problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128) + and kernel_ctx.pipeline.tag != "qr_async" + and kernel_ctx.tile.F_bk0 == 64 + ) ): # non qr_async_trload only support km0=128 tile size when hdim is not 128 # non qr_async only support kn0=128 tile size when hdim is 128 @@ -942,6 +947,7 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9): ( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128, 128) : [FmhaFwdTileSize( 16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16, -1, CppConstraint('get_num_blocks(64) <= num_cus')), FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 7224ed3a70..e30d4215d6 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -321,6 +321,8 @@ struct BlockFmhaPipelineQRKSVSAsync { if(num_total_loop <= 0) { + buffer_load_fence(0); // rocm-7.1.1, if whole tile is masked out, need to fence(0) + // otherwise will have compute error(maybe compiler bug?) if constexpr(kStoreLSE) { auto lse =