diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index fcf23e751b..b28c28eaca 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1440,16 +1440,16 @@ template using WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8 = WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base; -template -struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_base +template +struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4 { static constexpr WGAttrCtlEnum Ctrl = Ctrl_; - using ADataType = AType_; - using BDataType = BType_; + using ADataType = pk_fp4_t; + using BDataType = pk_fp4_t; using CDataType = float; - using AVecType = ext_vector_t; - using BVecType = ext_vector_t; + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; using CVecType = ext_vector_t; static constexpr index_t kM = 16; @@ -1482,9 +1482,8 @@ struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_base //__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a, // opsel, scale_b) #if defined(__gfx950__) - if constexpr(std::is_same_v && std::is_same_v) - c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( - a_vec, b_vec, c_vec, 4, 4, opselA, a_scale, opselB, b_scale); + c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, c_vec, 4, 4, opselA, a_scale, opselB, b_scale); #else ck_tile::ignore = c_vec; ck_tile::ignore = a_vec; @@ -1512,11 +1511,6 @@ struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_base }; -template -using WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4 = - WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_base; - - template struct WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base {