[CK_TILE] Optimize FMHA head-dim padded path on gfx11/gfx12 (#6156)

## Motivation
On gfx11/gfx12, FMHA forward kernels that require head-dim padding show
a large performance drop compared to the exact-head-dim path. In
practice, padded cases such as `HDIM=72` and `HDIM=80` were falling too
far off the fast path.

This PR improves padded-head-dim FMHA performance on gfx11/gfx12 while
keeping the behavior for other GPUs unchanged.

## Technical Details

- Add/scope a dedicated padded-head-dim (`qr_hpad`) FMHA forward path
for gfx11/gfx12.
- For `receipt=0`, keep support conservative and only enable the padded
fast path for vector-safe cases (`head_dim % 8 == 0`), matching the
existing assumption used on other GPUs.
- Move `v_prefetch` later only for the head-dim-padded path on
gfx11/gfx12. This reduces live ranges and removes the register-spill
behavior seen in the earlier scheduling.
- Enable the buffer-load OOB check offset trick for the padded path on
gfx11/gfx12.

## Test Plan

./build/bin/tile_example_fmha_fwd -prec=bf16 -mode={0/1} -b=1 -h=16
-d={72/80} -s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1}

## Test Result

Observed padded-head-dim performance improvements for HDIM=72/80:

- gfx11: about ~3.5x
- gfx1151: about ~2.0x
- gfx12: about ~1.3x


## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Hosang Yoon
2026-04-08 10:51:53 -04:00
committed by GitHub
parent c953982434
commit 65ad35becd
4 changed files with 144 additions and 26 deletions

View File

