mfma_32x32x64_fp8/bf8 (#2148)

* support for mfma_32x32x64_fp8

* clang-formatted

* Fixing sparsity in codegen
This commit is contained in:
Khushbu Agarwal
2025-05-01 13:36:24 -07:00
committed by GitHub
parent 619fba3134
commit d58f2b8bd0
4 changed files with 147 additions and 22 deletions

View File

@@ -282,14 +282,14 @@ class GemmCodeGenerator:
def _generate_common_header(self):
"""Generate common header with datatypes and layout"""
ctype = self.config.datatype
atype = self.config.datatype
btype = self.config.datatype
self.ctype = self.config.datatype
self.atype = self.config.datatype
self.btype = self.config.datatype
if self.config.datatype in ['fp8', 'bf8']:
ctype = 'fp16'
self.ctype = 'fp16'
elif self.config.datatype in ['int4']:
atype = 'fp16'
ctype = 'fp16'
self.atype = 'fp16'
self.ctype = 'fp16'
content = f"""// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
@@ -298,10 +298,10 @@ class GemmCodeGenerator:
#include "ck_tile/core.hpp"
// Data types
using ADataType = {DATA_TYPE_MAP[atype]};
using BDataType = {DATA_TYPE_MAP[btype]};
using ADataType = {DATA_TYPE_MAP[self.atype]};
using BDataType = {DATA_TYPE_MAP[self.btype]};
using AccDataType = float;
using CDataType = {DATA_TYPE_MAP[ctype]};
using CDataType = {DATA_TYPE_MAP[self.ctype]};
// Layout configurations
using ALayout = {LAYOUT_MAP[self.config.layouts[0]]};
@@ -499,7 +499,7 @@ struct GemmDispatcher {
static void init(bool structured_sparsity) {
auto& kernel_map = get_kernel_map();
if(!kernel_map.empty()) return;
if(!kernel_map.empty()) return;
\n"""
# Add tile/warp instantiations
tile_params = set(itertools.product(
@@ -516,12 +516,25 @@ struct GemmDispatcher {
for group in self.all_kernels:
content += f""" kernel_map["{group}"] = [=](ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify, ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& stream) {{
"""
content += f""" kernel_map["{group}"] = [=](ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify, ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& stream) {{
if(structured_sparsity){{ // SMFMA"""
for tile in tile_params:
# Check if we have valid tile/warp combinations
# (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m
if ((tile[0]/(tile[3] * tile[7]) * tile[3] * tile[7]) != tile[0]) or \
((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]):
continue
sparse = self.atype == 'fp16' and \
((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or
(tile[6] == 16 and tile[7] == 16 and tile[8] == 32))
content += f"""
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);"""
content += f"""
}} else {{"""
for tile in tile_params:
# Check if we have valid tile/warp combinations
# (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m
@@ -529,13 +542,10 @@ struct GemmDispatcher {
((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]):
continue
content += f"""
if(structured_sparsity) {{
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {1}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);
}} else {{
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {0}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);
}}"""
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(False)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);"""
content += f"""
}};\n"""
}}
}};\n"""
content += """ }