[CK][CK_TILE] Fix FMHA codegen group mode dispatch (#6764)

## Motivation

FMHA codegen had incorrect dispatch behavior in group mode. Two root
causes:

1. Wrong field names in dispatch conditions — Used batch-mode fields
(seqlen_q, seqlen_k) instead of group-mode fields (max_seqlen_q,
max_seqlen_k), causing wrong kernel selection at runtime on gfx950.
2. Missing kernel variants — Group mode was overly filtered out from
smaller-tile specializations (bwd) and lacked spatial-padding pipeline
variants on gfx950 (fwd).

gfx942 don't support trload pipeline. 

## Technical Details

 fmha_bwd.py:
- max_seq_q_cond and extra_cond now emit t.max_seqlen_q / t.max_seqlen_k
for group mode.
- Relaxed kernel filtering: group mode no longer skips tiles with
max_seq_q != 0.

  fmha_fwd.py:
  - get_bm0_cond emits a.max_seqlen_q for group mode tile-size dispatch.
- Added two qr_async_trload pipeline variants with spatial padding for
gfx950 group mode.
  
## Test Plan
Triggering AITER CI job:

## Submission Checklist

- [ x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
ArthurLiu
2026-04-28 02:14:42 +08:00
committed by GitHub
parent 0ebeb88ba9
commit c1d2cf0869
2 changed files with 12 additions and 3 deletions

View File

@@ -827,14 +827,20 @@ class FmhaBwdApiTrait:
@property
def max_seq_q_cond(self) -> str:
if self.tile.max_seq_q != 0:
return f" && (t.seqlen_q <= {self.tile.max_seq_q})"
if self.mode == "group":
return f" && (t.max_seqlen_q <= {self.tile.max_seq_q})"
else:
return f" && (t.seqlen_q <= {self.tile.max_seq_q})"
else:
return ""
@property
def extra_cond(self) -> str:
if self.tr_load == "t" and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128 and self.tile.F_bhdq == 128:
return " && (t.seqlen_k <= 256)"
if self.mode == "group":
return " && (t.max_seqlen_k <= 256)"
else:
return " && (t.seqlen_k <= 256)"
else:
return ""
@@ -1057,7 +1063,7 @@ def get_bwd_blobs(
hdim = tile.F_bhdq
if (mode == "group") and (spad1d == "f"):
continue
if (mode == "group" or ("no" not in mask)) and tile.max_seq_q != 0:
if ("no" not in mask) and tile.max_seq_q != 0:
continue
if (bias == "no" or bias == "alibi") and dbias == "t":
continue

View File

@@ -328,6 +328,8 @@ class FmhaFwdApiTrait:
if self.bm0 == max_bm0 or self.bm0 == 64:
return "true/*fall back to largest tile*/"
else:
if self.mode == "group":
return f"a.max_seqlen_q <= {self.bm0}"
return f"a.seqlen_q <= {self.bm0}"
@property
@@ -1136,6 +1138,7 @@ class KernelComponentFactoryGfx950(
):
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip # group mode spad
# # qr_async_trload_v3 bf16/fp16 not ready
# if (hdim, hdim_v) == (128, 128):