From 559ad6f0b16f986a7740bd582caf29f7d14bdafb Mon Sep 17 00:00:00 2001 From: Hosang <156028780+hyoon1@users.noreply.github.com> Date: Tue, 10 Mar 2026 12:46:41 -0400 Subject: [PATCH] [CK_TILE] Update gfx11 FMHA forward kernel configs (#5088) ## Motivation Tune gfx11 FMHA codegen to recover performance for mainly PSSK (padded seqlen_q/k) cases. This tuning is based on heuristic search and improves performance in most tested shapes. Performance should be evaluated on top of [`ROCm/rocm-libraries#5018`](https://github.com/ROCm/rocm-libraries/pull/5018) (required baseline). ## Technical Details - Updated gfx11 codegen heuristic choices for tile size and occupancy. - Updated gfx11 pipeline selection: - Disabled the `npad` (`f,f,f,f`) qr entry because it was consistently slower than the `pssk` (`t,t,f,f`) path, and kept `pssk` enabled so npad cases are dispatched to the faster kernel path.` - Kept gfx12 unchanged: with PSSK support from [`ROCm/rocm-libraries#4957`](https://github.com/ROCm/rocm-libraries/pull/4957), existing gfx12 config is already sufficient. - Tuning rationale: - In some cases, higher `kBlockPerCu` lowers register pressure. - On RDNA, this generally aligns with better performance when `waves_per_eu >= 6`. ## Test Plan - test_ck_tile_fmha - tile_example_fmha_fwd: tested this on gfx1100 and gfx1151 ./build/bin/tile_example_fmha_fwd -prec=bf16 -mode={0/1} -b=1 -h=24 -d=128 -s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1} ## Test Result - TFLOPs by sequence length target: `gfx1100` layout: `bhsd` - mode: batch / VGPR usage: 225 vs 214 SeqLen | Baseline | Tuned | Gain -- | -- | -- | -- 1024 | 74.10 | 71.97 | 0.97x 4096 | 66.26 | 77.79 | 1.17x 8192 | 68.18 | 75.88 | 1.11x 12288 | 68.47 | 80.44 | 1.17x 16384 | 59.54 | 79.66 | 1.34x 20480 | 55.78 | 77.91 | 1.40x 24576 | 55.08 | 77.47 | 1.41x 27280 | 47.45 | 77.16 | 1.63x - mode: group / VGPR usage: 256 vs 214 SeqLen | Baseline | Tuned | Gain -- | -- | -- | -- 1024 | 71.47 | 70.6 | 0.99x 4096 | 64.74 | 77.06 | 1.19x 8192 | 64.68 | 75.47 | 1.17x 12288 | 66.43 | 79.95 | 1.20x 16384 | 56.02 | 79.73 | 1.42x 20480 | 50.21 | 78.15 | 1.56x 24576 | 47.29 | 77.53 | 1.64x 27280 | 46.13 | 77.04 | 1.67x - TFLOPs by sequence length target: `gfx1151` layout: `bshd` - mode: batch / VGPR usage: 225 vs 223 Batch | Baseline | Tuned | Gain -- | -- | -- | -- 1024 | 26.85 | 29.17 | 1.09x 4096 | 24.75 | 26.01 | 1.05x 8192 | 25.24 | 25.50 | 1.01x 12288 | 25.18 | 25.00 | 0.99x 16384 | 24.79 | 25.91 | 1.05x 20480 | 25.56 | 25.24 | 0.99x 24576 | 25.13 | 26.20 | 1.04x 27280 | 10.78 | 26.35 | 2.44x - mode: group / VGPR usage: 256 vs 229 Batch | Baseline | Tuned | Gain -- | -- | -- | -- 1024 | 27.44 | 26.71 | 0.97x 4096 | 21.89 | 23.09 | 1.05x 8192 | 22.85 | 24.49 | 1.07x 12288 | 24.33 | 24.42 | 1.00x 16384 | 20.05 | 24.98 | 1.24x 20480 | 14.70 | 25.15 | 1.71x 24576 | 11.30 | 26.31 | 2.33x 27280 | 10.10 | 26.32 | 2.61x ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 30 +++++++++++++++---- include/ck_tile/core.hpp | 1 + include/ck_tile/core/arch/arch.hpp | 5 ++++ include/ck_tile/core/config.hpp | 3 ++ 4 files changed, 34 insertions(+), 5 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index e888cbd383..627352e226 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1095,7 +1095,10 @@ class KernelComponentFactoryGfx950( class KernelComponentFactoryGfx11(CompatibilityRuleFactory): - arch = ArchTrait("gfx11") + arch = ArchTrait( + "gfx11", + preprocessor_check="defined(__gfx11__) && !defined(__gfx115__)", + ) _DT_FP16_BF16 = ("fp16", "bf16") @@ -1109,10 +1112,12 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory): return { # bm0, bn0, bk0, bn1, bk1, ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], - ( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], - (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")), + FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")), + FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)], (192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], - (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (256, 256) : [FmhaFwdTileSize(128, 64, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)] } # fmt: skip else: raise ValueError(f"unsupported dtype={dtype}") @@ -1133,12 +1138,25 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory): ["t", "f"], ["t", "f"], ): - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + # Keep only ttff/tttt for gfx11: ffff path is often similar or worse + # pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip return pipelines +class KernelComponentFactoryGfx115(KernelComponentFactoryGfx11): + arch = ArchTrait("gfx115") + + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: + result = super().get_hdim_tile_size_dict(dtype) + if dtype in cls._DT_FP16_BF16: + result[(64, 64)] = [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)] # fmt: skip + result[(256, 256)] = [FmhaFwdTileSize(128, 64, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, -1)] # fmt: skip + return result + + class KernelComponentFactoryGfx12(CompatibilityRuleFactory): arch = ArchTrait("gfx12") @@ -1230,6 +1248,8 @@ def get_factory(target: str): if target.startswith("gfx9"): return KernelComponentFactoryGfx9 + if target.startswith("gfx115"): + return KernelComponentFactoryGfx115 if target.startswith("gfx11"): return KernelComponentFactoryGfx11 if target.startswith("gfx12"): diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 2d4964f86a..f42526ddf7 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -23,6 +23,7 @@ #include "ck_tile/core/arch/mma/mma_selector.hpp" #include "ck_tile/core/arch/mma/mma_traits.hpp" #include "ck_tile/core/arch/mma/mma_transforms.hpp" +#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp" #include "ck_tile/core/arch/mma/wmma/wmma.hpp" #include "ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp" #include "ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp" diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 5069172386..62d7971a8a 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -1141,6 +1141,9 @@ struct gfx103_t struct gfx11_t { }; +struct gfx115_t +{ +}; struct gfx12_t { }; @@ -1174,6 +1177,8 @@ CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx103_t) { return 32; } CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx11_t) { return 32; } +CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx115_t) { return 32; } + CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx12_t) { return 32; } CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx950_t) { return 64; } diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index ed102c86a8..a057ae9052 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -24,6 +24,9 @@ defined(__gfx1152__) || defined(__gfx1153__) || defined(__gfx11_generic__) #define __gfx11__ #endif +#if defined(__gfx1150__) || defined(__gfx1151__) || defined(__gfx1152__) || defined(__gfx1153__) +#define __gfx115__ +#endif #if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__) #define __gfx12__ #endif