@@ -139,6 +139,7 @@ LAYOUT_MAP = {"row": "true", "col": "false"}
PIPELINE_MAP = {
"qr": "ck_tile::BlockFmhaPipelineQRKSVS",
"qr_hpad": "ck_tile::BlockFmhaPipelineQRKSVSHpad",
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync",
"qs": "ck_tile::BlockFmhaPipelineQSKSVS",
"qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload",
@@ -147,6 +148,7 @@ PIPELINE_MAP = {
PIPELINE_ENUM_MAP = {
"qr": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_hpad": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_HPAD",
"qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
"qr_nwarp_sshuffle": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS",

View File

@@ -60,6 +60,22 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
#include "fmha_fwd.hpp"
"""
FMHA_FWD_KERNEL_HEADER_QR_HPAD = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by generate.py
#if defined(__HIP_DEVICE_COMPILE__) && \
(defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \
defined(__gfx1152__) || defined(__gfx1153__) || defined(__gfx11_generic__) || \
defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__))
#if !defined(CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK)
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1
#endif
#endif
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "fmha_fwd.hpp"
"""
FMHA_FWD_KERNEL_BODY_TEMPLATE = """
#include <iostream>
@@ -300,7 +316,7 @@ class FmhaFwdApiTrait:
return "true" # always support
else:
return "true"
elif self.pipeline_tag in ["qr", "qs"]:
elif self.pipeline_tag in ["qr", "qr_hpad", "qs"]:
if self.spad == "t":
return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
@@ -323,7 +339,7 @@ class FmhaFwdApiTrait:
return f"(a.cu_seqlen_k_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)"
else:
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
elif self.pipeline_tag in ["qr", "qs"]:
elif self.pipeline_tag in ["qr", "qr_hpad", "qs"]:
if self.skpad == "t":
return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
@@ -344,6 +360,11 @@ class FmhaFwdApiTrait:
return f"a.hdim_q % {vec} == 0"
else:
assert False
elif self.pipeline_tag == "qr_hpad":
if self.dpad == "t":
return "a.hdim_q % 8 == 0"
else:
assert False
elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dpad == "t":
@@ -361,6 +382,11 @@ class FmhaFwdApiTrait:
return f"a.hdim_v % {vec} == 0"
else:
assert False
elif self.pipeline_tag == "qr_hpad":
if self.dvpad == "t":
return "a.hdim_v % 8 == 0"
else:
assert False
elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dvpad == "t":
@@ -634,6 +660,7 @@ class FmhaFwdKernel:
F_pipeline: FmhaFwdPipeline
_KERNEL_HEADER: ClassVar[str] = FMHA_FWD_KERNEL_HEADER
_KERNEL_HEADER_QR_HPAD: ClassVar[str] = FMHA_FWD_KERNEL_HEADER_QR_HPAD
_KERNEL_BODY_TEMPLATE: ClassVar[str] = FMHA_FWD_KERNEL_BODY_TEMPLATE
@classmethod
@@ -643,6 +670,12 @@ class FmhaFwdKernel:
else:
return "ck_tile::FmhaFwdKernel"
@classmethod
def _get_kernel_header(cls, pipeline_tag):
if pipeline_tag == "qr_hpad":
return cls._KERNEL_HEADER_QR_HPAD
return cls._KERNEL_HEADER
@classmethod
def _get_cpp_kargs_creator_func_name(cls, pipeline_tag):
if pipeline_tag == "qr_async_trload_v3":
@@ -651,7 +684,9 @@ class FmhaFwdKernel:
return "fmha_fwd_create_kargs_and_grids"
def render(self) -> str:
return type(self)._KERNEL_HEADER + type(self)._KERNEL_BODY_TEMPLATE.format(
return type(self)._get_kernel_header(self.F_pipeline.tag) + type(
self
)._KERNEL_BODY_TEMPLATE.format(
F_kname=self.name,
F_arch=self.F_arch,
F_hdim=self.F_hdim,
@@ -1144,6 +1179,37 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
def supported_dtypes(cls) -> Tuple[str]:
return cls._DT_FP16_BF16
@classmethod
def get_rules(cls) -> List[CompatibilityRule]:
rules = super().get_rules()
# For gfx11 fp16/bf16 d128, use dpad=dvpad=t for the 64x32 tile:
# the exact-hdim variant (dpad=dvpad=f) is much slower here.
def check_d128_tile_pipeline(
problem_ctx: ProblemContext, kernel_ctx: KernelContext
) -> bool:
if problem_ctx.dtype not in cls._DT_FP16_BF16:
return True
if (problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128):
return True
is_64x32_tile = kernel_ctx.tile.F_bm0 == 64 and kernel_ctx.tile.F_bn0 == 32
pads_hdim = (
kernel_ctx.pipeline.F_dpad == "t" and kernel_ctx.pipeline.F_dvpad == "t"
)
exact_hdim = (
kernel_ctx.pipeline.F_dpad == "f" and kernel_ctx.pipeline.F_dvpad == "f"
)
if is_64x32_tile:
return pads_hdim
return exact_hdim
rules.append(check_d128_tile_pipeline)
return rules
@classmethod
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
if dtype in cls._DT_FP16_BF16:
@@ -1152,7 +1218,8 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
(128, 128) : [FmhaFwdTileSize( 64, 32, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, 6, CppConstraint("a.hdim_q != 128 || a.hdim_v != 128")),
FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)],
(192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(256, 256) : [FmhaFwdTileSize(128, 64, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)]
@@ -1179,7 +1246,9 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
# Keep only ttff/tttt for gfx11: ffff path is often similar or worse
# pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_hpad", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
if receipt == 1:
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
return pipelines
@@ -1251,7 +1320,9 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory):
):
# pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_hpad", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
if receipt == 1:
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32:
# no need lse/dropout kernels
for logits, qscale, mask, bias in itertools.product(

View File

@@ -13,6 +13,7 @@ enum class BlockFmhaPipelineEnum
QSKSVS,
QRKSVS_ASYNC_TRLOAD,
QRKSVS_ASYNC_TRLOAD_V3,
QRKSVS_HPAD,
};
template <BlockFmhaPipelineEnum>
@@ -40,4 +41,10 @@ struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD>
static constexpr const char* name = "qr_async_trload";
};
template <>
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS_HPAD>
{
static constexpr const char* name = "qr_hpad";
};
} // namespace ck_tile

View File

@@ -14,7 +14,9 @@
namespace ck_tile {
// This pipeline is qkv all located in LDS
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
template <typename Problem_,
typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy,
bool PaddedVecLoadStore_ = false>
struct BlockFmhaPipelineQRKSVS
{
using Problem = remove_cvref_t<Problem_>;
@@ -54,17 +56,18 @@ struct BlockFmhaPipelineQRKSVS
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr auto QScaleEnum = Problem::QScaleEnum;
static constexpr bool kHasSink = Problem::kHasSink;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr auto QScaleEnum = Problem::QScaleEnum;
static constexpr bool kHasSink = Problem::kHasSink;
static constexpr bool kPaddedVecLoadStore = PaddedVecLoadStore_;
static constexpr ck_tile::index_t kQKScaleGranularity = Problem::kQKScaleGranularity;
static constexpr ck_tile::index_t kVScaleGranularity = Problem::kVScaleGranularity;
@@ -80,23 +83,29 @@ struct BlockFmhaPipelineQRKSVS
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
!kHasLogitsSoftCap)) ||
(!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap));
static_assert(!kPaddedVecLoadStore || (kPadHeadDimQ && kPadHeadDimV),
"padded vector load/store fast path only applies to padded head-dim kernels");
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ = kPadHeadDimQ ? numeric_traits<QDataType>::PackedSize
: Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK = kPadHeadDimQ ? numeric_traits<KDataType>::PackedSize
: Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentQ = (kPadHeadDimQ && !kPaddedVecLoadStore)
? numeric_traits<QDataType>::PackedSize
: Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK = (kPadHeadDimQ && !kPaddedVecLoadStore)
? numeric_traits<KDataType>::PackedSize
: Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
return (kPadHeadDimV && !kPaddedVecLoadStore)
? 1
: Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? numeric_traits<VDataType>::PackedSize
: Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
(kPadHeadDimV && !kPaddedVecLoadStore) ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
static constexpr index_t kAlignmentRandVal =
@@ -548,8 +557,25 @@ struct BlockFmhaPipelineQRKSVS
});
}
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
{ // tail
auto v_prefetch = decltype(load_tile(v_dram_window)){};
enum class VPrefetchPoint
{
BeforeGemm0Tail,
AfterGemm0Tail,
AfterSoftmax
};
#if defined(__gfx11__) || defined(__gfx12__)
constexpr auto kVPrefetch =
kPadHeadDimV ? VPrefetchPoint::AfterSoftmax : VPrefetchPoint::AfterGemm0Tail;
#else
constexpr auto kVPrefetch = VPrefetchPoint::BeforeGemm0Tail;
#endif
if constexpr(kVPrefetch == VPrefetchPoint::BeforeGemm0Tail)
{
load_tile(v_prefetch, v_dram_window); // prefetch load v tile
}
{ // tail
block_sync_lds();
run_gemm_0(number<k0_loops - 2>{});
block_sync_lds();
@@ -562,6 +588,10 @@ struct BlockFmhaPipelineQRKSVS
run_gemm_0(number<k0_loops - 1>{});
}
if constexpr(kVPrefetch == VPrefetchPoint::AfterGemm0Tail)
{
load_tile(v_prefetch, v_dram_window);
}
// dequant
auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() {
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
@@ -819,6 +849,11 @@ struct BlockFmhaPipelineQRKSVS
randval_ptr, seq_offset, p_compute, randval_dram_window);
}
if constexpr(kVPrefetch == VPrefetchPoint::AfterSoftmax)
{
load_tile(v_prefetch, v_dram_window);
}
block_sync_lds();
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
@@ -1098,4 +1133,7 @@ struct BlockFmhaPipelineQRKSVS
}
};
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
using BlockFmhaPipelineQRKSVSHpad = BlockFmhaPipelineQRKSVS<Problem_, Policy_, true>;
} // namespace ck_tile