fix type error

This commit is contained in:
Gino Lu
2025-09-08 19:09:55 -05:00
parent 754ae0461b
commit 9661bb400b

View File

@@ -1533,12 +1533,24 @@ struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4
//__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a,
// opsel, scale_b)
#if defined(__gfx950__)
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
a_vec, b_vec, c_vec, 4, 4, opselA, a_scale, opselB, b_scale);
auto arg_a = bit_cast<int32x4_t>(a_vec);
auto arg_b = bit_cast<int32x4_t>(b_vec);
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
int32x8_t{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
int32x8_t{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
CVecType{0.f},
4,
4,
opselA,
a_scale,
opselB,
b_scale);
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = a_scale;
ck_tile::ignore = b_scale;
#endif
}
@@ -1550,12 +1562,23 @@ struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4
const int32_t& b_scale) const
{
#if defined(__gfx950__)
if constexpr(std::is_same_v<ADataType, pk_fp4_t> && std::is_same_v<BDataType, pk_fp4_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
a_vec, b_vec, CVecType{0.f}, 4, 4, opselA, a_scale, opselB, b_scale));
auto arg_a = bit_cast<int32x4_t>(a_vec);
auto arg_b = bit_cast<int32x4_t>(b_vec);
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
int32x8_t{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
int32x8_t{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
CVecType{0.f},
4,
4,
opselA,
a_scale,
opselB,
b_scale));
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = a_scale;
ck_tile::ignore = b_scale;
return CVecType{0.f};
#endif
}