From 4947f0306c063081e992d2845103a1276fc31822 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Sat, 31 Jan 2026 04:02:49 +0800 Subject: [PATCH] [CK_TILE] Fix incompatible vector type arguments for the intrinsic calls (#3672) * Change call to the intrinsics * fix clang format * Undo changes under include/ck/utility * Use named variable as vector size --------- Co-authored-by: illsilin_amdeng [ROCm/composable_kernel commit: 8c1788757a88ee03bc8dbeb69704832c99fa719c] --- .../warp/warp_gemm_attribute_mfma_impl.hpp | 68 +++++++++++++++---- 1 file changed, 56 insertions(+), 12 deletions(-) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index bd65f53383..9e23a06b23 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -612,7 +612,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 else { #if defined(__gfx90a__) || defined(__gfx94__) - c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); + c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k( + bit_cast>(a_vec), + bit_cast>(b_vec), + c_vec, + 0, + 0, + 0); #elif defined(__gfx908__) static_for<0, 2, 1>{}([&](auto k) { c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16( @@ -637,8 +643,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { #if defined(__gfx90a__) || defined(__gfx94__) - return bit_cast( - __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0)); + return bit_cast(__builtin_amdgcn_mfma_f32_32x32x8bf16_1k( + bit_cast>(a_vec), + bit_cast>(b_vec), + fp32x16_t{0.f}, + 0, + 0, + 0)); #elif defined(__gfx908__) CVecType c_vec{0.f}; static_for<0, 2, 1>{}([&](auto k) { @@ -700,7 +711,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16bf16_1k", Ctrl) { #if defined(__gfx90a__) || defined(__gfx94__) - c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); + c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k( + bit_cast>(a_vec), + bit_cast>(b_vec), + c_vec, + 0, + 0, + 0); #elif defined(__gfx908__) static_for<0, 2, 1>{}([&](auto k) { c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16( @@ -725,8 +742,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { #if defined(__gfx90a__) || defined(__gfx94__) - return bit_cast( - __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); + return bit_cast(__builtin_amdgcn_mfma_f32_16x16x16bf16_1k( + bit_cast>(a_vec), + bit_cast>(b_vec), + fp32x4_t{0.f}, + 0, + 0, + 0)); #elif defined(__gfx908__) CVecType c_vec{0.f}; static_for<0, 2, 1>{}([&](auto k) { @@ -790,7 +812,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4 else { #if defined(__gfx90a__) || defined(__gfx94__) - c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); + c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( + bit_cast>(a_vec), + bit_cast>(b_vec), + c_vec, + 0, + 0, + 0); #elif defined(__gfx908__) static_for<0, 2, 1>{}([&](auto k) { c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16( @@ -815,8 +843,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { #if defined(__gfx90a__) || defined(__gfx94__) - return bit_cast( - __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); + return bit_cast(__builtin_amdgcn_mfma_f32_4x4x4bf16_1k( + bit_cast>(a_vec), + bit_cast>(b_vec), + fp32x4_t{0.f}, + 0, + 0, + 0)); #elif defined(__gfx908__) CVecType c_vec{0.f}; static_for<0, 2, 1>{}([&](auto k) { @@ -880,7 +913,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4 else { #if defined(__gfx90a__) || defined(__gfx94__) - c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); + c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( + bit_cast>(a_vec), + bit_cast>(b_vec), + c_vec, + 0, + 0, + 0); #elif defined(__gfx908__) static_for<0, 2, 1>{}([&](auto k) { c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16( @@ -905,8 +944,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { #if defined(__gfx90a__) || defined(__gfx94__) - return bit_cast( - __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); + return bit_cast(__builtin_amdgcn_mfma_f32_4x4x4bf16_1k( + bit_cast>(a_vec), + bit_cast>(b_vec), + fp32x4_t{0.f}, + 0, + 0, + 0)); #elif defined(__gfx908__) CVecType c_vec{0.f}; static_for<0, 2, 1>{}([&](auto k) {