Fix mfma instruction

This commit is contained in:
Enrico Degregori
2026-06-18 08:08:11 +00:00
parent aa5ed1a749
commit 2b68eb63f3

View File

@@ -1725,9 +1725,6 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4
static constexpr index_t kScaleGranularity = 32;
// To get unity scale: 2^(kDefaultScale - 127) = 1.0
static constexpr index_t kDefaultScale = 0x7F7F7F7F;
// c_vec += a_vec * b_vec
template <typename... Params>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
@@ -1803,14 +1800,14 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{
operator()<Params...>(c_vec, a_vec, kDefaultScale, b_vec, kDefaultScale);
operator()<Params...>(c_vec, a_vec, 0, b_vec, 0);
}
// c_vec = a_vec * b_vec
template <typename... Params>
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
return operator()<Params...>(a_vec, kDefaultScale, b_vec, kDefaultScale);
return operator()<Params...>(a_vec, 0, b_vec, 0);
}
};