From a05514ca58ad7b1a578bf3faaa3ec5823a20ef03 Mon Sep 17 00:00:00 2001 From: Yi DING Date: Thu, 13 Nov 2025 20:22:05 +0800 Subject: [PATCH] [CK_TILE] Improve F8F6F4 Scaled WarpGemm (#3197) * [CK_TILE] Improve F8F6F4 Scaled WarpGemm * Thanks, Copilot [ROCm/composable_kernel commit: 8d50001b939691134a0b078ed15a41e22ee08bd0] --- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 26 +- .../warp/warp_gemm_attribute_mfma_impl.hpp | 188 ++++---------- .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 232 ++++++++---------- 3 files changed, 175 insertions(+), 271 deletions(-) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index ca5d2f872d..0d41461038 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -308,50 +308,50 @@ using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl using WarpGemmMfma_f32_16x16x128_fp4 = WarpGemmImpl< - WarpGemmAttributeMfma, + WarpGemmAttributeMfma, AttrNumAccess>>; template -using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl< - WarpGemmAttributeMfma, +using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl< // + WarpGemmAttributeMfma, AttrNumAccess>>; template -using WarpGemmMfma_f32_16x16x128_fp8_bf8 = WarpGemmImpl< - WarpGemmAttributeMfma, +using WarpGemmMfma_f32_16x16x128_fp8_bf8 = WarpGemmImpl< // + WarpGemmAttributeMfma, AttrNumAccess>>; template -using WarpGemmMfma_f32_16x16x128_bf8_fp8 = WarpGemmImpl< - WarpGemmAttributeMfma, +using WarpGemmMfma_f32_16x16x128_bf8_fp8 = WarpGemmImpl< // + WarpGemmAttributeMfma, AttrNumAccess>>; template -using WarpGemmMfma_f32_16x16x128_bf8_bf8 = WarpGemmImpl< - WarpGemmAttributeMfma, +using WarpGemmMfma_f32_16x16x128_bf8_bf8 = WarpGemmImpl< // + WarpGemmAttributeMfma, AttrNumAccess>>; template using WarpGemmMfma_f32_16x16x128_fp8_fp8_CTransposed = WarpGemmImpl, + WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4, AttrNumAccess>>; template using WarpGemmMfma_f32_16x16x128_fp8_bf8_CTransposed = WarpGemmImpl, + WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4, AttrNumAccess>>; template using WarpGemmMfma_f32_16x16x128_bf8_fp8_CTransposed = WarpGemmImpl, + WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4, AttrNumAccess>>; template using WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed = WarpGemmImpl, + WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4, AttrNumAccess>>; template diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index 8237f6fd50..1ddf0c0cf8 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1527,113 +1527,15 @@ using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8 = WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; template -struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base +struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4 { static constexpr WGAttrCtlEnum Ctrl = Ctrl_; using ADataType = AType_; using BDataType = BType_; using CDataType = float; - using AVecType = ext_vector_t; - using BVecType = ext_vector_t; - using CVecType = ext_vector_t; - - static constexpr index_t kM = 16; - static constexpr index_t kN = 16; - static constexpr index_t kK = 128; - - static constexpr index_t kAMBlock = 1; - static constexpr index_t kBNBlock = 1; - - static constexpr index_t kAMLane = 16; - static constexpr index_t kBNLane = 16; - static constexpr index_t kABKLane = 4; - static constexpr index_t kABKPerLane = 32; - - static constexpr index_t kCMLane = 4; - static constexpr index_t kCNLane = 16; - static constexpr index_t kCM0PerLane = 1; - static constexpr index_t kCM1PerLane = 4; - - // c_vec += a_vec * b_vec - template - CK_TILE_DEVICE void operator()(CVecType& c_vec, - const AVecType& a_vec, - const BVecType& b_vec, - bool_constant = {}) const - { - //__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a, - // opsel, scale_b) -#if defined(__gfx950__) - if constexpr(std::is_same_v && std::is_same_v) - c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( - a_vec, b_vec, c_vec, 0, 0, 0, 0, 0, 0); - else if constexpr(std::is_same_v && std::is_same_v) - c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( - a_vec, b_vec, c_vec, 0, 1, 0, 0, 0, 0); - else if constexpr(std::is_same_v && std::is_same_v) - c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( - a_vec, b_vec, c_vec, 1, 0, 0, 0, 0, 0); - else if constexpr(std::is_same_v && std::is_same_v) - c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( - a_vec, b_vec, c_vec, 1, 1, 0, 0, 0, 0); -#else - ck_tile::ignore = c_vec; - ck_tile::ignore = a_vec; - ck_tile::ignore = b_vec; -#endif - } - - // c_vec = a_vec * b_vec - CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const - { -#if defined(__gfx950__) - if constexpr(std::is_same_v && std::is_same_v) - return bit_cast(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( - a_vec, b_vec, CVecType{0.f}, 0, 0, 0, 0, 0, 0)); - else if constexpr(std::is_same_v && std::is_same_v) - return bit_cast(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( - a_vec, b_vec, CVecType{0.f}, 0, 1, 0, 0, 0, 0)); - else if constexpr(std::is_same_v && std::is_same_v) - return bit_cast(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( - a_vec, b_vec, CVecType{0.f}, 1, 0, 0, 0, 0, 0)); - else if constexpr(std::is_same_v && std::is_same_v) - return bit_cast(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( - a_vec, b_vec, CVecType{0.f}, 1, 1, 0, 0, 0, 0)); -#else - ck_tile::ignore = a_vec; - ck_tile::ignore = b_vec; - return CVecType{0.f}; -#endif - } -}; - -template -using WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_fp8 = - WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base; - -template -using WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_bf8 = - WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base; - -template -using WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_fp8 = - WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base; - -template -using WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8 = - WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base; - -template -struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4 -{ - static constexpr WGAttrCtlEnum Ctrl = Ctrl_; - using ADataType = pk_fp4_t; - using BDataType = pk_fp4_t; - using CDataType = float; - - using AVecType = ext_vector_t; - using BVecType = ext_vector_t; + using AVecType = ext_vector_t::PackedSize>; + using BVecType = ext_vector_t::PackedSize>; using CVecType = ext_vector_t; static constexpr index_t kM = 16; @@ -1662,21 +1564,38 @@ struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4 const int32_t& b_scale, bool_constant = {}) const { - //__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a, - // opsel, scale_b) #if defined(__gfx950__) - auto arg_a = bit_cast(a_vec); - auto arg_b = bit_cast(b_vec); - c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( - int32x8_t{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, - int32x8_t{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, - c_vec, - 4, - 4, - opselA, - a_scale, - opselB, - b_scale); + auto dtype2conf = [](auto dtype) { + if constexpr(std::is_same_v) + return make_tuple(number<0>{}, int32x8_t{}); + else if constexpr(std::is_same_v) + return make_tuple(number<1>{}, int32x8_t{}); + // else if e2m3 => make_tuple(number<2>{}, int32x6_t{}) + // else if e3m2 => make_tuple(number<3>{}, int32x6_t{}) + else if constexpr(std::is_same_v) + return make_tuple(number<4>{}, int32x4_t{}); + else + static_assert(false, "Unsupported data type for mfma scale"); + }; + auto dtype2code = [&](auto dtype) { return dtype2conf(dtype)(number<0>{}); }; + auto dtype2vec = [&](auto dtype) { return dtype2conf(dtype)(number<1>{}); }; + auto arg256 = [&](auto x) { + if constexpr(sizeof(x) == 16) + return int32x8_t{x[0], x[1], x[2], x[3], 0, 0, 0, 0}; + else if constexpr(sizeof(x) == 24) + return int32x8_t{x[0], x[1], x[2], x[3], x[4], x[5], 0, 0}; + else if constexpr(sizeof(x) == 32) + return x; + else + static_assert(false, "Unexpected vector size for mfma scale"); + }; + + auto arg_a = bit_cast(a_vec); + auto arg_b = bit_cast(b_vec); + constexpr int cbsz = decltype(dtype2code(ADataType{}))::value; + constexpr int blgp = decltype(dtype2code(BDataType{}))::value; + c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg256(arg_a), arg256(arg_b), c_vec, cbsz, blgp, opselA, a_scale, opselB, b_scale); #else ck_tile::ignore = c_vec; ck_tile::ignore = a_vec; @@ -1693,26 +1612,25 @@ struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4 const BVecType& b_vec, const int32_t& b_scale) const { -#if defined(__gfx950__) - auto arg_a = bit_cast(a_vec); - auto arg_b = bit_cast(b_vec); - return bit_cast(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( - int32x8_t{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, - int32x8_t{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, - CVecType{0.f}, - 4, - 4, - opselA, - a_scale, - opselB, - b_scale)); -#else - ck_tile::ignore = a_vec; - ck_tile::ignore = b_vec; - ck_tile::ignore = a_scale; - ck_tile::ignore = b_scale; - return CVecType{0.f}; -#endif + CVecType c_vec{0.f}; + operator()(c_vec, a_vec, a_scale, b_vec, b_scale); + return c_vec; + } + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + operator()<0, 0>(c_vec, a_vec, 0, b_vec, 0); + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + return operator()<0, 0>(a_vec, 0, b_vec, 0); } }; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 1858f80fa8..fe9a611b55 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -10,6 +10,13 @@ namespace ck_tile { namespace impl { +namespace warp_gemm_dispatcher { + +// C++20 using enum +static inline constexpr auto ESingle = WGAttrNumAccessEnum::Single; +static inline constexpr auto EDouble = WGAttrNumAccessEnum::Double; +static inline constexpr auto EQuad = WGAttrNumAccessEnum::Quad; + template -struct WarpGemmDispatcher; + WGAttrNumAccessEnum AttrNumAccess = ESingle> +struct Dispatcher; // clang-format off // fp32 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K4; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K16<>; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K4; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K16<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution<>; }; // fp16 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16<>; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>; }; -template<> struct WarpGemmDispatcher { - using Type = WarpGemmMfmaF16F16F32M32N32K16; }; -template<> struct WarpGemmDispatcher { - using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32<>; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>; }; -template<> struct WarpGemmDispatcher { - using Type = WarpGemmMfmaF16F16F32M16N16K32; }; -template<> struct WarpGemmDispatcher { - using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M4N64K16; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M64N4K16; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M4N64K16; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M64N4K16; }; // WMMA cases #if defined(__gfx11__) || defined(__gfx12__) -template struct WarpGemmDispatcher { using Type = WarpGemmWmma_f32_16x16x16_f16_f16;}; +template struct Dispatcher { using Type = WarpGemmWmma_f32_16x16x16_f16_f16;}; #else -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; }; #endif -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution; }; // fp16 2:4 structural sparsity // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity -template<> struct WarpGemmDispatcher { using Type = WarpGemmSmfmacF16F16F32M32N32K16; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmSmfmacF16F16F32M16N16K32; }; +template<> struct Dispatcher { using Type = WarpGemmSmfmacF16F16F32M32N32K16; }; +template<> struct Dispatcher { using Type = WarpGemmSmfmacF16F16F32M16N16K32; }; // bf16 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<>; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>; }; -template<> struct WarpGemmDispatcher { - using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; }; -template<> struct WarpGemmDispatcher { - using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<>; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>; }; -template<> struct WarpGemmDispatcher { - using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; -template<> struct WarpGemmDispatcher { - using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M4N64K16; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M64N4K16; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M4N64K16; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M64N4K16; }; // WMMA cases #if defined(__gfx11__) || defined(__gfx12__) -template struct WarpGemmDispatcher { using Type = WarpGemmWmma_f32_16x16x16_bf16_bf16; }; +template struct Dispatcher { using Type = WarpGemmWmma_f32_16x16x16_bf16_bf16; }; #else -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; }; #endif -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution; }; // fp8 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8_CTransposed; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_bf8; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_bf8; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8<>; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8<>; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8<>; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8<>; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8_CTransposed<>; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8_CTransposed<>; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8_CTransposed<>; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8_CTransposed; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_bf8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_bf8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp4<>; }; +// scale mfma based f8f6f4 +template struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8; }; +template struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8; }; +template struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8; }; +template struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8; }; +template struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8_CTransposed; }; +template struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8_CTransposed; }; +template struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8_CTransposed; }; +template struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp4<>; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<>; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<>; }; -template<> struct WarpGemmDispatcher { - using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8; }; -template<> struct WarpGemmDispatcher { - using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8; }; -template<> struct WarpGemmDispatcher { - using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8; }; -template<> struct WarpGemmDispatcher { - using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8; }; - -template<> struct WarpGemmDispatcher { - using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8; }; -template<> struct WarpGemmDispatcher { - using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8; }; -template<> struct WarpGemmDispatcher { - using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8; }; -template<> struct WarpGemmDispatcher { - using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8; }; - -template<> struct WarpGemmDispatcher { - using Type = WarpGemmMfma_f32_16x16x128_fp4; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8; }; //WMMA cases -template struct WarpGemmDispatcher { using Type =WarpGemmWmma_f32_16x16x16_f8_f8; }; -template struct WarpGemmDispatcher { using Type =WarpGemmWmma_f32_16x16x16_bf8_bf8; }; -template struct WarpGemmDispatcher { using Type =WarpGemmWmma_f32_16x16x16_f8_bf8; }; -template struct WarpGemmDispatcher { using Type =WarpGemmWmma_f32_16x16x16_bf8_f8; }; +template struct Dispatcher { using Type = WarpGemmWmma_f32_16x16x16_f8_f8; }; +template struct Dispatcher { using Type = WarpGemmWmma_f32_16x16x16_bf8_bf8; }; +template struct Dispatcher { using Type = WarpGemmWmma_f32_16x16x16_f8_bf8; }; +template struct Dispatcher { using Type = WarpGemmWmma_f32_16x16x16_bf8_f8; }; // int8 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_i32_32x32x16_i8_i8; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_i32_16x16x32_i8_i8; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_i32_32x32x16_i8_i8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_i32_16x16x32_i8_i8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed; }; // WMMA cases -template struct WarpGemmDispatcher { using Type = WarpGemmWmma_i32_16x16x16_i8_i8;}; +template struct Dispatcher { using Type = WarpGemmWmma_i32_16x16x16_i8_i8;}; // clang-format on +} // namespace warp_gemm_dispatcher } // namespace impl template -using WarpGemmDispatcher = typename impl::WarpGemmDispatcher::Type; +using WarpGemmDispatcher = typename impl::warp_gemm_dispatcher::Dispatcher< // + AType, + BType, + AccType, + MPerWave, + NPerWave, + KPerWave, + TransposeC, + SwizzleA, + UseStructuredSparsity, + AttrNumAccess>::Type; } // namespace ck_tile