[rocm-libraries] ROCm/rocm-libraries#6156 (commit 367565a)

[CK_TILE] Optimize FMHA head-dim padded path on gfx11/gfx12
 (#6156)

## Motivation
On gfx11/gfx12, FMHA forward kernels that require head-dim padding show
a large performance drop compared to the exact-head-dim path. In
practice, padded cases such as `HDIM=72` and `HDIM=80` were falling too
far off the fast path.

This PR improves padded-head-dim FMHA performance on gfx11/gfx12 while
keeping the behavior for other GPUs unchanged.

## Technical Details

- Add/scope a dedicated padded-head-dim (`qr_hpad`) FMHA forward path
for gfx11/gfx12.
- For `receipt=0`, keep support conservative and only enable the padded
fast path for vector-safe cases (`head_dim % 8 == 0`), matching the
existing assumption used on other GPUs.
- Move `v_prefetch` later only for the head-dim-padded path on
gfx11/gfx12. This reduces live ranges and removes the register-spill
behavior seen in the earlier scheduling.
- Enable the buffer-load OOB check offset trick for the padded path on
gfx11/gfx12.

## Test Plan

./build/bin/tile_example_fmha_fwd -prec=bf16 -mode={0/1} -b=1 -h=16
-d={72/80} -s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1}

## Test Result

Observed padded-head-dim performance improvements for HDIM=72/80:

- gfx11: about ~3.5x
- gfx1151: about ~2.0x
- gfx12: about ~1.3x

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Hosang Yoon
2026-04-08 14:53:18 +00:00
committed by assistant-librarian[bot]
parent 7d6c8e5afa
commit 4c0e73ab12
4 changed files with 144 additions and 26 deletions

View File

@@ -13,6 +13,7 @@ enum class BlockFmhaPipelineEnum
QSKSVS,
QRKSVS_ASYNC_TRLOAD,
QRKSVS_ASYNC_TRLOAD_V3,
QRKSVS_HPAD,
};
template <BlockFmhaPipelineEnum>
@@ -40,4 +41,10 @@ struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD>
static constexpr const char* name = "qr_async_trload";
};
template <>
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS_HPAD>
{
static constexpr const char* name = "qr_hpad";
};
} // namespace ck_tile

View File

@@ -14,7 +14,9 @@
namespace ck_tile {
// This pipeline is qkv all located in LDS
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
template <typename Problem_,
typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy,
bool PaddedVecLoadStore_ = false>
struct BlockFmhaPipelineQRKSVS
{
using Problem = remove_cvref_t<Problem_>;
@@ -54,17 +56,18 @@ struct BlockFmhaPipelineQRKSVS
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr auto QScaleEnum = Problem::QScaleEnum;
static constexpr bool kHasSink = Problem::kHasSink;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr auto QScaleEnum = Problem::QScaleEnum;
static constexpr bool kHasSink = Problem::kHasSink;
static constexpr bool kPaddedVecLoadStore = PaddedVecLoadStore_;
static constexpr ck_tile::index_t kQKScaleGranularity = Problem::kQKScaleGranularity;
static constexpr ck_tile::index_t kVScaleGranularity = Problem::kVScaleGranularity;
@@ -80,23 +83,29 @@ struct BlockFmhaPipelineQRKSVS
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
!kHasLogitsSoftCap)) ||
(!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap));
static_assert(!kPaddedVecLoadStore || (kPadHeadDimQ && kPadHeadDimV),
"padded vector load/store fast path only applies to padded head-dim kernels");
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ = kPadHeadDimQ ? numeric_traits<QDataType>::PackedSize
: Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK = kPadHeadDimQ ? numeric_traits<KDataType>::PackedSize
: Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentQ = (kPadHeadDimQ && !kPaddedVecLoadStore)
? numeric_traits<QDataType>::PackedSize
: Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK = (kPadHeadDimQ && !kPaddedVecLoadStore)
? numeric_traits<KDataType>::PackedSize
: Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
return (kPadHeadDimV && !kPaddedVecLoadStore)
? 1
: Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? numeric_traits<VDataType>::PackedSize
: Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
(kPadHeadDimV && !kPaddedVecLoadStore) ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
static constexpr index_t kAlignmentRandVal =
@@ -548,8 +557,25 @@ struct BlockFmhaPipelineQRKSVS
});
}
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
{ // tail
auto v_prefetch = decltype(load_tile(v_dram_window)){};
enum class VPrefetchPoint
{
BeforeGemm0Tail,
AfterGemm0Tail,
AfterSoftmax
};
#if defined(__gfx11__) || defined(__gfx12__)
constexpr auto kVPrefetch =
kPadHeadDimV ? VPrefetchPoint::AfterSoftmax : VPrefetchPoint::AfterGemm0Tail;
#else
constexpr auto kVPrefetch = VPrefetchPoint::BeforeGemm0Tail;
#endif
if constexpr(kVPrefetch == VPrefetchPoint::BeforeGemm0Tail)
{
load_tile(v_prefetch, v_dram_window); // prefetch load v tile
}
{ // tail
block_sync_lds();
run_gemm_0(number<k0_loops - 2>{});
block_sync_lds();
@@ -562,6 +588,10 @@ struct BlockFmhaPipelineQRKSVS
run_gemm_0(number<k0_loops - 1>{});
}
if constexpr(kVPrefetch == VPrefetchPoint::AfterGemm0Tail)
{
load_tile(v_prefetch, v_dram_window);
}
// dequant
auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() {
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
@@ -819,6 +849,11 @@ struct BlockFmhaPipelineQRKSVS
randval_ptr, seq_offset, p_compute, randval_dram_window);
}
if constexpr(kVPrefetch == VPrefetchPoint::AfterSoftmax)
{
load_tile(v_prefetch, v_dram_window);
}
block_sync_lds();
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
@@ -1098,4 +1133,7 @@ struct BlockFmhaPipelineQRKSVS
}
};
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
using BlockFmhaPipelineQRKSVSHpad = BlockFmhaPipelineQRKSVS<Problem_, Policy_, true>;
} // namespace ck_tile