remove additional check when e8m0->float

This commit is contained in:
Feng Shijie
2025-08-15 06:20:46 +00:00
parent 714b341797
commit 7899fb4a8d

View File

@@ -722,8 +722,11 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
auto scale = scale_tensor.get_thread_buffer()[scale_offset];
constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size();
constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize;
constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size();
constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize;
constexpr int float_mantissa = 23;
uint32_t uscale = uint32_t(scale.data) << float_mantissa;
auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) {
if constexpr(std::is_same_v<ComputeType, half_t>)
@@ -747,7 +750,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
i,
pk_mxfp4_to_compute_v2(
quant_weight_tensor.get_thread_buffer()[quant_idx_k * PackedCnt + i],
static_cast<float>(scale)));
bit_cast<float>(uscale)));
});
};