From 2f3874b51511916095f035eb23fe8979f236a892 Mon Sep 17 00:00:00 2001 From: zanzhang Date: Thu, 18 Dec 2025 09:39:49 +0800 Subject: [PATCH] update ck_tile moe --- .../ops/flatmm/kernel/moe_flatmm_kernel.hpp | 64 +++++++++++-------- 1 file changed, 38 insertions(+), 26 deletions(-) diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index 99a8916db0..9867a01e0f 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -672,31 +672,48 @@ struct MoeFlatmmKernel } }(); - auto scale_m_desc = kargs.scale_m; - constexpr int AGranularityK = decltype(scale_m_desc)::GranularityK; - const auto& scale_a_tensor_view = [&]() { - constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(I0); - constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I0); - index_t scale_m_packs = kargs.M / (MXFP4M_Pack * MThreadPerXdl); - index_t scale_k_packs = kargs.K / (MXFP4K_Pack * AGranularityK * KThreadPerXdl); - // Pack 2x2 e8m0 over M/K dimension into 1 int32_t to trigger dword width load - const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed( - make_tuple(scale_m_packs, scale_k_packs, KThreadPerXdl, MThreadPerXdl)); - const auto scale_a_desc = transform_tensor_descriptor( - scale_a_naive_desc, - make_tuple(make_merge_transform(make_tuple(scale_m_packs, MThreadPerXdl)), - make_merge_transform(make_tuple(scale_k_packs, KThreadPerXdl))), - make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return make_tensor_view( - reinterpret_cast(scale_m_desc.ptr), scale_a_desc); + auto scale_m_desc = kargs.scale_m; + if constexpr(AQUANT_Pipeline) + { + constexpr int AGranularityK = decltype(scale_m_desc)::GranularityK == 0 ? 1 : decltype(scale_m_desc)::GranularityK; + + constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(I0); + constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I0); + index_t scale_m_packs = kargs.M / (MXFP4M_Pack * MThreadPerXdl); + index_t scale_k_packs = kargs.K / (MXFP4K_Pack * AGranularityK * KThreadPerXdl); + // Pack 2x2 e8m0 over M/K dimension into 1 int32_t to trigger dword width load + const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed( + make_tuple(scale_m_packs, scale_k_packs, KThreadPerXdl, MThreadPerXdl)); + const auto scale_a_desc = transform_tensor_descriptor( + scale_a_naive_desc, + make_tuple(make_merge_transform(make_tuple(scale_m_packs, MThreadPerXdl)), + make_merge_transform(make_tuple(scale_k_packs, KThreadPerXdl))), + make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return make_tensor_view( + reinterpret_cast(scale_m_desc.ptr), scale_a_desc); + } + else + { + constexpr int AGranularityK = 32; + constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(I0); + constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I0); + index_t scale_m_packs = kargs.M / (MXFP4M_Pack * MThreadPerXdl); + index_t scale_k_packs = kargs.K / (MXFP4K_Pack * AGranularityK * KThreadPerXdl); + return make_naive_tensor_view( + reinterpret_cast(scale_m_desc.ptr), + make_tuple(scale_m_packs * MThreadPerXdl, scale_k_packs * KThreadPerXdl), + make_tuple(scale_k_packs * KThreadPerXdl, 1), + number<8>{}, + number<1>{}); + } }(); - auto scale_n = kargs.scale_n; - constexpr int BGranularityK = decltype(scale_n)::GranularityK; const auto scale_b_flat_view = [&]() { + auto scale_n = kargs.scale_n; + constexpr int BGranularityK = decltype(scale_n)::GranularityK == 0 ? 1 : decltype(scale_n)::GranularityK; if constexpr(AQUANT_Pipeline) { index_t scale_k = @@ -833,17 +850,12 @@ struct MoeFlatmmKernel constexpr int GranularityK = 32; // fixed config for MXF4_Pipeline auto a_scale_block_window = - // make_tile_window(views.at(I3), - // make_tuple(number{}, - // number{}), - // {coord_m, 0}); make_tile_window( views.at(I3), make_tuple(number{}, number{}), {coord_m / M_Pack, 0}); - // constexpr int GranularityK = 32; // fixed config for MXF4_Pipeline constexpr int XDLPerLoadScaleB = BMXFP4_Pipeline ? 4 : 1; // GranularityK32 / XDL16x16x32_K8 = 4 @@ -943,7 +955,7 @@ struct MoeFlatmmKernel static_cast(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; const BDataType* b_flat_ptr = static_cast(kargs.b_ptr) + - (splitk_batch_offset.b_k_split_offset + expert_stride * expert_id) / 2; + (splitk_batch_offset.b_k_split_offset + expert_stride * expert_id) / WeightPackedSize; EDataType* e_ptr = static_cast(kargs.e_ptr); const AccDataType* exp_weight_ptr =