From e3f4d5d28f00bfd031933eda6a612f86dac378cc Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Mon, 21 Apr 2025 10:21:35 -0700 Subject: [PATCH] MFMA 16x16x32fp8 (#2103) * add mfma_16x16x32_fp8 * clang format code * Finished the fix for gemm basic * clang foramt * rebuild CI * recover gemm.hpp * add MFMA 16*16*32bf8 --------- Co-authored-by: solin [ROCm/composable_kernel commit: a738e43445f9f82227220922fcd2d683cc9ef626] --- .../gemm/pipeline/gemm_pipeline_problem.hpp | 2 + .../ops/gemm/pipeline/tile_gemm_traits.hpp | 3 +- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 14 ++ .../warp/warp_gemm_attribute_mfma_impl.hpp | 167 +++++++++++++++++- .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 4 + 5 files changed, 188 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index cba3677332..0b38e7789e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -32,6 +32,8 @@ struct GemmPipelineProblemBase static constexpr bool TransposeC = Traits::TransposeC; + static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity; + static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); static constexpr bool kPadM = Traits::kPadM; diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp index 0dae2eeca5..a31004b425 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -26,7 +26,8 @@ struct TileGemmTraits using BLayout = BLayout_; using CLayout = CLayout_; - static constexpr bool TransposeC = false; + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; }; template >>; +using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl, + 2>>; + +using WarpGemmMfma_f32_16x16x32_fp8_fp8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl, + 2>>; + +using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed = WarpGemmImpl>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index 21a865e792..64c7543ffe 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -623,6 +623,165 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4 }; // FP8 +template +struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = AType_; + using BDataType = BType_; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 16; + static constexpr index_t kN = 16; + static constexpr index_t kK = 32; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv) + { + if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "v", "v", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_116x16x32_fp8_bf8", "+v", "v", "v", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "v", "v", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "v", "v", "v") + } + } + else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa) + { + if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "a", "a", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "a", "a", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "a", "a", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "a", "a", "v") + } + } + else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav) + { + if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "a", "v", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "a", "v", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "a", "v", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "a", "v", "v") + } + } + else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva) + { + if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "v", "a", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "v", "a", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "v", "a", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "v", "a", "v") + } + } + else + { +#if defined(__gfx94__) + if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx94__) + if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8( + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8( + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_f32_316x16x32_bf8_bf8( + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; + template struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base { @@ -809,11 +968,17 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base template using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8 = WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; - +template +using WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8 = + WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base; template using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8 = WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; +template +using WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8 = + WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base; + template using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8 = WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 6320b33598..f437ee10c5 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -57,12 +57,16 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; // clang-format on