mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Disable SMFMA gfx90a (#2184)
* sparsity fix for gfx90a
* reverting tile_engine changes
[ROCm/composable_kernel commit: f05e45ba59]
This commit is contained in:
@@ -109,20 +109,11 @@ using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK
|
||||
4>>;
|
||||
|
||||
// fp16 2:4 structured sparsity
|
||||
#if defined(__gfx94__) || defined(__gfx95__)
|
||||
using WarpGemmSmfmacF16F16F32M32N32K16 = WarpGemmSmfmacImpl<WarpGemmAttributeSmfmac<
|
||||
WarpGemmAttributeSmfmacImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmSmfmacImpl<WarpGemmAttributeSmfmac<
|
||||
WarpGemmAttributeSmfmacImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>>>;
|
||||
#else // gfx 90a does not support smfmac
|
||||
using WarpGemmSmfmacF16F16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
#endif
|
||||
|
||||
// bf16
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl<
|
||||
|
||||
@@ -49,7 +49,7 @@ struct WarpGemmAttributeSmfmacImplF16F16F32M32N32K16
|
||||
const int32_t& idx,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
#if defined(__gfx9__)
|
||||
#if defined(__gfx94_) or defined(__gfx95_)
|
||||
c_vec = __builtin_amdgcn_smfmac_f32_32x32x16_f16(a_vec, b_vec, c_vec, idx, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
@@ -100,7 +100,7 @@ struct WarpGemmAttributeSmfmacImplF16F16F32M16N16K32
|
||||
const int32_t& idx,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
#if defined(__gfx9__)
|
||||
#if defined(__gfx94_) or defined(__gfx95_)
|
||||
c_vec = __builtin_amdgcn_smfmac_f32_16x16x32_f16(a_vec, b_vec, c_vec, idx, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
|
||||
@@ -535,11 +535,7 @@ struct GemmDispatcher {
|
||||
((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or
|
||||
(tile[6] == 16 and tile[7] == 16 and tile[8] == 32))
|
||||
content += f"""
|
||||
#if defined(__gfx908__)
|
||||
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(False)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);
|
||||
#else
|
||||
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);
|
||||
#endif"""
|
||||
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);"""
|
||||
content += f"""
|
||||
}} else {{"""
|
||||
for tile in tile_params:
|
||||
|
||||
Reference in New Issue
Block a user