diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index bd65f53383..8272b015f9 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1555,6 +1555,9 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4 static constexpr index_t kCM0PerLane = 1; static constexpr index_t kCM1PerLane = 4; + // To get unity scale: 2^(kDefaultScale - 127) = 1.0 + static constexpr index_t kDefaultScale = 0x7F7F7F7F; + // c_vec += a_vec * b_vec template CK_TILE_DEVICE void operator()(CVecType& c_vec, @@ -1624,13 +1627,13 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4 const BVecType& b_vec, bool_constant = {}) const { - operator()<0, 0>(c_vec, a_vec, 0, b_vec, 0); + operator()<0, 0>(c_vec, a_vec, kDefaultScale, b_vec, kDefaultScale); } // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { - return operator()<0, 0>(a_vec, 0, b_vec, 0); + return operator()<0, 0>(a_vec, kDefaultScale, b_vec, kDefaultScale); } };