mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
remove additional check when e8m0->float
This commit is contained in:
@@ -722,8 +722,11 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
auto scale = scale_tensor.get_thread_buffer()[scale_offset];
|
||||
|
||||
constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size();
|
||||
constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize;
|
||||
constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size();
|
||||
constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize;
|
||||
constexpr int float_mantissa = 23;
|
||||
|
||||
uint32_t uscale = uint32_t(scale.data) << float_mantissa;
|
||||
|
||||
auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) {
|
||||
if constexpr(std::is_same_v<ComputeType, half_t>)
|
||||
@@ -747,7 +750,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
i,
|
||||
pk_mxfp4_to_compute_v2(
|
||||
quant_weight_tensor.get_thread_buffer()[quant_idx_k * PackedCnt + i],
|
||||
static_cast<float>(scale)));
|
||||
bit_cast<float>(uscale)));
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user