use default scale (no scale) for 16x16x128 mfma scale

This commit is contained in:
Sami Remes
2026-01-30 12:55:46 -05:00
parent 407df88c02
commit 4d241289c9

View File

@@ -1555,6 +1555,9 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
// To get unity scale: 2^(kDefaultScale - 127) = 1.0
static constexpr index_t kDefaultScale = 0x7F7F7F7F;
// c_vec += a_vec * b_vec
template <index_t opselA, index_t opselB, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
@@ -1624,13 +1627,13 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
operator()<0, 0>(c_vec, a_vec, 0, b_vec, 0);
operator()<0, 0>(c_vec, a_vec, kDefaultScale, b_vec, kDefaultScale);
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
return operator()<0, 0>(a_vec, 0, b_vec, 0);
return operator()<0, 0>(a_vec, kDefaultScale, b_vec, kDefaultScale);
}
};