From cd23de8b3874f78562bb738bc8af98df76e25167 Mon Sep 17 00:00:00 2001 From: Anton Gorenko Date: Sat, 21 Feb 2026 06:15:10 +0500 Subject: [PATCH] [CK_TILE][FMHA] Support gfx11 (#4584) ## Motivation Add support of gfx11 architectures (RDNA3) to FMHA. ## Technical Details Distributions (matrix elements to lane registers mapping) of gfx11 WMMA are completely different from distributions of gfx9 MFMA and gfx12 WMMA. There are two cases in FMHA where this difference matters: * usage of results (matrix C) of one GEMM as input (matrix A) of another GEMM. * random number generation for dropout (implementation for gfx9 MFMA, gfx12 WMMA and host validation produce the same results). Both cases are solved by a special remapping implemented using `__builtin_amdgcn_permlanex16` and `__builtin_amdgcn_perm`. Additional changes: * FMHA tests are now build and run only for those types for which instances exist (gfx11 supports only fp16 and bf16). * Two fixes for uninitialized values (`mask.sink` and `do_fp8_static_quant`): they may contain garbage resulting in incorrect dispatching logic, sometimes tests report that there are no instance available for current parameters. * Small fix to remove expcnt(0) from s_waitcnt instruction on gfx11 when they are not requested (i.e. every time), likely has no effect on performance but makes disassembly a bit clearer. ## Test Plan ``` ninja test_ck_tile_fmha bin/test_ck_tile_fmha_fwd_fp16 bin/test_ck_tile_fmha_fwd_bf16 bin/test_ck_tile_fmha_bwd_fp16 bin/test_ck_tile_fmha_bwd_bf16 ``` ## Test Result All tests must pass (some tests may be skipped). ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- CHANGELOG.md | 1 + example/ck_tile/01_fmha/CMakeLists.txt | 4 +- .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 20 ++++++++ .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 46 +++++++++++++++++++ .../01_fmha/codegen/ops/fmha_fwd_appendkv.py | 18 ++++++++ .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 19 ++++++++ .../codegen/ops/fmha_pagedkv_prefill.py | 20 ++++++++ example/ck_tile/01_fmha/fmha_fwd.hpp | 4 +- example/ck_tile/01_fmha/mask.hpp | 6 ++- include/ck_tile/core/arch/arch.hpp | 17 ++++--- .../ck_tile/ops/fmha/block/block_dropout.hpp | 42 +++++++++++++++++ .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 2 +- ...block_fmha_bwd_pipeline_default_policy.hpp | 23 ++++++++-- ...ock_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp | 8 ++++ ...ock_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp | 8 ++++ .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 8 ++++ include/ck_tile/ops/gemm.hpp | 1 + .../gemm/warp/warp_wmma_gemm_gfx11_utils.hpp | 45 ++++++++++++++++++ test/ck_tile/fmha/CMakeLists.txt | 25 ++++++++-- 19 files changed, 296 insertions(+), 21 deletions(-) create mode 100644 include/ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index c99fc1d065..04ba0283ab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines. * Added persistent async input scheduler for CK Tile universal GEMM kernels to support asynchronous input streaming. * Added FP8 block scale quantization for FMHA forward kernel. +* Added gfx11 support for FMHA. ### Changed diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index fbd6551091..35afb1181e 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -3,9 +3,9 @@ set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) # Currently only gfx9 and gfx12 archs are supported by FMHA -list(FILTER INST_TARGETS INCLUDE REGEX "gfx9|gfx12") +list(FILTER INST_TARGETS INCLUDE REGEX "gfx9|gfx1[12]") if(NOT INST_TARGETS) - message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9, gfx12) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9, gfx11, gfx12) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") return() endif() diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 355224fe2c..39950d9a33 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -457,6 +457,24 @@ class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9): return results +class KernelComponentFactoryGfx11(KernelComponentFactoryBase): + arch = ArchTrait("gfx11") + + @staticmethod + def get_dq_dk_dv_tiles(dtype: str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]: + if tr_load == "t": + return [] + if dtype in ["fp16", "bf16"]: + return [ + # bm0, bn0, bk0, bk1, bk2, bk3, bk4, bhdq, bhdv, + FmhaBwdDQDKDVTileSize( 32, 64, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaBwdDQDKDVTileSize( 32, 64, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1), + ] # fmt: skip + return [] + + class KernelComponentFactoryGfx12(KernelComponentFactoryBase): arch = ArchTrait("gfx12") @@ -483,6 +501,8 @@ def get_factory(target: str): if target.startswith("gfx9"): return KernelComponentFactoryGfx9 + if target.startswith("gfx11"): + return KernelComponentFactoryGfx11 if target.startswith("gfx12"): return KernelComponentFactoryGfx12 diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index f9301878c4..aa29633edc 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1094,6 +1094,50 @@ class KernelComponentFactoryGfx950( return pipelines +class KernelComponentFactoryGfx11(CompatibilityRuleFactory): + arch = ArchTrait("gfx11") + + _DT_FP16_BF16 = ("fp16", "bf16") + + @classmethod + def supported_dtypes(cls) -> Tuple[str]: + return cls._DT_FP16_BF16 + + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: + if dtype in cls._DT_FP16_BF16: + return { + # bm0, bn0, bk0, bn1, bk1, + ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + } # fmt: skip + else: + raise ValueError(f"unsupported dtype={dtype}") + + @classmethod + def get_pipelines( + cls, dtype, hdim, hdim_v, receipt, mask_impl + ) -> List[FmhaFwdPipeline]: + pipelines = [] + if dtype in cls._DT_FP16_BF16: + qscale = "no" + for logits, mask, bias, lse, dropout, skip, sink in itertools.product( + ["t", "f"], + get_mask_map(mask_impl).keys(), + BIAS_MAP.keys(), + ["t", "f"], + ["t", "f"], + ["t", "f"], + ["t", "f"], + ): + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + return pipelines + + class KernelComponentFactoryGfx12(CompatibilityRuleFactory): arch = ArchTrait("gfx12") @@ -1183,6 +1227,8 @@ def get_factory(target: str): if target.startswith("gfx9"): return KernelComponentFactoryGfx9 + if target.startswith("gfx11"): + return KernelComponentFactoryGfx11 if target.startswith("gfx12"): return KernelComponentFactoryGfx12 diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index 02518c6c0a..72b76b011a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -388,6 +388,22 @@ class KernelComponentFactoryGfx9(KernelComponentFactoryBase): arch = ArchTrait("gfx9") +class KernelComponentFactoryGfx11(KernelComponentFactoryBase): + arch = ArchTrait("gfx11") + + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + if dtype in ["fp16", "bf16"]: + return KernelComponentFactoryBase.get_hdim_tile_size_dict(dtype) + return None + + @staticmethod + def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]: + if dtype in ["fp16", "bf16"]: + return KernelComponentFactoryBase.get_pipelines(dtype, hdim) + return [] + + class KernelComponentFactoryGfx12(KernelComponentFactoryBase): arch = ArchTrait("gfx12") @@ -398,6 +414,8 @@ def get_factory(target: str): if target.startswith("gfx9"): return KernelComponentFactoryGfx9 + if target.startswith("gfx11"): + return KernelComponentFactoryGfx11 if target.startswith("gfx12"): return KernelComponentFactoryGfx12 diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 9105900fc7..def90a5429 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -836,6 +836,23 @@ class KernelComponentFactoryGfx9(KernelComponentFactoryBase): return None +class KernelComponentFactoryGfx11(KernelComponentFactoryBase): + arch = ArchTrait("gfx11") + + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + if dtype in ["fp16", "bf16"]: + return { + # bm0, bn0, bk0, bn1, bk1, + "32" : FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + } # fmt: skip + else: + return None + + class KernelComponentFactoryGfx12(KernelComponentFactoryBase): arch = ArchTrait("gfx12") @@ -865,6 +882,8 @@ def get_factory(target: str): if target.startswith("gfx9"): return KernelComponentFactoryGfx9 + if target.startswith("gfx11"): + return KernelComponentFactoryGfx11 if target.startswith("gfx12"): return KernelComponentFactoryGfx12 diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index cdb43c3480..45e5f9c705 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -597,6 +597,24 @@ class KernelComponentFactoryGfx9(KernelComponentFactoryBase): return None +class KernelComponentFactoryGfx11(KernelComponentFactoryBase): + arch = ArchTrait("gfx11") + + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + if dtype in ["fp16", "bf16"]: + return { + # bm0, bn0, bk0, bn1, bk1, + # "32": FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + # "64": FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + # "192": FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + # "256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + } # fmt: skip + else: + return None + + class KernelComponentFactoryGfx12(KernelComponentFactoryBase): arch = ArchTrait("gfx12") @@ -628,6 +646,8 @@ def get_factory(target: str): if target.startswith("gfx9"): return KernelComponentFactoryGfx9 + if target.startswith("gfx11"): + return KernelComponentFactoryGfx11 if target.startswith("gfx12"): return KernelComponentFactoryGfx12 diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index ee404010ef..3123e2bd59 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -1635,8 +1635,8 @@ struct fmha_fwd_splitkv_traits mask_enum mask_type; bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_lse; - bool do_fp8_static_quant; - bool has_sink = false; + bool do_fp8_static_quant = false; + bool has_sink = false; // TODO: padding check is inside this api }; float fmha_fwd_splitkv(fmha_fwd_splitkv_traits, diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp index c780bf7b6b..bcaa7d596d 100644 --- a/example/ck_tile/01_fmha/mask.hpp +++ b/example/ck_tile/01_fmha/mask.hpp @@ -73,6 +73,7 @@ struct mask_info tmp.x = r.at(ck_tile::number<1>{}); tmp.left = left_size; tmp.right = right_size; + tmp.sink = 0; } else if(t == "t" || t == "b" || t == "g") { @@ -147,7 +148,10 @@ struct mask_info } else if(str == "0") { - tmp.type = mask_enum::no_mask; + tmp.type = mask_enum::no_mask; + tmp.left = -1; + tmp.right = -1; + tmp.sink = 0; } else if(str == "1" || str == "t") { diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 92d0dd7f73..40bdb2ff31 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -911,14 +911,15 @@ struct WaitcntLayoutGfx12 }; struct WaitcntLayoutGfx11 -{ // vm[15:10] (6), lgkm[9:4] (6), exp unused +{ // vm[15:10] (6), lgkm[9:4] (6), exp [2:0] (3) CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F; CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x3F; - CK_TILE_DEVICE static constexpr bool HAS_EXP = false; + CK_TILE_DEVICE static constexpr index_t EXP_MASK = 0x07; + CK_TILE_DEVICE static constexpr bool HAS_EXP = true; CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c) { return ((c & VM_MASK) << 10); } CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 4); } - CK_TILE_DEVICE static constexpr index_t pack_exp(index_t) { return 0; } + CK_TILE_DEVICE static constexpr index_t pack_exp(index_t c) { return (c & EXP_MASK); } }; struct WaitcntLayoutLegacy @@ -952,10 +953,14 @@ using Waitcnt = WaitcntLayoutLegacy; struct waitcnt_arg { // kMax* exposed for callers; match field widths per-arch -#if defined(__gfx12__) || defined(__gfx11__) +#if defined(__gfx12__) CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x3F; // 6 bits CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x0; // none +#elif defined(__gfx11__) + CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits + CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x3F; // 6 bits + CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x07; // 3 bits #else CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits (split) CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x0F; // 4 bits @@ -981,8 +986,8 @@ struct waitcnt_arg { if constexpr(Waitcnt::HAS_EXP) { - // EXP_MASK only exists on legacy -#if !defined(__gfx12__) && !defined(__gfx11__) + // EXP_MASK only exists on pre-gfx12 +#if !defined(__gfx12__) static_assert((cnt & ~Waitcnt::EXP_MASK) == 0, "expcnt out of range"); return Waitcnt::pack_exp(cnt); #else diff --git a/include/ck_tile/ops/fmha/block/block_dropout.hpp b/include/ck_tile/ops/fmha/block/block_dropout.hpp index 6e01ea5dda..37c1fe4805 100644 --- a/include/ck_tile/ops/fmha/block/block_dropout.hpp +++ b/include/ck_tile/ops/fmha/block/block_dropout.hpp @@ -33,6 +33,42 @@ namespace ck_tile { namespace detail { // The number of Philox 4x32 results required to fill 32x32 tile of 8-bit values constexpr index_t philox_per_tile = 64; + +// C distribution of gfx11 WMMA differs from C distribution of gfx9 MFMA and gfx12 WMMA. +// This function deinterleaves the generated random values to make them compatible with other +// architectures and verification code on host. +template +CK_TILE_DEVICE void PermuteBlockDropoutRandval(uint8_t (&random_uint8_t)[N]) +{ +#if defined(__gfx11__) + static_for<0, N, 8>{}([&](auto i_offset) { + array rs; + static_for<0, 8, 1>{}([&](auto i) { rs.data[i] = random_uint8_t[i_offset + i]; }); + + const uint32_t r0 = rs.template get_as(number<0>{}); + const uint32_t r1 = rs.template get_as(number<1>{}); + + // Deinterleave values (even and odd indices) + const uint32_t v0 = __builtin_amdgcn_perm(r1, r0, 0x06'04'02'00); + const uint32_t v1 = __builtin_amdgcn_perm(r1, r0, 0x07'05'03'01); + + // Swap rows (lane <-> lane ^ 16) + const uint32_t w0 = + __builtin_amdgcn_permlanex16(0, v0, 0x76543210, 0xfedcba98, false, true); + const uint32_t w1 = + __builtin_amdgcn_permlanex16(0, v1, 0x76543210, 0xfedcba98, false, true); + + rs.template set_as(number<0>{}, get_lane_id() < 16 ? v0 : w1); + rs.template set_as(number<1>{}, get_lane_id() < 16 ? w0 : v1); + + static_for<0, 8, 1>{}([&](auto i) { random_uint8_t[i_offset + i] = rs.data[i]; }); + }); +#else + static_assert(false, "PermuteBlockDropoutRandval is only for gfx11"); + ignore = random_uint8_t; +#endif +} + } // namespace detail struct NullBlockDropout @@ -295,6 +331,9 @@ struct BlockDropout static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); ph.get_random_16x8(random_uint8_t, ph_subsequence); } +#if defined(__gfx11__) + detail::PermuteBlockDropoutRandval(random_uint8_t); +#endif } else { @@ -566,6 +605,9 @@ struct BlockDropoutBwd static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); ph.get_random_16x8(random_uint8_t, ph_subsequence); } +#if defined(__gfx11__) + detail::PermuteBlockDropoutRandval(random_uint8_t); +#endif } else { diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 06b0d76a0d..b5d3f490ed 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -1181,7 +1181,7 @@ struct FmhaBwdDQDKDVKernel scale_rp_undrop, dropout); -#if defined(__gfx12__) +#if defined(__gfx11__) || defined(__gfx12__) // Workaround for a compiler bug (SWDEV-559729): v_wmma instructions can be incorrectly // placed in divergent branches used to store padded tensors (when some lanes are // inactive due to padding). Inline asm with dummy dependencies on VGPRs of the tensors diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index e04e08258e..e67a525ac4 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp" @@ -1692,8 +1693,10 @@ struct BlockFmhaBwdPipelineDefaultPolicy using AWarpDstr = typename WarpGemm::AWarpDstr; using CWarpDstr = typename WarpGemm::CWarpDstr; - auto pt_warp_tensor = + auto p_warp_tensor = make_static_distributed_tensor(CWarpDstr{}); + auto pt_warp_tensor = + make_static_distributed_tensor(AWarpDstr{}); constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); @@ -1705,10 +1708,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - pt_warp_tensor.get_thread_buffer() = p_in.get_y_sliced_thread_data( + p_warp_tensor.get_thread_buffer() = p_in.get_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); +#if defined(__gfx11__) + PermuteWarpGemmCToA(pt_warp_tensor, p_warp_tensor); +#else + pt_warp_tensor.get_thread_buffer() = p_warp_tensor.get_thread_buffer(); +#endif pt_out.set_y_sliced_thread_data( merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths), @@ -1742,8 +1750,10 @@ struct BlockFmhaBwdPipelineDefaultPolicy using AWarpDstr = typename WarpGemm::AWarpDstr; using CWarpDstr = typename WarpGemm::CWarpDstr; - auto dst_warp_tensor = + auto ds_warp_tensor = make_static_distributed_tensor(CWarpDstr{}); + auto dst_warp_tensor = + make_static_distributed_tensor(AWarpDstr{}); constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); @@ -1755,10 +1765,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - dst_warp_tensor.get_thread_buffer() = ds_in.get_y_sliced_thread_data( + ds_warp_tensor.get_thread_buffer() = ds_in.get_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); +#if defined(__gfx11__) + PermuteWarpGemmCToA(dst_warp_tensor, ds_warp_tensor); +#else + dst_warp_tensor.get_thread_buffer() = ds_warp_tensor.get_thread_buffer(); +#endif dst_out.set_y_sliced_thread_data( merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths), diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp index e516fc8eea..d7696f0f76 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { @@ -675,8 +676,15 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS i_page_block_v = v_page_block_navigator.move_tile_window( i_page_block_v, v_dram_window, {0, kK1}, physical_next_block_id_v); +#if defined(__gfx11__) + auto p = make_static_distributed_tensor( + decltype(gemm_1)::template MakeABlockTileDistribution()); + PermuteWarpGemmCToA( + p, cast_tile(tile_elementwise_in(p_compute_element_func, p_compute))); +#else const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); +#endif // STAGE 3, KV gemm if constexpr(k1_loops > 1) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index c09330f847..ef6ed8b4e8 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { @@ -704,8 +705,15 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS i_page_block_v = v_page_block_navigator.move_tile_window( i_page_block_v, v_dram_window, {0, kK1}, physical_next_block_id_v); +#if defined(__gfx11__) + auto p = make_static_distributed_tensor( + decltype(gemm_1)::template MakeABlockTileDistribution()); + PermuteWarpGemmCToA( + p, cast_tile(tile_elementwise_in(p_compute_element_func, p_compute))); +#else const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); +#endif // STAGE 3, KV gemm if constexpr(k1_loops > 1) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 2fbc9fdb54..35654840bd 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -7,6 +7,7 @@ #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { @@ -717,8 +718,15 @@ struct BlockFmhaPipelineQRKSVS move_tile_window(v_dram_window, {0, kK1}); +#if defined(__gfx11__) + auto p = make_static_distributed_tensor( + decltype(gemm_1)::template MakeABlockTileDistribution()); + PermuteWarpGemmCToA( + p, cast_tile(tile_elementwise_in(p_compute_element_func, p_compute))); +#else const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); +#endif float v_descale = 1.0f; if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 230c614649..b1681e07e4 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -77,6 +77,7 @@ #include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp" +#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_interleaved_pk_type.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp b/include/ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp new file mode 100644 index 0000000000..4ce787b19d --- /dev/null +++ b/include/ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp @@ -0,0 +1,45 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// C distribution of gfx11 WMMA is not compatible with A distribution: +// C: 2 lanes per row (lane and lane + 16), 8 values per lane are interleaved. +// A: 1 lane per row, 16 values, lane and lane + 16 have the same values. +// This function transforms one ditribution to another for GEMM-GEMM scenarios. +template +CK_TILE_DEVICE static constexpr void PermuteWarpGemmCToA(OutTensor& out, const InTensor& in) +{ +#if defined(__gfx11__) + static_assert(sizeof(typename OutTensor::DataType) == 2); + static_assert(std::is_same_v); + + constexpr index_t n_out = OutTensor::get_thread_buffer_size(); + static_assert(n_out == InTensor::get_thread_buffer_size() * 2); + + // Perm byte selectors are swapped for the second row (16 lanes) because it needs to be done + // once instead to swapping w and v everytime + const uint32_t byte_selector0 = get_lane_id() < 16 ? 0x05'04'01'00 : 0x01'00'05'04; + const uint32_t byte_selector1 = get_lane_id() < 16 ? 0x07'06'03'02 : 0x03'02'07'06; + static_for<0, n_out, 1>{}([&](auto i) { + const auto v = in.get_thread_buffer().template get_as(i); + // Swap rows (lane <-> lane ^ 16) + const auto w = __builtin_amdgcn_permlanex16(0, v, 0x76543210, 0xfedcba98, false, true); + // Interleave values of lane and lane ^ 16 + out.get_thread_buffer().template set_as( + number{}, __builtin_amdgcn_perm(w, v, byte_selector0)); + out.get_thread_buffer().template set_as( + number{}, __builtin_amdgcn_perm(w, v, byte_selector1)); + }); +#else + static_assert(false, "PermuteWarpGemmCToA is only for gfx11"); + ignore = out; + ignore = in; +#endif +} + +} // namespace ck_tile diff --git a/test/ck_tile/fmha/CMakeLists.txt b/test/ck_tile/fmha/CMakeLists.txt index e591d5066f..60779f7f51 100644 --- a/test/ck_tile/fmha/CMakeLists.txt +++ b/test/ck_tile/fmha/CMakeLists.txt @@ -1,11 +1,6 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Keep in sync with example/ck_tile/01_fmha/CMakeLists.txt -if(NOT SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx12") - return() -endif() - set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances") set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances") @@ -18,9 +13,19 @@ function(add_gtest_fwd test_group) set(CPP_TYPE_fp8bf16 "FmhaFwdFp8Bf16") set(CPP_TYPE_fp32 "FmhaFwdFp32") + set(sources) + if(TARGET ${FMHA_FWD_INSTANCES}) + get_target_property(sources ${FMHA_FWD_INSTANCES} SOURCES) + message(VERBOSE "${FMHA_FWD_INSTANCES} SOURCES ${sources}") + endif() + set(all_tests) foreach(type ${V_TYPES}) set(name "${test_group}_${type}") + if(NOT sources MATCHES "_${type}_") + message(STATUS "No FMHA FWD instances for ${type}, skip ${name}") + continue() + endif() add_gtest_executable(${name} test_fmha_fwd.cpp) get_test_property(${name} LABELS COMMON_LABELS) set_tests_properties(${name} PROPERTIES LABELS "${COMMON_LABELS};${TEST_NAME};${test_group}") @@ -38,9 +43,19 @@ function(add_gtest_bwd test_group) set(CPP_TYPE_bf16 "FmhaBwdBf16") set(CPP_TYPE_fp32 "FmhaBwdFp32") + set(sources) + if(TARGET ${FMHA_BWD_INSTANCES}) + get_target_property(sources ${FMHA_BWD_INSTANCES} SOURCES) + message(VERBOSE "${FMHA_BWD_INSTANCES} SOURCES ${sources}") + endif() + set(all_tests) foreach(type ${V_TYPES}) set(name "${test_group}_${type}") + if(NOT sources MATCHES "_${type}_") + message(STATUS "No FMHA BWD instances for ${type}, skip ${name}") + continue() + endif() add_gtest_executable(${name} test_fmha_bwd.cpp) get_test_property(${name} LABELS COMMON_LABELS) set_tests_properties(${name} PROPERTIES LABELS "${COMMON_LABELS};${TEST_NAME};${test_group}")