diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index 946b81c146..8ccd2c1bcb 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -638,7 +638,8 @@ struct MoeFlatmmKernel index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1); index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1); - using ScaleType = std::conditional_t; + // using ScaleType = std::conditional_t; + using ScaleType = std::conditional_t; const auto scale_b_flat_view = make_naive_tensor_view( reinterpret_cast(scale_n.ptr) + expert_id * kargs.N * scale_k, diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index b42db7434c..9925cc0691 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -189,12 +189,12 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 struct DequantizeMxFP4 { - CK_TILE_DEVICE auto operator()(statically_indexed_array& dequant_B_n, - const auto& quant_weight_tensor, - const auto& scale_tensor, - auto xdl_nIter, - auto xdl_kIter) { - + CK_TILE_DEVICE auto operator()([[maybe_unused]] statically_indexed_array& dequant_B_n, + [[maybe_unused]] const auto& quant_weight_tensor, + [[maybe_unused]] const auto& scale_tensor, + [[maybe_unused]] auto xdl_nIter, + [[maybe_unused]] auto xdl_kIter) { +#if 0 auto quant_idx_k = xdl_kIter % number{}; auto scale_idx_n = xdl_nIter % number{}; @@ -258,6 +258,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 .at(i), bit_cast(uscale))); }); +#endif #endif return 0; } @@ -281,9 +282,13 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size(); constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize; + /* constexpr int float_mantissa = 23; uint32_t uscale = uint32_t(scale.data) << float_mantissa; + */ + + float scale_f32 = type_cast(scale.data); using ComputeV2Type = std::conditional_t, fp16x2_t, bf16x2_t>; @@ -308,7 +313,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 pk_int4_to_compute_v2( bit_cast>(quant_weight_tensor[quant_idx_k]) .at(i), - bit_cast(uscale))); + scale_f32)); }); return 0; }