From 5d0b1b733c535a5e618bae02785fa89060ad4957 Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Thu, 1 May 2025 13:36:24 -0700 Subject: [PATCH] mfma_32x32x64_fp8/bf8 (#2148) * support for mfma_32x32x64_fp8 * clang-formatted * Fixing sparsity in codegen [ROCm/composable_kernel commit: d58f2b8bd0c2adad65a731403673d545d8483acb] --- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 12 +++ .../warp/warp_gemm_attribute_mfma_impl.hpp | 98 +++++++++++++++++++ .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 5 + tile_engine/ops/gemm/gemm_instance_builder.py | 54 +++++----- 4 files changed, 147 insertions(+), 22 deletions(-) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 22962b9404..e75aca1d91 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -228,6 +228,18 @@ using WarpGemmMfma_f32_16x16x128_bf8_fp8 = WarpGemmImpl>>; +using WarpGemmMfma_f32_32x32x64_fp8_fp8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +using WarpGemmMfma_f32_32x32x64_fp8_bf8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +using WarpGemmMfma_f32_32x32x64_bf8_fp8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +using WarpGemmMfma_f32_32x32x64_bf8_bf8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed = WarpGemmImpl>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index cd32f35180..96c3c3d29f 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1440,6 +1440,104 @@ template using WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8 = WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base; +template +struct WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = AType_; + using BDataType = BType_; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 32; + static constexpr index_t kN = 32; + static constexpr index_t kK = 64; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 32; + static constexpr index_t kBNLane = 32; + static constexpr index_t kABKLane = 2; + static constexpr index_t kABKPerLane = 32; + + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 32; + static constexpr index_t kCM0PerLane = 4; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + //__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a, + // opsel, scale_b) +#if defined(__gfx950__) + if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + a_vec, b_vec, c_vec, 0, 0, 0, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + a_vec, b_vec, c_vec, 0, 1, 0, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + a_vec, b_vec, c_vec, 1, 0, 0, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + a_vec, b_vec, c_vec, 1, 1, 0, 0, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx950__) + if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 0, 0, 0, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 0, 1, 0, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 1, 0, 0, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 1, 1, 0, 0, 0, 0)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; + +template +using WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8 = + WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base; + +template +using WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8 = + WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base; + +template +using WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8 = + WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base; + +template +using WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8 = + WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base; + // int8 template struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8 diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 0e3342c479..64bd61a3dc 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -74,6 +74,11 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8; }; + // clang-format on } // namespace impl diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index b441bdd2d6..a748c35feb 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -282,14 +282,14 @@ class GemmCodeGenerator: def _generate_common_header(self): """Generate common header with datatypes and layout""" - ctype = self.config.datatype - atype = self.config.datatype - btype = self.config.datatype + self.ctype = self.config.datatype + self.atype = self.config.datatype + self.btype = self.config.datatype if self.config.datatype in ['fp8', 'bf8']: - ctype = 'fp16' + self.ctype = 'fp16' elif self.config.datatype in ['int4']: - atype = 'fp16' - ctype = 'fp16' + self.atype = 'fp16' + self.ctype = 'fp16' content = f"""// SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. @@ -298,10 +298,10 @@ class GemmCodeGenerator: #include "ck_tile/core.hpp" // Data types -using ADataType = {DATA_TYPE_MAP[atype]}; -using BDataType = {DATA_TYPE_MAP[btype]}; +using ADataType = {DATA_TYPE_MAP[self.atype]}; +using BDataType = {DATA_TYPE_MAP[self.btype]}; using AccDataType = float; -using CDataType = {DATA_TYPE_MAP[ctype]}; +using CDataType = {DATA_TYPE_MAP[self.ctype]}; // Layout configurations using ALayout = {LAYOUT_MAP[self.config.layouts[0]]}; @@ -499,7 +499,7 @@ struct GemmDispatcher { static void init(bool structured_sparsity) { auto& kernel_map = get_kernel_map(); - if(!kernel_map.empty()) return; + if(!kernel_map.empty()) return; \n""" # Add tile/warp instantiations tile_params = set(itertools.product( @@ -516,12 +516,25 @@ struct GemmDispatcher { for group in self.all_kernels: - content += f""" kernel_map["{group}"] = [=](ck_tile::DeviceMem& c_m_n_dev_buf, - ck_tile::HostTensor& c_m_n_host_result, - ck_tile::HostTensor& c_m_n_dev_result, - int verify, ck_tile::GemmHostArgs& args, - const ck_tile::stream_config& stream) {{ - """ + content += f""" kernel_map["{group}"] = [=](ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::HostTensor& c_m_n_host_result, + ck_tile::HostTensor& c_m_n_dev_result, + int verify, ck_tile::GemmHostArgs& args, + const ck_tile::stream_config& stream) {{ + if(structured_sparsity){{ // SMFMA""" + for tile in tile_params: + # Check if we have valid tile/warp combinations + # (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m + if ((tile[0]/(tile[3] * tile[7]) * tile[3] * tile[7]) != tile[0]) or \ + ((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]): + continue + sparse = self.atype == 'fp16' and \ + ((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or + (tile[6] == 16 and tile[7] == 16 and tile[8] == 32)) + content += f""" + run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);""" + content += f""" + }} else {{""" for tile in tile_params: # Check if we have valid tile/warp combinations # (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m @@ -529,13 +542,10 @@ struct GemmDispatcher { ((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]): continue content += f""" - if(structured_sparsity) {{ - run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {1}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream); - }} else {{ - run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {0}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream); - }}""" + run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(False)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);""" content += f""" - }};\n""" + }} + }};\n""" content += """ }