From 9661bb400bb1e7742f894379acfa61b6223c8c3c Mon Sep 17 00:00:00 2001 From: Gino Lu Date: Mon, 8 Sep 2025 19:09:55 -0500 Subject: [PATCH] fix type error --- .../warp/warp_gemm_attribute_mfma_impl.hpp | 33 ++++++++++++++++--- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index 8de9508630..b212d9aa4e 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1533,12 +1533,24 @@ struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4 //__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a, // opsel, scale_b) #if defined(__gfx950__) - c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( - a_vec, b_vec, c_vec, 4, 4, opselA, a_scale, opselB, b_scale); + auto arg_a = bit_cast(a_vec); + auto arg_b = bit_cast(b_vec); + c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + int32x8_t{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, + int32x8_t{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, + CVecType{0.f}, + 4, + 4, + opselA, + a_scale, + opselB, + b_scale); #else ck_tile::ignore = c_vec; ck_tile::ignore = a_vec; ck_tile::ignore = b_vec; + ck_tile::ignore = a_scale; + ck_tile::ignore = b_scale; #endif } @@ -1550,12 +1562,23 @@ struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4 const int32_t& b_scale) const { #if defined(__gfx950__) - if constexpr(std::is_same_v && std::is_same_v) - return bit_cast(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( - a_vec, b_vec, CVecType{0.f}, 4, 4, opselA, a_scale, opselB, b_scale)); + auto arg_a = bit_cast(a_vec); + auto arg_b = bit_cast(b_vec); + return bit_cast(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + int32x8_t{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, + int32x8_t{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, + CVecType{0.f}, + 4, + 4, + opselA, + a_scale, + opselB, + b_scale)); #else ck_tile::ignore = a_vec; ck_tile::ignore = b_vec; + ck_tile::ignore = a_scale; + ck_tile::ignore = b_scale; return CVecType{0.f}; #endif }