From b09b6cdce9ea808ef4bbca9034ed1e1f9878d46f Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Tue, 9 Sep 2025 04:18:48 +0000 Subject: [PATCH] fix scale_m gather load for a8w8 moe --- .../core/tensor/tile_scatter_gather.hpp | 71 +++++++++++++++++++ .../moe_flatmm/kernel/moe_flatmm_kernel.hpp | 61 +++++++++------- 2 files changed, 107 insertions(+), 25 deletions(-) diff --git a/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/include/ck_tile/core/tensor/tile_scatter_gather.hpp index c415525942..082207d1df 100644 --- a/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -404,6 +404,77 @@ struct tile_scatter_gather }); } + template + CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor, + const Ys2PageIdxMap& ys_to_page_idx_map, + number = {}, + bool_constant = {}) const + { + using Traits = load_store_traits; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + const auto idx_gather = ys_to_page_idx_map(idx_ys_start); + const auto page_offset = page_idx_[idx_gather]; + + // read from bottom tensor + const vector_t vec_value = [&]() { + if constexpr(std::is_same_v) + { + return get_bottom_tensor_view().template get_vectorized_elements( + bottom_tensor_thread_coord, + page_offset, + bool_constant{}); + } + else + { + return get_bottom_tensor_view().template get_vectorized_elements( + bottom_tensor_thread_coord, + page_offset, + valids_[idx_gather], + bool_constant{}); + } + }(); + + // write into distributed tensor + static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { + constexpr auto idx_ys = generate_tuple( + [&](auto jj) { + return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + Traits::PackedSize; + + dst_tensor.get_thread_buffer().template at() = + vec_value.template get_as()[j / Traits::PackedSize]; + }); + + // ys_to_page_idx_map handles all offset calculation. + // So ther is no need to move thread coordinate redundantly. + }); + }); + } + // TODO: currently async load only implemented in inline asm template ck_tile::index_t stride_C_, ck_tile::index_t n_padded_zeros_ = 0, ck_tile::index_t k_padded_zeros_ = 0, - ScaleM scale_m_ = {}, - ScaleN scale_n_ = {}, - ExpertBias exp_bias_ = {}) + ScaleM scale_m_ = {}, + ScaleN scale_n_ = {}, + ExpertBias exp_bias_ = {}) : ScaleFlatmmHostArgs(a_ptr_, b_ptr_, {}, // d_ptr_array @@ -524,7 +524,7 @@ struct MoeFlatmmKernel const int expert_id, const KernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset, - const int n_pad_zeros = 12) + const int n_pad_zeros = 12) { const auto& a_tensor_view = [&]() { if constexpr(std::is_same_v) @@ -596,7 +596,7 @@ struct MoeFlatmmKernel const auto scale_b_flat_view = make_naive_tensor_view( reinterpret_cast(scale_n.ptr) + expert_id * kargs.N * scale_k, - make_tuple(FlatScaleN - kargs.n_padded_zeros / NPerXdl / N_Pack, FlatScaleK), + make_tuple(FlatScaleN - kargs.n_padded_zeros / NPerXdl / N_Pack, FlatScaleK), make_tuple(FlatScaleK, 1), number<8>{}, number<1>{}); @@ -878,14 +878,22 @@ struct MoeFlatmmKernel const auto scale_m_coord = output_acc_tile_distr.calculate_index(); // 2d thread offset, [i_row, i_col] - constexpr ck_tile::index_t ScaleMRepeat = - decltype(output_acc_tile_distr)::DstrEncode::hs_lengthss_[number<0>{}][number<0>{}]; - statically_indexed_array scale_m_offsets; + constexpr index_t kM2 = 4; // Val-dim + constexpr index_t kM1 = get_warp_size() / NPerXdl; // Thr-dim + constexpr index_t kM0 = MPerXdl / kM1 / kM2; // Var-dim - static_for<0, ScaleMRepeat, 1>{}([&](auto m0) { - const auto row_idx = - coord_m + m0 * (TilePartitioner::MPerBlock / ScaleMRepeat) + scale_m_coord[I0]; - scale_m_offsets[m0] = row_to_token_idx(row_idx); + constexpr index_t ScaleMRepeat = MRepeat * kM0 * kM2; + statically_indexed_array scale_m_offsets; + + static_for<0, MRepeat, 1>{}([&](auto mIter) { + static_for<0, kM0, 1>{}([&](auto m0) { + static_for<0, kM2, 1>{}([&](auto m2) { + const auto row_idx = + coord_m + mIter * MPerXdl + m0 * kM1 * kM2 + m2 + scale_m_coord[I0]; + scale_m_offsets[mIter * number{} + m0 * number{} + m2] = + row_to_token_idx(row_idx); + }); + }); }); constexpr int DynamicTileOffsetFlag = 0; @@ -924,18 +932,18 @@ struct MoeFlatmmKernel } }; - auto scale_m_window = make_tile_scatter_gather( - make_naive_tensor_view( - kargs.scale_m.ptr, - make_tuple(kargs.M, 1), - make_tuple(scale_stride_m, 0), - number{}, - number<1>{}), - make_tuple(number{}, - number{}), - {0, 0}, // offset m is included in gather offsets - output_acc_tile_distr, - scale_m_offsets); + auto scale_m_window = + make_tile_scatter_gather(make_naive_tensor_view( + kargs.scale_m.ptr, + make_tuple(kargs.M, 1), + make_tuple(scale_stride_m, 0), + number<1>{}, // gather load can't vectorize + number<1>{}), + make_tuple(number{}, + number{}), + {0, 0}, // offset m is included in gather offsets + output_acc_tile_distr, + scale_m_offsets); auto scale_n_window = make_tile_window( make_naive_tensor_view( @@ -1015,7 +1023,10 @@ struct MoeFlatmmKernel if constexpr(!MXFP4_Pipeline) { - scale_m_buffer = load_tile(scale_m_window); + scale_m_window.load(scale_m_buffer, [](auto ys_coord) { + return ys_coord.at(I0) * number{} + ys_coord.at(I2) * number{} + + ys_coord.at(I3); + }); scale_n_buffer = load_tile(scale_n_window); if constexpr(IsGateUp) scale_n_up_buffer = load_tile(scale_n_up_window);