From 79f2db722e222794329a429c654bfcaa92be1df0 Mon Sep 17 00:00:00 2001 From: yadaish Date: Wed, 19 Nov 2025 04:13:41 +0000 Subject: [PATCH] support a16_wint4 moe --- .../18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp | 19 ++- include/ck_tile/core/numeric/pk_int4.hpp | 18 +++ .../ops/flatmm/kernel/moe_flatmm_kernel.hpp | 2 +- ...ec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 147 +++++++++++++++++- 4 files changed, 179 insertions(+), 7 deletions(-) diff --git a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp index 62fb6bbcb2..709c772f6a 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp @@ -86,7 +86,8 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config FlatmmConfig::NumWaveGroups, true>; // Preshuffle_ - constexpr bool MXFP4_Pipeline = std::is_same_v; + // TODO(yadai): rename to W4_Pipeline + constexpr bool MXFP4_Pipeline = std::is_same_v | std::is_same_v; if constexpr(!MXFP4_Pipeline && moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up) { @@ -444,6 +445,22 @@ int run_a16w4_moe_flatmm_example(int argc, char* argv[]) FlatmmConfig, ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{}); } + else if(mixed_prec == "fp16xint4") + { + return run_a16w4_moe_gemm_example_with_layouts< + ck_tile::half_t, + ck_tile::pk_int4_t, + FlatmmConfig, + ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{}); + } + else if(mixed_prec == "bf16xint4") + { + return run_a16w4_moe_gemm_example_with_layouts< + ck_tile::bfloat16_t, + ck_tile::pk_int4_t, + FlatmmConfig, + ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{}); + } else { throw std::runtime_error("Unsupported precision type for gemm1_gate_up!"); diff --git a/include/ck_tile/core/numeric/pk_int4.hpp b/include/ck_tile/core/numeric/pk_int4.hpp index fc1caf13ff..088407b40c 100644 --- a/include/ck_tile/core/numeric/pk_int4.hpp +++ b/include/ck_tile/core/numeric/pk_int4.hpp @@ -151,6 +151,16 @@ CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x) return pk_add_f16(bit_cast(lo), bit_cast(SUB)); } + + +CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x, float scale) +{ + auto float_vec2 = pk_int4_t_to_fp32x2_t(x); + float_vec2.x = float_vec2.x * scale; + float_vec2.y = float_vec2.y * scale; + return fp16x2_t{type_convert(float_vec2.x), type_convert(float_vec2.y)}; +} + CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x) { uint8_t x_u8 = ck_tile::bit_cast(x); @@ -166,6 +176,14 @@ CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x) return res; } +CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x, float scale) +{ + auto float_vec2 = pk_int4_t_to_fp32x2_t(x); + float_vec2.x = float_vec2.x * scale; + float_vec2.y = float_vec2.y * scale; + return bf16x2_t{type_convert(float_vec2.x), type_convert(float_vec2.y)}; +} + CK_TILE_HOST_DEVICE int8x2_t pk_int4_t_to_int8x2_t(const pk_int4_t& x) { uint8_t x_u8 = ck_tile::bit_cast(x); diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index 411cfe81ed..0d5433f608 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -241,7 +241,7 @@ struct MoeFlatmmKernel IsGateUp ? TilePartitioner::NPerBlock / 2 : TilePartitioner::NPerBlock; // MXF4_Pipeline only has the of scale B and granularityK is 32 - static constexpr bool MXFP4_Pipeline = std::is_same_v; + static constexpr bool MXFP4_Pipeline = std::is_same_v || std::is_same_v; static constexpr int MXFP4N_Pack = 2; static constexpr int MXFP4K_Pack = 2; diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 89c40a0d69..73d1252e57 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -187,6 +187,135 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 // For the basic gemm pipelien DoubleSmemBuffer set to be false naturally. static constexpr bool DoubleSmemBuffer = false; + struct DequantizeMxFP4 { + + CK_TILE_DEVICE auto operator()(statically_indexed_array& dequant_B_n, + const auto& quant_weight_tensor, + const auto& scale_tensor, + auto xdl_nIter, + auto xdl_kIter) { + + auto quant_idx_k = xdl_kIter % number{}; + + auto scale_idx_n = xdl_nIter % number{}; + auto scale_idx_k = (xdl_kIter % number{}) / number{}; + auto scale_offset = scale_idx_n + scale_idx_k * number{}; + + auto scale = scale_tensor.get_thread_buffer()[scale_offset]; + + constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size(); + constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize; + constexpr int float_mantissa = 23; + + uint32_t uscale = uint32_t(scale.data) << float_mantissa; + + using ComputeV2Type = + std::conditional_t, fp16x2_t, bf16x2_t>; + +#if defined(__gfx950__) + auto pk_mxfp4x4_to_compute_v2 = [](auto pk_mxfp4x4, float fscale, auto byte_idx) { + if constexpr(std::is_same_v) + { + return __builtin_amdgcn_cvt_scalef32_pk_f16_fp4( + pk_mxfp4x4, fscale, int(byte_idx)); + } + else if constexpr(std::is_same_v) + { + return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4( + pk_mxfp4x4, fscale, int(byte_idx)); + } + else + { + static_assert(sizeof(pk_mxfp4x4) == 0, "unsupported compute type"); + } + }; + static_for<0, PackedCnt, 1>{}([&](auto i) { + dequant_B_n[xdl_nIter].get_thread_buffer().template set_as( + i, + pk_mxfp4x4_to_compute_v2( + quant_weight_tensor[quant_idx_k], bit_cast(uscale), i)); + }); +#else + auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) { + if constexpr(std::is_same_v) + { + return pk_fp4_to_fp16x2(pk_mxfp4, fscale); + } + else if constexpr(std::is_same_v) + { + return pk_fp4_to_bf16x2(pk_mxfp4, fscale); + } + else + { + static_assert(sizeof(pk_mxfp4) == 0, "unsupported compute type"); + } + }; + static_for<0, PackedCnt, 1>{}([&](auto i) { + dequant_B_n[xdl_nIter].get_thread_buffer().template set_as( + i, + pk_mxfp4_to_compute_v2( + bit_cast>(quant_weight_tensor[quant_idx_k]) + .at(i), + bit_cast(uscale))); + }); +#endif + return 0; + } + }; + + struct DequantizeINT4 { + + CK_TILE_DEVICE auto operator()(statically_indexed_array& dequant_B_n, + const auto& quant_weight_tensor, + const auto& scale_tensor, + auto xdl_nIter, + auto xdl_kIter) { + + auto quant_idx_k = xdl_kIter % number{}; + + auto scale_idx_n = xdl_nIter % number{}; + auto scale_idx_k = (xdl_kIter % number{}) / number{}; + auto scale_offset = scale_idx_n + scale_idx_k * number{}; + + auto scale = scale_tensor.get_thread_buffer()[scale_offset]; + + constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size(); + constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize; + constexpr int float_mantissa = 23; + + uint32_t uscale = uint32_t(scale.data) << float_mantissa; + + using ComputeV2Type = + std::conditional_t, fp16x2_t, bf16x2_t>; + + auto pk_int4_to_compute_v2 = [](auto pk_int4, float fscale) { + if constexpr(std::is_same_v) + { + return pk_int4_t_to_halfx2_t(pk_int4, fscale); + } + else if constexpr(std::is_same_v) + { + return pk_int4_t_to_bfloat16x2_t(pk_int4, fscale); + } + else + { + static_assert(sizeof(pk_int4) == 0, "unsupported compute type"); + } + }; + static_for<0, PackedCnt, 1>{}([&](auto i) { + dequant_B_n[xdl_nIter].get_thread_buffer().template set_as( + i, + pk_int4_to_compute_v2( + bit_cast>(quant_weight_tensor[quant_idx_k]) + .at(i), + bit_cast(uscale))); + }); + return 0; + } + }; + + using DequantOp = typename std::conditional, DequantizeMxFP4, DequantizeINT4>::type; + CK_TILE_HOST_DEVICE static constexpr auto SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM) { @@ -747,6 +876,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 statically_indexed_array dequant_B_n; + + /* auto dequant_mxfp4 = [&](const auto& quant_weight_tensor, const auto& scale_tensor, auto xdl_nIter, @@ -816,6 +947,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 }); #endif }; + */ // MAIN LOOP index_t iCounter = (num_loop - 1) / 2; @@ -877,7 +1009,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); if constexpr(mIter == 0) - dequant_mxfp4( + DequantOp{}( + dequant_B_n, b_warp_tensor_ping(nIter)(kIter / number{}), scale_b_warp_tensor_ping(nIter / number{})( kIter / number{}), @@ -997,7 +1130,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); if constexpr(mIter == 0) - dequant_mxfp4( + DequantOp{}( + dequant_B_n, b_warp_tensor_pong(nIter)(kIter / number{}), scale_b_warp_tensor_pong(nIter / number{})( kIter / number{}), @@ -1124,7 +1258,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); if constexpr(mIter == 0) - dequant_mxfp4( + DequantOp{}( + dequant_B_n, b_warp_tensor_ping(nIter)(kIter / number{}), scale_b_warp_tensor_ping(nIter / number{})( kIter / number{}), @@ -1185,7 +1320,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); if constexpr(mIter == 0) - dequant_mxfp4( + DequantOp{}( + dequant_B_n, b_warp_tensor_pong(nIter)(kIter / number{}), scale_b_warp_tensor_pong(nIter / number{})( kIter / number{}), @@ -1236,7 +1372,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); if constexpr(mIter == 0) - dequant_mxfp4( + DequantOp{}( + dequant_B_n, b_warp_tensor_ping(nIter)(kIter / number{}), scale_b_warp_tensor_ping(nIter / number{})( kIter / number{}),