mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
mfma_32x32x64_fp8/bf8 (#2148)
* support for mfma_32x32x64_fp8
* clang-formatted
* Fixing sparsity in codegen
[ROCm/composable_kernel commit: d58f2b8bd0]
This commit is contained in:
@@ -228,6 +228,18 @@ using WarpGemmMfma_f32_16x16x128_bf8_fp8 = WarpGemmImpl<WarpGemmAtrributeMfma<
|
||||
using WarpGemmMfma_f32_16x16x128_bf8_bf8 = WarpGemmImpl<WarpGemmAtrributeMfma<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x64_fp8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x64_fp8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x64_bf8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x64_bf8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
@@ -1440,6 +1440,104 @@ template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base<bf8_t, bf8_t, Ctrl_>;
|
||||
|
||||
template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = AType_;
|
||||
using BDataType = BType_;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = ext_vector_t<ADataType, 32>;
|
||||
using BVecType = ext_vector_t<BDataType, 32>;
|
||||
using CVecType = ext_vector_t<CDataType, 16>;
|
||||
|
||||
static constexpr index_t kM = 32;
|
||||
static constexpr index_t kN = 32;
|
||||
static constexpr index_t kK = 64;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 32;
|
||||
static constexpr index_t kBNLane = 32;
|
||||
static constexpr index_t kABKLane = 2;
|
||||
static constexpr index_t kABKPerLane = 32;
|
||||
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 32;
|
||||
static constexpr index_t kCM0PerLane = 4;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
//__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a,
|
||||
// opsel, scale_b)
|
||||
#if defined(__gfx950__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
a_vec, b_vec, c_vec, 0, 0, 0, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
a_vec, b_vec, c_vec, 0, 1, 0, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
a_vec, b_vec, c_vec, 1, 0, 0, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
a_vec, b_vec, c_vec, 1, 1, 0, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
a_vec, b_vec, CVecType{0.f}, 0, 0, 0, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
a_vec, b_vec, CVecType{0.f}, 0, 1, 0, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
a_vec, b_vec, CVecType{0.f}, 1, 0, 0, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
a_vec, b_vec, CVecType{0.f}, 1, 1, 0, 0, 0, 0));
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base<fp8_t, fp8_t, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base<fp8_t, bf8_t, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base<bf8_t, fp8_t, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base<bf8_t, bf8_t, Ctrl_>;
|
||||
|
||||
// int8
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
|
||||
|
||||
@@ -74,6 +74,11 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float,
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8; };
|
||||
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8; };
|
||||
|
||||
// clang-format on
|
||||
} // namespace impl
|
||||
|
||||
|
||||
@@ -282,14 +282,14 @@ class GemmCodeGenerator:
|
||||
|
||||
def _generate_common_header(self):
|
||||
"""Generate common header with datatypes and layout"""
|
||||
ctype = self.config.datatype
|
||||
atype = self.config.datatype
|
||||
btype = self.config.datatype
|
||||
self.ctype = self.config.datatype
|
||||
self.atype = self.config.datatype
|
||||
self.btype = self.config.datatype
|
||||
if self.config.datatype in ['fp8', 'bf8']:
|
||||
ctype = 'fp16'
|
||||
self.ctype = 'fp16'
|
||||
elif self.config.datatype in ['int4']:
|
||||
atype = 'fp16'
|
||||
ctype = 'fp16'
|
||||
self.atype = 'fp16'
|
||||
self.ctype = 'fp16'
|
||||
|
||||
content = f"""// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
@@ -298,10 +298,10 @@ class GemmCodeGenerator:
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
// Data types
|
||||
using ADataType = {DATA_TYPE_MAP[atype]};
|
||||
using BDataType = {DATA_TYPE_MAP[btype]};
|
||||
using ADataType = {DATA_TYPE_MAP[self.atype]};
|
||||
using BDataType = {DATA_TYPE_MAP[self.btype]};
|
||||
using AccDataType = float;
|
||||
using CDataType = {DATA_TYPE_MAP[ctype]};
|
||||
using CDataType = {DATA_TYPE_MAP[self.ctype]};
|
||||
|
||||
// Layout configurations
|
||||
using ALayout = {LAYOUT_MAP[self.config.layouts[0]]};
|
||||
@@ -499,7 +499,7 @@ struct GemmDispatcher {
|
||||
|
||||
static void init(bool structured_sparsity) {
|
||||
auto& kernel_map = get_kernel_map();
|
||||
if(!kernel_map.empty()) return;
|
||||
if(!kernel_map.empty()) return;
|
||||
\n"""
|
||||
# Add tile/warp instantiations
|
||||
tile_params = set(itertools.product(
|
||||
@@ -516,12 +516,25 @@ struct GemmDispatcher {
|
||||
|
||||
|
||||
for group in self.all_kernels:
|
||||
content += f""" kernel_map["{group}"] = [=](ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify, ck_tile::GemmHostArgs& args,
|
||||
const ck_tile::stream_config& stream) {{
|
||||
"""
|
||||
content += f""" kernel_map["{group}"] = [=](ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify, ck_tile::GemmHostArgs& args,
|
||||
const ck_tile::stream_config& stream) {{
|
||||
if(structured_sparsity){{ // SMFMA"""
|
||||
for tile in tile_params:
|
||||
# Check if we have valid tile/warp combinations
|
||||
# (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m
|
||||
if ((tile[0]/(tile[3] * tile[7]) * tile[3] * tile[7]) != tile[0]) or \
|
||||
((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]):
|
||||
continue
|
||||
sparse = self.atype == 'fp16' and \
|
||||
((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or
|
||||
(tile[6] == 16 and tile[7] == 16 and tile[8] == 32))
|
||||
content += f"""
|
||||
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);"""
|
||||
content += f"""
|
||||
}} else {{"""
|
||||
for tile in tile_params:
|
||||
# Check if we have valid tile/warp combinations
|
||||
# (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m
|
||||
@@ -529,13 +542,10 @@ struct GemmDispatcher {
|
||||
((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]):
|
||||
continue
|
||||
content += f"""
|
||||
if(structured_sparsity) {{
|
||||
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {1}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);
|
||||
}} else {{
|
||||
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {0}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);
|
||||
}}"""
|
||||
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(False)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);"""
|
||||
content += f"""
|
||||
}};\n"""
|
||||
}}
|
||||
}};\n"""
|
||||
|
||||
content += """ }
|
||||
|
||||
|
||||
Reference in New Issue
Block a user