[rocm-libraries] ROCm/rocm-libraries#6156 (commit 367565a)

[CK_TILE] Optimize FMHA head-dim padded path on gfx11/gfx12
 (#6156)

## Motivation
On gfx11/gfx12, FMHA forward kernels that require head-dim padding show
a large performance drop compared to the exact-head-dim path. In
practice, padded cases such as `HDIM=72` and `HDIM=80` were falling too
far off the fast path.

This PR improves padded-head-dim FMHA performance on gfx11/gfx12 while
keeping the behavior for other GPUs unchanged.

## Technical Details

- Add/scope a dedicated padded-head-dim (`qr_hpad`) FMHA forward path
for gfx11/gfx12.
- For `receipt=0`, keep support conservative and only enable the padded
fast path for vector-safe cases (`head_dim % 8 == 0`), matching the
existing assumption used on other GPUs.
- Move `v_prefetch` later only for the head-dim-padded path on
gfx11/gfx12. This reduces live ranges and removes the register-spill
behavior seen in the earlier scheduling.
- Enable the buffer-load OOB check offset trick for the padded path on
gfx11/gfx12.

## Test Plan

./build/bin/tile_example_fmha_fwd -prec=bf16 -mode={0/1} -b=1 -h=16
-d={72/80} -s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1}

## Test Result

Observed padded-head-dim performance improvements for HDIM=72/80:

- gfx11: about ~3.5x
- gfx1151: about ~2.0x
- gfx12: about ~1.3x

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Hosang Yoon
2026-04-08 14:53:18 +00:00
committed by assistant-librarian[bot]
parent 7d6c8e5afa
commit 4c0e73ab12
4 changed files with 144 additions and 26 deletions

View File

@@ -60,6 +60,22 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
#include "fmha_fwd.hpp"
"""
FMHA_FWD_KERNEL_HEADER_QR_HPAD = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by generate.py
#if defined(__HIP_DEVICE_COMPILE__) && \
(defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \
defined(__gfx1152__) || defined(__gfx1153__) || defined(__gfx11_generic__) || \
defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__))
#if !defined(CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK)
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1
#endif
#endif
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "fmha_fwd.hpp"
"""
FMHA_FWD_KERNEL_BODY_TEMPLATE = """
#include <iostream>
@@ -300,7 +316,7 @@ class FmhaFwdApiTrait:
return "true" # always support
else:
return "true"
elif self.pipeline_tag in ["qr", "qs"]:
elif self.pipeline_tag in ["qr", "qr_hpad", "qs"]:
if self.spad == "t":
return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
@@ -323,7 +339,7 @@ class FmhaFwdApiTrait:
return f"(a.cu_seqlen_k_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)"
else:
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
elif self.pipeline_tag in ["qr", "qs"]:
elif self.pipeline_tag in ["qr", "qr_hpad", "qs"]:
if self.skpad == "t":
return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
@@ -344,6 +360,11 @@ class FmhaFwdApiTrait:
return f"a.hdim_q % {vec} == 0"
else:
assert False
elif self.pipeline_tag == "qr_hpad":
if self.dpad == "t":
return "a.hdim_q % 8 == 0"
else:
assert False
elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dpad == "t":
@@ -361,6 +382,11 @@ class FmhaFwdApiTrait:
return f"a.hdim_v % {vec} == 0"
else:
assert False
elif self.pipeline_tag == "qr_hpad":
if self.dvpad == "t":
return "a.hdim_v % 8 == 0"
else:
assert False
elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dvpad == "t":
@@ -634,6 +660,7 @@ class FmhaFwdKernel:
F_pipeline: FmhaFwdPipeline
_KERNEL_HEADER: ClassVar[str] = FMHA_FWD_KERNEL_HEADER
_KERNEL_HEADER_QR_HPAD: ClassVar[str] = FMHA_FWD_KERNEL_HEADER_QR_HPAD
_KERNEL_BODY_TEMPLATE: ClassVar[str] = FMHA_FWD_KERNEL_BODY_TEMPLATE
@classmethod
@@ -643,6 +670,12 @@ class FmhaFwdKernel:
else:
return "ck_tile::FmhaFwdKernel"
@classmethod
def _get_kernel_header(cls, pipeline_tag):
if pipeline_tag == "qr_hpad":
return cls._KERNEL_HEADER_QR_HPAD
return cls._KERNEL_HEADER
@classmethod
def _get_cpp_kargs_creator_func_name(cls, pipeline_tag):
if pipeline_tag == "qr_async_trload_v3":
@@ -651,7 +684,9 @@ class FmhaFwdKernel:
return "fmha_fwd_create_kargs_and_grids"
def render(self) -> str:
return type(self)._KERNEL_HEADER + type(self)._KERNEL_BODY_TEMPLATE.format(
return type(self)._get_kernel_header(self.F_pipeline.tag) + type(
self
)._KERNEL_BODY_TEMPLATE.format(
F_kname=self.name,
F_arch=self.F_arch,
F_hdim=self.F_hdim,
@@ -1144,6 +1179,37 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
def supported_dtypes(cls) -> Tuple[str]:
return cls._DT_FP16_BF16
@classmethod
def get_rules(cls) -> List[CompatibilityRule]:
rules = super().get_rules()
# For gfx11 fp16/bf16 d128, use dpad=dvpad=t for the 64x32 tile:
# the exact-hdim variant (dpad=dvpad=f) is much slower here.
def check_d128_tile_pipeline(
problem_ctx: ProblemContext, kernel_ctx: KernelContext
) -> bool:
if problem_ctx.dtype not in cls._DT_FP16_BF16:
return True
if (problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128):
return True
is_64x32_tile = kernel_ctx.tile.F_bm0 == 64 and kernel_ctx.tile.F_bn0 == 32
pads_hdim = (
kernel_ctx.pipeline.F_dpad == "t" and kernel_ctx.pipeline.F_dvpad == "t"
)
exact_hdim = (
kernel_ctx.pipeline.F_dpad == "f" and kernel_ctx.pipeline.F_dvpad == "f"
)
if is_64x32_tile:
return pads_hdim
return exact_hdim
rules.append(check_d128_tile_pipeline)
return rules
@classmethod
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
if dtype in cls._DT_FP16_BF16:
@@ -1152,7 +1218,8 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
(128, 128) : [FmhaFwdTileSize( 64, 32, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, 6, CppConstraint("a.hdim_q != 128 || a.hdim_v != 128")),
FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)],
(192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(256, 256) : [FmhaFwdTileSize(128, 64, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)]
@@ -1179,7 +1246,9 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
# Keep only ttff/tttt for gfx11: ffff path is often similar or worse
# pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_hpad", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
if receipt == 1:
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
return pipelines
@@ -1251,7 +1320,9 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory):
):
# pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_hpad", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
if receipt == 1:
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32:
# no need lse/dropout kernels
for logits, qscale, mask, bias in itertools.product(