scale bf16

This commit is contained in:
yadaish
2025-12-01 05:30:02 +00:00
parent 2d7a35de3e
commit 2182364ebb
2 changed files with 14 additions and 8 deletions

View File

@@ -638,7 +638,8 @@ struct MoeFlatmmKernel
index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1);
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
using ScaleType = std::conditional_t<MXFP4_Pipeline, e8m0_t, float>;
// using ScaleType = std::conditional_t<MXFP4_Pipeline, e8m0_t, float>;
using ScaleType = std::conditional_t<MXFP4_Pipeline, bfloat16_t, float>;
const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const ScaleType*>(scale_n.ptr) + expert_id * kargs.N * scale_k,

View File

@@ -189,12 +189,12 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
struct DequantizeMxFP4 {
CK_TILE_DEVICE auto operator()(statically_indexed_array<typename WG::BWarpTensor, NIterPerWarp>& dequant_B_n,
const auto& quant_weight_tensor,
const auto& scale_tensor,
auto xdl_nIter,
auto xdl_kIter) {
CK_TILE_DEVICE auto operator()([[maybe_unused]] statically_indexed_array<typename WG::BWarpTensor, NIterPerWarp>& dequant_B_n,
[[maybe_unused]] const auto& quant_weight_tensor,
[[maybe_unused]] const auto& scale_tensor,
[[maybe_unused]] auto xdl_nIter,
[[maybe_unused]] auto xdl_kIter) {
#if 0
auto quant_idx_k = xdl_kIter % number<XDL_PerWeightK>{};
auto scale_idx_n = xdl_nIter % number<XDL_PerScaleN>{};
@@ -258,6 +258,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
.at(i),
bit_cast<float>(uscale)));
});
#endif
#endif
return 0;
}
@@ -281,9 +282,13 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size();
constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize;
/*
constexpr int float_mantissa = 23;
uint32_t uscale = uint32_t(scale.data) << float_mantissa;
*/
float scale_f32 = type_cast<float>(scale.data);
using ComputeV2Type =
std::conditional_t<std::is_same_v<ComputeType, half_t>, fp16x2_t, bf16x2_t>;
@@ -308,7 +313,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
pk_int4_to_compute_v2(
bit_cast<thread_buffer<pk_int4_t, 4>>(quant_weight_tensor[quant_idx_k])
.at(i),
bit_cast<float>(uscale)));
scale_f32));
});
return 0;
}