diff --git a/CHANGELOG.md b/CHANGELOG.md index c99fc1d065..04ba0283ab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines. * Added persistent async input scheduler for CK Tile universal GEMM kernels to support asynchronous input streaming. * Added FP8 block scale quantization for FMHA forward kernel. +* Added gfx11 support for FMHA. ### Changed diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index fbd6551091..35afb1181e 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -3,9 +3,9 @@ set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) # Currently only gfx9 and gfx12 archs are supported by FMHA -list(FILTER INST_TARGETS INCLUDE REGEX "gfx9|gfx12") +list(FILTER INST_TARGETS INCLUDE REGEX "gfx9|gfx1[12]") if(NOT INST_TARGETS) - message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9, gfx12) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9, gfx11, gfx12) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") return() endif() diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 355224fe2c..39950d9a33 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -457,6 +457,24 @@ class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9): return results +class KernelComponentFactoryGfx11(KernelComponentFactoryBase): + arch = ArchTrait("gfx11") + + @staticmethod + def get_dq_dk_dv_tiles(dtype: str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]: + if tr_load == "t": + return [] + if dtype in ["fp16", "bf16"]: + return [ + # bm0, bn0, bk0, bk1, bk2, bk3, bk4, bhdq, bhdv, + FmhaBwdDQDKDVTileSize( 32, 64, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaBwdDQDKDVTileSize( 32, 64, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1), + ] # fmt: skip + return [] + + class KernelComponentFactoryGfx12(KernelComponentFactoryBase): arch = ArchTrait("gfx12") @@ -483,6 +501,8 @@ def get_factory(target: str): if target.startswith("gfx9"): return KernelComponentFactoryGfx9 + if target.startswith("gfx11"): + return KernelComponentFactoryGfx11 if target.startswith("gfx12"): return KernelComponentFactoryGfx12 diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index f9301878c4..aa29633edc 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1094,6 +1094,50 @@ class KernelComponentFactoryGfx950( return pipelines +class KernelComponentFactoryGfx11(CompatibilityRuleFactory): + arch = ArchTrait("gfx11") + + _DT_FP16_BF16 = ("fp16", "bf16") + + @classmethod + def supported_dtypes(cls) -> Tuple[str]: + return cls._DT_FP16_BF16 + + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: + if dtype in cls._DT_FP16_BF16: + return { + # bm0, bn0, bk0, bn1, bk1, + ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + } # fmt: skip + else: + raise ValueError(f"unsupported dtype={dtype}") + + @classmethod + def get_pipelines( + cls, dtype, hdim, hdim_v, receipt, mask_impl + ) -> List[FmhaFwdPipeline]: + pipelines = [] + if dtype in cls._DT_FP16_BF16: + qscale = "no" + for logits, mask, bias, lse, dropout, skip, sink in itertools.product( + ["t", "f"], + get_mask_map(mask_impl).keys(), + BIAS_MAP.keys(), + ["t", "f"], + ["t", "f"], + ["t", "f"], + ["t", "f"], + ): + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + return pipelines + + class KernelComponentFactoryGfx12(CompatibilityRuleFactory): arch = ArchTrait("gfx12") @@ -1183,6 +1227,8 @@ def get_factory(target: str): if target.startswith("gfx9"): return KernelComponentFactoryGfx9 + if target.startswith("gfx11"): + return KernelComponentFactoryGfx11 if target.startswith("gfx12"): return KernelComponentFactoryGfx12 diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index 02518c6c0a..72b76b011a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -388,6 +388,22 @@ class KernelComponentFactoryGfx9(KernelComponentFactoryBase): arch = ArchTrait("gfx9") +class KernelComponentFactoryGfx11(KernelComponentFactoryBase): + arch = ArchTrait("gfx11") + + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + if dtype in ["fp16", "bf16"]: + return KernelComponentFactoryBase.get_hdim_tile_size_dict(dtype) + return None + + @staticmethod + def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]: + if dtype in ["fp16", "bf16"]: + return KernelComponentFactoryBase.get_pipelines(dtype, hdim) + return [] + + class KernelComponentFactoryGfx12(KernelComponentFactoryBase): arch = ArchTrait("gfx12") @@ -398,6 +414,8 @@ def get_factory(target: str): if target.startswith("gfx9"): return KernelComponentFactoryGfx9 + if target.startswith("gfx11"): + return KernelComponentFactoryGfx11 if target.startswith("gfx12"): return KernelComponentFactoryGfx12 diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 9105900fc7..def90a5429 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -836,6 +836,23 @@ class KernelComponentFactoryGfx9(KernelComponentFactoryBase): return None +class KernelComponentFactoryGfx11(KernelComponentFactoryBase): + arch = ArchTrait("gfx11") + + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + if dtype in ["fp16", "bf16"]: + return { + # bm0, bn0, bk0, bn1, bk1, + "32" : FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + } # fmt: skip + else: + return None + + class KernelComponentFactoryGfx12(KernelComponentFactoryBase): arch = ArchTrait("gfx12") @@ -865,6 +882,8 @@ def get_factory(target: str): if target.startswith("gfx9"): return KernelComponentFactoryGfx9 + if target.startswith("gfx11"): + return KernelComponentFactoryGfx11 if target.startswith("gfx12"): return KernelComponentFactoryGfx12 diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index cdb43c3480..45e5f9c705 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -597,6 +597,24 @@ class KernelComponentFactoryGfx9(KernelComponentFactoryBase): return None +class KernelComponentFactoryGfx11(KernelComponentFactoryBase): + arch = ArchTrait("gfx11") + + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + if dtype in ["fp16", "bf16"]: + return { + # bm0, bn0, bk0, bn1, bk1, + # "32": FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + # "64": FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + # "192": FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + # "256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + } # fmt: skip + else: + return None + + class KernelComponentFactoryGfx12(KernelComponentFactoryBase): arch = ArchTrait("gfx12") @@ -628,6 +646,8 @@ def get_factory(target: str): if target.startswith("gfx9"): return KernelComponentFactoryGfx9 + if target.startswith("gfx11"): + return KernelComponentFactoryGfx11 if target.startswith("gfx12"): return KernelComponentFactoryGfx12 diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index ee404010ef..3123e2bd59 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -1635,8 +1635,8 @@ struct fmha_fwd_splitkv_traits mask_enum mask_type; bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_lse; - bool do_fp8_static_quant; - bool has_sink = false; + bool do_fp8_static_quant = false; + bool has_sink = false; // TODO: padding check is inside this api }; float fmha_fwd_splitkv(fmha_fwd_splitkv_traits, diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp index c780bf7b6b..bcaa7d596d 100644 --- a/example/ck_tile/01_fmha/mask.hpp +++ b/example/ck_tile/01_fmha/mask.hpp @@ -73,6 +73,7 @@ struct mask_info tmp.x = r.at(ck_tile::number<1>{}); tmp.left = left_size; tmp.right = right_size; + tmp.sink = 0; } else if(t == "t" || t == "b" || t == "g") { @@ -147,7 +148,10 @@ struct mask_info } else if(str == "0") { - tmp.type = mask_enum::no_mask; + tmp.type = mask_enum::no_mask; + tmp.left = -1; + tmp.right = -1; + tmp.sink = 0; } else if(str == "1" || str == "t") { diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 92d0dd7f73..40bdb2ff31 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -911,14 +911,15 @@ struct WaitcntLayoutGfx12 }; struct WaitcntLayoutGfx11 -{ // vm[15:10] (6), lgkm[9:4] (6), exp unused +{ // vm[15:10] (6), lgkm[9:4] (6), exp [2:0] (3) CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F; CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x3F; - CK_TILE_DEVICE static constexpr bool HAS_EXP = false; + CK_TILE_DEVICE static constexpr index_t EXP_MASK = 0x07; + CK_TILE_DEVICE static constexpr bool HAS_EXP = true; CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c) { return ((c & VM_MASK) << 10); } CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 4); } - CK_TILE_DEVICE static constexpr index_t pack_exp(index_t) { return 0; } + CK_TILE_DEVICE static constexpr index_t pack_exp(index_t c) { return (c & EXP_MASK); } }; struct WaitcntLayoutLegacy @@ -952,10 +953,14 @@ using Waitcnt = WaitcntLayoutLegacy; struct waitcnt_arg { // kMax* exposed for callers; match field widths per-arch -#if defined(__gfx12__) || defined(__gfx11__) +#if defined(__gfx12__) CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x3F; // 6 bits CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x0; // none +#elif defined(__gfx11__) + CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits + CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x3F; // 6 bits + CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x07; // 3 bits #else CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits (split) CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x0F; // 4 bits @@ -981,8 +986,8 @@ struct waitcnt_arg { if constexpr(Waitcnt::HAS_EXP) { - // EXP_MASK only exists on legacy -#if !defined(__gfx12__) && !defined(__gfx11__) + // EXP_MASK only exists on pre-gfx12 +#if !defined(__gfx12__) static_assert((cnt & ~Waitcnt::EXP_MASK) == 0, "expcnt out of range"); return Waitcnt::pack_exp(cnt); #else diff --git a/include/ck_tile/ops/fmha/block/block_dropout.hpp b/include/ck_tile/ops/fmha/block/block_dropout.hpp index 6e01ea5dda..37c1fe4805 100644 --- a/include/ck_tile/ops/fmha/block/block_dropout.hpp +++ b/include/ck_tile/ops/fmha/block/block_dropout.hpp @@ -33,6 +33,42 @@ namespace ck_tile { namespace detail { // The number of Philox 4x32 results required to fill 32x32 tile of 8-bit values constexpr index_t philox_per_tile = 64; + +// C distribution of gfx11 WMMA differs from C distribution of gfx9 MFMA and gfx12 WMMA. +// This function deinterleaves the generated random values to make them compatible with other +// architectures and verification code on host. +template +CK_TILE_DEVICE void PermuteBlockDropoutRandval(uint8_t (&random_uint8_t)[N]) +{ +#if defined(__gfx11__) + static_for<0, N, 8>{}([&](auto i_offset) { + array rs; + static_for<0, 8, 1>{}([&](auto i) { rs.data[i] = random_uint8_t[i_offset + i]; }); + + const uint32_t r0 = rs.template get_as(number<0>{}); + const uint32_t r1 = rs.template get_as(number<1>{}); + + // Deinterleave values (even and odd indices) + const uint32_t v0 = __builtin_amdgcn_perm(r1, r0, 0x06'04'02'00); + const uint32_t v1 = __builtin_amdgcn_perm(r1, r0, 0x07'05'03'01); + + // Swap rows (lane <-> lane ^ 16) + const uint32_t w0 = + __builtin_amdgcn_permlanex16(0, v0, 0x76543210, 0xfedcba98, false, true); + const uint32_t w1 = + __builtin_amdgcn_permlanex16(0, v1, 0x76543210, 0xfedcba98, false, true); + + rs.template set_as(number<0>{}, get_lane_id() < 16 ? v0 : w1); + rs.template set_as(number<1>{}, get_lane_id() < 16 ? w0 : v1); + + static_for<0, 8, 1>{}([&](auto i) { random_uint8_t[i_offset + i] = rs.data[i]; }); + }); +#else + static_assert(false, "PermuteBlockDropoutRandval is only for gfx11"); + ignore = random_uint8_t; +#endif +} + } // namespace detail struct NullBlockDropout @@ -295,6 +331,9 @@ struct BlockDropout static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); ph.get_random_16x8(random_uint8_t, ph_subsequence); } +#if defined(__gfx11__) + detail::PermuteBlockDropoutRandval(random_uint8_t); +#endif } else { @@ -566,6 +605,9 @@ struct BlockDropoutBwd static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); ph.get_random_16x8(random_uint8_t, ph_subsequence); } +#if defined(__gfx11__) + detail::PermuteBlockDropoutRandval(random_uint8_t); +#endif } else { diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 06b0d76a0d..b5d3f490ed 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -1181,7 +1181,7 @@ struct FmhaBwdDQDKDVKernel scale_rp_undrop, dropout); -#if defined(__gfx12__) +#if defined(__gfx11__) || defined(__gfx12__) // Workaround for a compiler bug (SWDEV-559729): v_wmma instructions can be incorrectly // placed in divergent branches used to store padded tensors (when some lanes are // inactive due to padding). Inline asm with dummy dependencies on VGPRs of the tensors diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index e04e08258e..e67a525ac4 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp" @@ -1692,8 +1693,10 @@ struct BlockFmhaBwdPipelineDefaultPolicy using AWarpDstr = typename WarpGemm::AWarpDstr; using CWarpDstr = typename WarpGemm::CWarpDstr; - auto pt_warp_tensor = + auto p_warp_tensor = make_static_distributed_tensor(CWarpDstr{}); + auto pt_warp_tensor = + make_static_distributed_tensor(AWarpDstr{}); constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); @@ -1705,10 +1708,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - pt_warp_tensor.get_thread_buffer() = p_in.get_y_sliced_thread_data( + p_warp_tensor.get_thread_buffer() = p_in.get_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); +#if defined(__gfx11__) + PermuteWarpGemmCToA(pt_warp_tensor, p_warp_tensor); +#else + pt_warp_tensor.get_thread_buffer() = p_warp_tensor.get_thread_buffer(); +#endif pt_out.set_y_sliced_thread_data( merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths), @@ -1742,8 +1750,10 @@ struct BlockFmhaBwdPipelineDefaultPolicy using AWarpDstr = typename WarpGemm::AWarpDstr; using CWarpDstr = typename WarpGemm::CWarpDstr; - auto dst_warp_tensor = + auto ds_warp_tensor = make_static_distributed_tensor(CWarpDstr{}); + auto dst_warp_tensor = + make_static_distributed_tensor(AWarpDstr{}); constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); @@ -1755,10 +1765,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - dst_warp_tensor.get_thread_buffer() = ds_in.get_y_sliced_thread_data( + ds_warp_tensor.get_thread_buffer() = ds_in.get_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); +#if defined(__gfx11__) + PermuteWarpGemmCToA(dst_warp_tensor, ds_warp_tensor); +#else + dst_warp_tensor.get_thread_buffer() = ds_warp_tensor.get_thread_buffer(); +#endif dst_out.set_y_sliced_thread_data( merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths), diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp index e516fc8eea..d7696f0f76 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { @@ -675,8 +676,15 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS i_page_block_v = v_page_block_navigator.move_tile_window( i_page_block_v, v_dram_window, {0, kK1}, physical_next_block_id_v); +#if defined(__gfx11__) + auto p = make_static_distributed_tensor( + decltype(gemm_1)::template MakeABlockTileDistribution()); + PermuteWarpGemmCToA( + p, cast_tile(tile_elementwise_in(p_compute_element_func, p_compute))); +#else const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); +#endif // STAGE 3, KV gemm if constexpr(k1_loops > 1) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index c09330f847..ef6ed8b4e8 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { @@ -704,8 +705,15 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS i_page_block_v = v_page_block_navigator.move_tile_window( i_page_block_v, v_dram_window, {0, kK1}, physical_next_block_id_v); +#if defined(__gfx11__) + auto p = make_static_distributed_tensor( + decltype(gemm_1)::template MakeABlockTileDistribution()); + PermuteWarpGemmCToA( + p, cast_tile(tile_elementwise_in(p_compute_element_func, p_compute))); +#else const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); +#endif // STAGE 3, KV gemm if constexpr(k1_loops > 1) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 2fbc9fdb54..35654840bd 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -7,6 +7,7 @@ #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { @@ -717,8 +718,15 @@ struct BlockFmhaPipelineQRKSVS move_tile_window(v_dram_window, {0, kK1}); +#if defined(__gfx11__) + auto p = make_static_distributed_tensor( + decltype(gemm_1)::template MakeABlockTileDistribution()); + PermuteWarpGemmCToA( + p, cast_tile(tile_elementwise_in(p_compute_element_func, p_compute))); +#else const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); +#endif float v_descale = 1.0f; if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 230c614649..b1681e07e4 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -77,6 +77,7 @@ #include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp" +#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_interleaved_pk_type.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp b/include/ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp new file mode 100644 index 0000000000..4ce787b19d --- /dev/null +++ b/include/ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp @@ -0,0 +1,45 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// C distribution of gfx11 WMMA is not compatible with A distribution: +// C: 2 lanes per row (lane and lane + 16), 8 values per lane are interleaved. +// A: 1 lane per row, 16 values, lane and lane + 16 have the same values. +// This function transforms one ditribution to another for GEMM-GEMM scenarios. +template +CK_TILE_DEVICE static constexpr void PermuteWarpGemmCToA(OutTensor& out, const InTensor& in) +{ +#if defined(__gfx11__) + static_assert(sizeof(typename OutTensor::DataType) == 2); + static_assert(std::is_same_v); + + constexpr index_t n_out = OutTensor::get_thread_buffer_size(); + static_assert(n_out == InTensor::get_thread_buffer_size() * 2); + + // Perm byte selectors are swapped for the second row (16 lanes) because it needs to be done + // once instead to swapping w and v everytime + const uint32_t byte_selector0 = get_lane_id() < 16 ? 0x05'04'01'00 : 0x01'00'05'04; + const uint32_t byte_selector1 = get_lane_id() < 16 ? 0x07'06'03'02 : 0x03'02'07'06; + static_for<0, n_out, 1>{}([&](auto i) { + const auto v = in.get_thread_buffer().template get_as(i); + // Swap rows (lane <-> lane ^ 16) + const auto w = __builtin_amdgcn_permlanex16(0, v, 0x76543210, 0xfedcba98, false, true); + // Interleave values of lane and lane ^ 16 + out.get_thread_buffer().template set_as( + number{}, __builtin_amdgcn_perm(w, v, byte_selector0)); + out.get_thread_buffer().template set_as( + number{}, __builtin_amdgcn_perm(w, v, byte_selector1)); + }); +#else + static_assert(false, "PermuteWarpGemmCToA is only for gfx11"); + ignore = out; + ignore = in; +#endif +} + +} // namespace ck_tile diff --git a/test/ck_tile/fmha/CMakeLists.txt b/test/ck_tile/fmha/CMakeLists.txt index e591d5066f..60779f7f51 100644 --- a/test/ck_tile/fmha/CMakeLists.txt +++ b/test/ck_tile/fmha/CMakeLists.txt @@ -1,11 +1,6 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Keep in sync with example/ck_tile/01_fmha/CMakeLists.txt -if(NOT SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx12") - return() -endif() - set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances") set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances") @@ -18,9 +13,19 @@ function(add_gtest_fwd test_group) set(CPP_TYPE_fp8bf16 "FmhaFwdFp8Bf16") set(CPP_TYPE_fp32 "FmhaFwdFp32") + set(sources) + if(TARGET ${FMHA_FWD_INSTANCES}) + get_target_property(sources ${FMHA_FWD_INSTANCES} SOURCES) + message(VERBOSE "${FMHA_FWD_INSTANCES} SOURCES ${sources}") + endif() + set(all_tests) foreach(type ${V_TYPES}) set(name "${test_group}_${type}") + if(NOT sources MATCHES "_${type}_") + message(STATUS "No FMHA FWD instances for ${type}, skip ${name}") + continue() + endif() add_gtest_executable(${name} test_fmha_fwd.cpp) get_test_property(${name} LABELS COMMON_LABELS) set_tests_properties(${name} PROPERTIES LABELS "${COMMON_LABELS};${TEST_NAME};${test_group}") @@ -38,9 +43,19 @@ function(add_gtest_bwd test_group) set(CPP_TYPE_bf16 "FmhaBwdBf16") set(CPP_TYPE_fp32 "FmhaBwdFp32") + set(sources) + if(TARGET ${FMHA_BWD_INSTANCES}) + get_target_property(sources ${FMHA_BWD_INSTANCES} SOURCES) + message(VERBOSE "${FMHA_BWD_INSTANCES} SOURCES ${sources}") + endif() + set(all_tests) foreach(type ${V_TYPES}) set(name "${test_group}_${type}") + if(NOT sources MATCHES "_${type}_") + message(STATUS "No FMHA BWD instances for ${type}, skip ${name}") + continue() + endif() add_gtest_executable(${name} test_fmha_bwd.cpp) get_test_property(${name} LABELS COMMON_LABELS) set_tests_properties(${name} PROPERTIES LABELS "${COMMON_LABELS};${TEST_NAME};${test_group}")