From 7ecbf82708905697c7764b89e9656de9eaa9aae6 Mon Sep 17 00:00:00 2001 From: Anton Gorenko Date: Wed, 3 Jun 2026 06:16:10 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#7500 (commit f5cd4fd) [CK_TILE][FMHA] Optimize long-context decoding on gfx11/12 (#7500) ## Motivation Relevant issue: ROCM-22065 FMHA has less-than-optimal performance of long-context decoding (i.e. when seqlen_q = 1) on gfx11/12. This PR optimizes the splitkv pipeline and configs for such scenarios. ## Technical Details Optimizations applied in this PR: 1. use tiles with smaller M0 (16 vs 64), these tiles are used when seqlen_q <= 16 2. adapt qr_nwarp_sshuffle pipeline for gfx11, it allows to use more warps even for M0 = 16 (the qr pipeline parallelizes work between warps in M dim so with M0 = 16 it allows to use only 1 warp) 3. enable kMergeNumHeadGroupsSeqLenQ (an optimization that merges one group of heads in GQA) for all hdim values, not only 128 4. increase the number of splits (multiply by the number of head groups) if (3) is used 5. increase the number of splits for RDNAs (`multiProcessorCount` is the number of WGPs on RDNAs, not CUs, so it should be doubled to have meaning similar to CDNAs) Performance on gfx1151: | Case | develop (GB/s) | This PR (GB/s) | |:-------|-------:|-------:| | [fp16\|group\|bshd] b:1, h:32/32, s:1/45056, d:64/64 | 127.58 | 183.11 | | [fp16\|group\|bhsd] b:1, h:32/32, s:1/45056, d:64/64 | 153.64 | 215.02 | | [fp16\|group\|bshd] b:1, h:16/8, s:1/77184, d:128/128 | 120.51 | 225.76 | | [fp16\|group\|bhsd] b:1, h:16/8, s:1/77184, d:128/128 | 130.62 | 223.84 | | [fp16\|group\|bshd] b:1, h:32/32, s:1/9600, d:128/128 | 82.65 | 138.44 | | [fp16\|group\|bhsd] b:1, h:32/32, s:1/9600, d:128/128 | 105.75 | 220.45 | | [fp16\|group\|bshd] b:1, h:8/1, s:1/401024, d:256/256 | 16.27 | 187.89 | | [fp16\|group\|bhsd] b:1, h:8/1, s:1/401024, d:256/256 | 16.28 | 188.19 | ## Test Plan An additional test case is added to the exiting test. It uses seqlen_q = 1, GQA, no mask to trigger the changes ``` ninja test_ck_tile_fmha_fwd_fp16 && bin/test_ck_tile_fmha_fwd_fp16 --gtest_filter="*SplitKV* ninja test_ck_tile_fmha_fwd_bf16 && bin/test_ck_tile_fmha_fwd_bf16 --gtest_filter="*SplitKV* ``` Manual testing can be done with these commands: ``` bin/tile_example_fmha_fwd -prec=fp16 -mode=1 -page_block_size=128 -b=1 -h=32 -h_k=32 -d=64 -s=1 -s_k=$((352 * 128)) -lse=1 -mask=0 -num_splits=0 -kname=1 -v=1 bin/tile_example_fmha_fwd -prec=fp16 -mode=1 -page_block_size=128 -b=1 -h=16 -h_k=8 -d=128 -s=1 -s_k=$((603 * 128)) -lse=1 -mask=0 -num_splits=0 -kname=1 -v=1 bin/tile_example_fmha_fwd -prec=fp16 -mode=1 -page_block_size=128 -b=1 -h=32 -h_k=32 -d=128 -s=1 -s_k=$((75 * 128)) -lse=1 -mask=0 -num_splits=0 -kname=1 -v=1 bin/tile_example_fmha_fwd -prec=fp16 -mode=1 -page_block_size=128 -b=1 -h=8 -h_k=1 -d=256 -s=1 -s_k=$((3133 * 128)) -lse=1 -mask=0 -num_splits=0 -kname=1 -v=1 ``` ## Test Result All the tests must pass. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 55 ++++++++++++++----- example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 16 +++++- ...litkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp | 10 +++- ...nwarp_sshuffle_qr_ks_vs_default_policy.hpp | 23 +++++++- test/ck_tile/fmha/test_fmha_fwd.cpp | 1 + 5 files changed, 84 insertions(+), 21 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 849f463afa..ed025dcf5f 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -128,7 +128,7 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F namespace {{ template void run_instance(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ - if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS + if constexpr ({F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS && (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask> || std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{ if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{ @@ -283,7 +283,7 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const """ FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && - ((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + ((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; // get combine kernel tile sizes @@ -364,6 +364,14 @@ class FmhaFwdSplitKVApiTrait: else: assert False + def seqtune(self, max_bm0: int) -> str: + if self.bm0 == max_bm0: + return "true/*fall back to largest tile*/" + else: + if self.mode == "group": + return f"a.max_seqlen_q <= {self.bm0}" + return f"a.seqlen_q <= {self.bm0}" + @property def skcheck(self) -> str: if self.mode == "group": @@ -561,6 +569,7 @@ class FmhaFwdSplitKVApiPool: for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()): per_hdim_case = str() for i_hdim, (hdim, pool_by_hdim) in enumerate(pool_by_dtype.items()): + max_bm0 = max((t.bm0 for t in pool_by_hdim), default=0) inners = str() for i_trait, trait in enumerate(pool_by_hdim): inners += FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format( @@ -579,6 +588,7 @@ class FmhaFwdSplitKVApiPool: F_pagedkv=BOOL_MAP[trait.pagedkv], F_sink=BOOL_MAP[trait.sink], F_scheck=trait.scheck, + F_seqtune=trait.seqtune(max_bm0), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, @@ -763,6 +773,7 @@ class KernelComponentFactoryBase: pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip + pipelines.append(Pipeline("qr_nwarp_sshuffle", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip elif dtype in ["fp8", "bf8"]: for logits, mask, bias in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() @@ -846,11 +857,15 @@ class KernelComponentFactoryGfx11(KernelComponentFactoryBase): def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: if dtype in ["fp16", "bf16"]: return { - # bm0, bn0, bk0, bn1, bk1, - "32" : FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + # bm0, bn0, bk0, bn1, bk1, + "32" : [FmhaFwdTileSize( 16, 64, 16, 32, 32, 32, 1, 2, 1, 1, 2, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + "64" : [FmhaFwdTileSize( 16, 64, 32, 64, 32, 64, 1, 4, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + "128": [FmhaFwdTileSize( 16, 64, 32, 128, 32, 128, 1, 4, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + "256": [FmhaFwdTileSize( 16, 64, 32, 256, 32, 256, 1, 4, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip else: return None @@ -863,11 +878,15 @@ class KernelComponentFactoryGfx12(KernelComponentFactoryBase): def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: if dtype in ["fp16", "bf16"]: return { - # bm0, bn0, bk0, bn1, bk1, - "32" : FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + # bm0, bn0, bk0, bn1, bk1, + "32" : [FmhaFwdTileSize( 16, 64, 16, 32, 32, 32, 1, 2, 1, 1, 2, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + "64" : [FmhaFwdTileSize( 16, 64, 32, 64, 32, 64, 1, 4, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + "128": [FmhaFwdTileSize( 16, 128, 32, 128, 32, 128, 1, 8, 1, 1, 8, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + "256": [FmhaFwdTileSize( 16, 128, 32, 256, 32, 256, 1, 8, 1, 1, 8, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip elif dtype in ["fp8", "bf8"]: return { @@ -930,11 +949,17 @@ def get_fwd_splitkv_blobs( d = factory.get_hdim_tile_size_dict(dtype) if d is None: continue - # for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): - tile = d[hdim_str] + tiles = d[hdim_str] + if not isinstance(tiles, list): + tiles = [tiles] hdim = int(hdim_str) - for pipeline in factory.get_pipelines(dtype, hdim, mask_impl): + for tile, pipeline in itertools.product( + tiles, factory.get_pipelines(dtype, hdim, mask_impl) + ): + # Use qr_nwarp_sshuffle with multiple N warps and qr otherwise + if (tile.F_rn0 != 1) != (pipeline.tag == "qr_nwarp_sshuffle"): + continue if mode == "group": if pipeline.F_spad != "t" or pipeline.F_skpad != "t": # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 0b51dffa46..243ff87faa 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -165,8 +165,10 @@ int override_num_splits_if_necessary( if(num_splits < 1 && p_drop == 0.0f) { + // props.multiProcessorCount for >=gfx10 is the number of WGPs (each has 2 CUs) + const int num_blocks_per_SM = props.warpSize == 32 ? 4 : 2; return num_splits_heuristic( - batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 128); + batch * nhead * num_m_blocks, props.multiProcessorCount * num_blocks_per_SM, 128); } return num_splits; @@ -648,8 +650,18 @@ fwd_result fmha_fwd_run(mode_enum mode, // legalize num_splits according to other options if(num_splits < 1) { + int nhead_merged = nhead; + int max_seqlen_q_merged = max_seqlen_q; + // When max_seqlen_q == 1 and multiple head groups are merged (kMergeNumHeadGroupsSeqLenQ) + // then more splits are required + if(bias.type == bias_enum::no_bias && mask.type == mask_enum::no_mask && + max_seqlen_q == 1 && nhead_k < nhead) + { + nhead_merged = nhead_k; + max_seqlen_q_merged = max_seqlen_q * (nhead / nhead_k); + } num_splits = override_num_splits_if_necessary( - batch, nhead, max_seqlen_q, hdim_v, p_drop, num_splits); + batch, nhead_merged, max_seqlen_q_merged, hdim_v, p_drop, num_splits); } if(128 < num_splits) { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp index adc8ea5a90..bdc598f754 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp" +#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { @@ -257,7 +258,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS clear_tile(o_acc); if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0) { - set_tile(m, SMPLComputeDataType{sink_v * C_LOG2E}); + set_tile(m, SMPLComputeDataType{sink_v * static_cast(C_LOG2E)}); set_tile(l, SMPLComputeDataType{1.0f}); } else @@ -698,8 +699,15 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); +#if defined(__gfx11__) + auto p = make_static_distributed_tensor( + decltype(gemm_1)::template MakeABlockTileDistribution()); + PermuteWarpGemmCToA( + p, cast_tile(tile_elementwise_in(p_compute_element_func, p_compute))); +#else const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); +#endif // l{j}, Oacc{j} constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp index c5af751cd5..316720ac22 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp @@ -5,8 +5,6 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp" namespace ck_tile { @@ -163,6 +161,25 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; constexpr index_t kTileK = Problem::BlockFmhaShape::kN0; +#if defined(__gfx11__) + // Keep C distribution and replicate data for NWarp to prevent doubling registers + // PermuteWarpGemmCToA will convert C distribution to A for matrix P later + constexpr index_t K1 = kKPerBlock / WG::kM; + constexpr index_t K0 = kTileK / kKPerBlock; + constexpr index_t M1 = MWarp; + constexpr index_t M0 = kMPerBlock / WG::kN; + + constexpr auto s2_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2, 2>, + sequence<0, 0, 1>>{}; + + constexpr auto s2_block_dstr_encoding = detail::make_embed_tile_distribution_encoding( + s2_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); +#else // K2 is equal to Impl::kABKPerLane * kKIterPerWarpGemm constexpr index_t K3 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; constexpr index_t K2 = WG::WarpGemmAttribute::Impl::kABKLane; @@ -179,7 +196,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy tuple, sequence<2, 2>>, sequence<1, 2, 2, 2>, sequence<0, 0, 1, 3>>{}; - +#endif constexpr auto s2_block_dstr = make_static_tile_distribution(s2_block_dstr_encoding); return s2_block_dstr; diff --git a/test/ck_tile/fmha/test_fmha_fwd.cpp b/test/ck_tile/fmha/test_fmha_fwd.cpp index 6ae33da30f..bdfc2d17c4 100644 --- a/test/ck_tile/fmha/test_fmha_fwd.cpp +++ b/test/ck_tile/fmha/test_fmha_fwd.cpp @@ -735,6 +735,7 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd, Values(3, 4), Values(std::tuple{4, 3, 1, 200, 1024, "0"}, std::tuple{2, 2, -1, 512, 2000, "0"}, + std::tuple{2, 8, 2, 1, 1024, "0"}, std::tuple{3, 2, -1, 230, 899, "t:128,128"}))); TEST_P(SplitKV, DataTypeConfig)