mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
[CK][CK_TILE] Fix FMHA codegen group mode dispatch (#6764)
## Motivation FMHA codegen had incorrect dispatch behavior in group mode. Two root causes: 1. Wrong field names in dispatch conditions — Used batch-mode fields (seqlen_q, seqlen_k) instead of group-mode fields (max_seqlen_q, max_seqlen_k), causing wrong kernel selection at runtime on gfx950. 2. Missing kernel variants — Group mode was overly filtered out from smaller-tile specializations (bwd) and lacked spatial-padding pipeline variants on gfx950 (fwd). gfx942 don't support trload pipeline. ## Technical Details fmha_bwd.py: - max_seq_q_cond and extra_cond now emit t.max_seqlen_q / t.max_seqlen_k for group mode. - Relaxed kernel filtering: group mode no longer skips tiles with max_seq_q != 0. fmha_fwd.py: - get_bm0_cond emits a.max_seqlen_q for group mode tile-size dispatch. - Added two qr_async_trload pipeline variants with spatial padding for gfx950 group mode. ## Test Plan Triggering AITER CI job: ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
@@ -827,14 +827,20 @@ class FmhaBwdApiTrait:
|
||||
@property
|
||||
def max_seq_q_cond(self) -> str:
|
||||
if self.tile.max_seq_q != 0:
|
||||
return f" && (t.seqlen_q <= {self.tile.max_seq_q})"
|
||||
if self.mode == "group":
|
||||
return f" && (t.max_seqlen_q <= {self.tile.max_seq_q})"
|
||||
else:
|
||||
return f" && (t.seqlen_q <= {self.tile.max_seq_q})"
|
||||
else:
|
||||
return ""
|
||||
|
||||
@property
|
||||
def extra_cond(self) -> str:
|
||||
if self.tr_load == "t" and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128 and self.tile.F_bhdq == 128:
|
||||
return " && (t.seqlen_k <= 256)"
|
||||
if self.mode == "group":
|
||||
return " && (t.max_seqlen_k <= 256)"
|
||||
else:
|
||||
return " && (t.seqlen_k <= 256)"
|
||||
else:
|
||||
return ""
|
||||
|
||||
@@ -1057,7 +1063,7 @@ def get_bwd_blobs(
|
||||
hdim = tile.F_bhdq
|
||||
if (mode == "group") and (spad1d == "f"):
|
||||
continue
|
||||
if (mode == "group" or ("no" not in mask)) and tile.max_seq_q != 0:
|
||||
if ("no" not in mask) and tile.max_seq_q != 0:
|
||||
continue
|
||||
if (bias == "no" or bias == "alibi") and dbias == "t":
|
||||
continue
|
||||
|
||||
@@ -328,6 +328,8 @@ class FmhaFwdApiTrait:
|
||||
if self.bm0 == max_bm0 or self.bm0 == 64:
|
||||
return "true/*fall back to largest tile*/"
|
||||
else:
|
||||
if self.mode == "group":
|
||||
return f"a.max_seqlen_q <= {self.bm0}"
|
||||
return f"a.seqlen_q <= {self.bm0}"
|
||||
|
||||
@property
|
||||
@@ -1136,6 +1138,7 @@ class KernelComponentFactoryGfx950(
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip # group mode spad
|
||||
|
||||
# # qr_async_trload_v3 bf16/fp16 not ready
|
||||
# if (hdim, hdim_v) == (128, 128):
|
||||
|
||||
Reference in New Issue
Block a user