From c1d2cf086946854faaaae5aaf8aa2b73fce56f69 Mon Sep 17 00:00:00 2001 From: ArthurLiu Date: Tue, 28 Apr 2026 02:14:42 +0800 Subject: [PATCH] [CK][CK_TILE] Fix FMHA codegen group mode dispatch (#6764) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation FMHA codegen had incorrect dispatch behavior in group mode. Two root causes: 1. Wrong field names in dispatch conditions — Used batch-mode fields (seqlen_q, seqlen_k) instead of group-mode fields (max_seqlen_q, max_seqlen_k), causing wrong kernel selection at runtime on gfx950. 2. Missing kernel variants — Group mode was overly filtered out from smaller-tile specializations (bwd) and lacked spatial-padding pipeline variants on gfx950 (fwd). gfx942 don't support trload pipeline. ## Technical Details fmha_bwd.py: - max_seq_q_cond and extra_cond now emit t.max_seqlen_q / t.max_seqlen_k for group mode. - Relaxed kernel filtering: group mode no longer skips tiles with max_seq_q != 0. fmha_fwd.py: - get_bm0_cond emits a.max_seqlen_q for group mode tile-size dispatch. - Added two qr_async_trload pipeline variants with spatial padding for gfx950 group mode. ## Test Plan Triggering AITER CI job: ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 12 +++++++++--- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 3 +++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 7105f1aa5c..5e7e2a2ffd 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -827,14 +827,20 @@ class FmhaBwdApiTrait: @property def max_seq_q_cond(self) -> str: if self.tile.max_seq_q != 0: - return f" && (t.seqlen_q <= {self.tile.max_seq_q})" + if self.mode == "group": + return f" && (t.max_seqlen_q <= {self.tile.max_seq_q})" + else: + return f" && (t.seqlen_q <= {self.tile.max_seq_q})" else: return "" @property def extra_cond(self) -> str: if self.tr_load == "t" and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128 and self.tile.F_bhdq == 128: - return " && (t.seqlen_k <= 256)" + if self.mode == "group": + return " && (t.max_seqlen_k <= 256)" + else: + return " && (t.seqlen_k <= 256)" else: return "" @@ -1057,7 +1063,7 @@ def get_bwd_blobs( hdim = tile.F_bhdq if (mode == "group") and (spad1d == "f"): continue - if (mode == "group" or ("no" not in mask)) and tile.max_seq_q != 0: + if ("no" not in mask) and tile.max_seq_q != 0: continue if (bias == "no" or bias == "alibi") and dbias == "t": continue diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 542bf2f2fa..635f44037c 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -328,6 +328,8 @@ class FmhaFwdApiTrait: if self.bm0 == max_bm0 or self.bm0 == 64: return "true/*fall back to largest tile*/" else: + if self.mode == "group": + return f"a.max_seqlen_q <= {self.bm0}" return f"a.seqlen_q <= {self.bm0}" @property @@ -1136,6 +1138,7 @@ class KernelComponentFactoryGfx950( ): pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip # group mode spad # # qr_async_trload_v3 bf16/fp16 not ready # if (hdim, hdim_v) == (128, 128):