mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
[rocm-libraries] ROCm/rocm-libraries#6156 (commit 367565a)
[CK_TILE] Optimize FMHA head-dim padded path on gfx11/gfx12 (#6156) ## Motivation On gfx11/gfx12, FMHA forward kernels that require head-dim padding show a large performance drop compared to the exact-head-dim path. In practice, padded cases such as `HDIM=72` and `HDIM=80` were falling too far off the fast path. This PR improves padded-head-dim FMHA performance on gfx11/gfx12 while keeping the behavior for other GPUs unchanged. ## Technical Details - Add/scope a dedicated padded-head-dim (`qr_hpad`) FMHA forward path for gfx11/gfx12. - For `receipt=0`, keep support conservative and only enable the padded fast path for vector-safe cases (`head_dim % 8 == 0`), matching the existing assumption used on other GPUs. - Move `v_prefetch` later only for the head-dim-padded path on gfx11/gfx12. This reduces live ranges and removes the register-spill behavior seen in the earlier scheduling. - Enable the buffer-load OOB check offset trick for the padded path on gfx11/gfx12. ## Test Plan ./build/bin/tile_example_fmha_fwd -prec=bf16 -mode={0/1} -b=1 -h=16 -d={72/80} -s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1} ## Test Result Observed padded-head-dim performance improvements for HDIM=72/80: - gfx11: about ~3.5x - gfx1151: about ~2.0x - gfx12: about ~1.3x ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
7d6c8e5afa
commit
4c0e73ab12
@@ -13,6 +13,7 @@ enum class BlockFmhaPipelineEnum
|
||||
QSKSVS,
|
||||
QRKSVS_ASYNC_TRLOAD,
|
||||
QRKSVS_ASYNC_TRLOAD_V3,
|
||||
QRKSVS_HPAD,
|
||||
};
|
||||
|
||||
template <BlockFmhaPipelineEnum>
|
||||
@@ -40,4 +41,10 @@ struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD>
|
||||
static constexpr const char* name = "qr_async_trload";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS_HPAD>
|
||||
{
|
||||
static constexpr const char* name = "qr_hpad";
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -14,7 +14,9 @@
|
||||
namespace ck_tile {
|
||||
|
||||
// This pipeline is qkv all located in LDS
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
||||
template <typename Problem_,
|
||||
typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy,
|
||||
bool PaddedVecLoadStore_ = false>
|
||||
struct BlockFmhaPipelineQRKSVS
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
@@ -54,17 +56,18 @@ struct BlockFmhaPipelineQRKSVS
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr auto QScaleEnum = Problem::QScaleEnum;
|
||||
static constexpr bool kHasSink = Problem::kHasSink;
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr auto QScaleEnum = Problem::QScaleEnum;
|
||||
static constexpr bool kHasSink = Problem::kHasSink;
|
||||
static constexpr bool kPaddedVecLoadStore = PaddedVecLoadStore_;
|
||||
|
||||
static constexpr ck_tile::index_t kQKScaleGranularity = Problem::kQKScaleGranularity;
|
||||
static constexpr ck_tile::index_t kVScaleGranularity = Problem::kVScaleGranularity;
|
||||
@@ -80,23 +83,29 @@ struct BlockFmhaPipelineQRKSVS
|
||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||
!kHasLogitsSoftCap)) ||
|
||||
(!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap));
|
||||
static_assert(!kPaddedVecLoadStore || (kPadHeadDimQ && kPadHeadDimV),
|
||||
"padded vector load/store fast path only applies to padded head-dim kernels");
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ = kPadHeadDimQ ? numeric_traits<QDataType>::PackedSize
|
||||
: Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK = kPadHeadDimQ ? numeric_traits<KDataType>::PackedSize
|
||||
: Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentQ = (kPadHeadDimQ && !kPaddedVecLoadStore)
|
||||
? numeric_traits<QDataType>::PackedSize
|
||||
: Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK = (kPadHeadDimQ && !kPaddedVecLoadStore)
|
||||
? numeric_traits<KDataType>::PackedSize
|
||||
: Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV = []() {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
return (kPadHeadDimV && !kPaddedVecLoadStore)
|
||||
? 1
|
||||
: Policy::template GetAlignmentV<Problem>();
|
||||
else
|
||||
return kPadSeqLenK ? numeric_traits<VDataType>::PackedSize
|
||||
: Policy::template GetAlignmentV<Problem>();
|
||||
}();
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
(kPadHeadDimV && !kPaddedVecLoadStore) ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
static constexpr index_t kAlignmentRandVal =
|
||||
@@ -548,8 +557,25 @@ struct BlockFmhaPipelineQRKSVS
|
||||
});
|
||||
}
|
||||
|
||||
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
|
||||
{ // tail
|
||||
auto v_prefetch = decltype(load_tile(v_dram_window)){};
|
||||
enum class VPrefetchPoint
|
||||
{
|
||||
BeforeGemm0Tail,
|
||||
AfterGemm0Tail,
|
||||
AfterSoftmax
|
||||
};
|
||||
|
||||
#if defined(__gfx11__) || defined(__gfx12__)
|
||||
constexpr auto kVPrefetch =
|
||||
kPadHeadDimV ? VPrefetchPoint::AfterSoftmax : VPrefetchPoint::AfterGemm0Tail;
|
||||
#else
|
||||
constexpr auto kVPrefetch = VPrefetchPoint::BeforeGemm0Tail;
|
||||
#endif
|
||||
if constexpr(kVPrefetch == VPrefetchPoint::BeforeGemm0Tail)
|
||||
{
|
||||
load_tile(v_prefetch, v_dram_window); // prefetch load v tile
|
||||
}
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
run_gemm_0(number<k0_loops - 2>{});
|
||||
block_sync_lds();
|
||||
@@ -562,6 +588,10 @@ struct BlockFmhaPipelineQRKSVS
|
||||
|
||||
run_gemm_0(number<k0_loops - 1>{});
|
||||
}
|
||||
if constexpr(kVPrefetch == VPrefetchPoint::AfterGemm0Tail)
|
||||
{
|
||||
load_tile(v_prefetch, v_dram_window);
|
||||
}
|
||||
// dequant
|
||||
auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
@@ -819,6 +849,11 @@ struct BlockFmhaPipelineQRKSVS
|
||||
randval_ptr, seq_offset, p_compute, randval_dram_window);
|
||||
}
|
||||
|
||||
if constexpr(kVPrefetch == VPrefetchPoint::AfterSoftmax)
|
||||
{
|
||||
load_tile(v_prefetch, v_dram_window);
|
||||
}
|
||||
|
||||
block_sync_lds();
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -1098,4 +1133,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
||||
using BlockFmhaPipelineQRKSVSHpad = BlockFmhaPipelineQRKSVS<Problem_, Policy_, true>;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user