From 72871f527673eb50f9c31cfca9b07f876bda89ea Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" <210906412+assistant-librarian[bot]@users.noreply.github.com> Date: Wed, 18 Feb 2026 11:32:15 -0800 Subject: [PATCH] moe flatmm xcd remap (#4297) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit co-authors: @Chi-Chu319 @juuso-oskari Added XCD remapping for flatmm moe batch | Mixtral (tflops, wip_355) | Mixtral-7B  (tflops, our branch) | perf boost -- | -- | -- | -- 64 | 865.424 | 995.455 | 15.0% 256 | 886.336 | 1020.96 | 15.2% 1024 | 890.808 | 1022.53 | 14.8% --- 🔁 Imported from [ROCm/composable_kernel#3161](https://github.com/ROCm/composable_kernel/pull/3161) 🧑‍💻 Originally authored by @Chi-Chu319 --------- Co-authored-by: Tianxing Wu Co-authored-by: Tianxing Wu Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: systems-assistant[bot] Co-authored-by: illsilin_amdeng --- .../ops/flatmm/kernel/moe_flatmm_kernel.hpp | 20 +++++--- .../ops/gemm/kernel/gemm_tile_partitioner.hpp | 48 +++++++++++++++++++ 2 files changed, 61 insertions(+), 7 deletions(-) diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index 604089b7c4..a211d3b88e 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -901,16 +901,25 @@ struct MoeFlatmmKernel template CK_TILE_DEVICE void operator()(MoeFlatmmKernelArgs kargs) const { - int partition_idx = blockIdx.x; - int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N); + // total number of tokens: sorted tokens + delimiter tokens + trailing padding tokens + // we launch the grid based on the total number of tokens which needs to be static + int partition_idx = blockIdx.x; + auto max_token_id = kargs.p_max_token_id[0]; // sorted tokens + delimiter tokens + int total_valid_tile_cnt = TilePartitioner::GridSize(max_token_id, kargs.N); + auto tilePartitioner = TilePartitioner{max_token_id, kargs.N}; do { + if(partition_idx >= total_valid_tile_cnt) + { + return; // early exit for trailing padding tokens + } + partition_idx = tilePartitioner.RemapXCD(partition_idx, total_valid_tile_cnt); const auto [block_offset_m, block_offset_n] = - TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx); + tilePartitioner.GetOutputTileIndex(partition_idx); this->operator()(kargs, block_offset_m, block_offset_n); partition_idx += gridDim.x; - } while(UsePersistentKernel && partition_idx < total_work_tile_cnt); + } while(UsePersistentKernel && partition_idx < total_valid_tile_cnt); } template @@ -920,7 +929,6 @@ struct MoeFlatmmKernel // const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x); const index_t coord_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); const index_t coord_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); - const index_t max_token_id = kargs.p_max_token_id[0]; // allocate LDS __shared__ char smem_ptr_ping[GetSmemPingSize()]; __shared__ char smem_ptr_pong[GetSmemPongSize()]; @@ -948,8 +956,6 @@ struct MoeFlatmmKernel return gather_token_id; }; - if(coord_m >= max_token_id) - return; static_for<0, DramMRepeat, 1>{}([&](auto m0) { const auto row_idx = coord_m + m0 * (TilePartitioner::MPerBlock / DramMRepeat) + a_coord[I0]; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index ac7a2966aa..6114bb2eeb 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -265,6 +265,54 @@ struct GemmSpatiallyLocalTilePartitioner return integer_divide_ceil(K, KPerBlock); } + /** + * @brief XCDs access ids in round robin format, this function remaps the 1D ids to continguous + * XCD segments + * + * @param block_1d_id grid 1D id + * @param total_num_tiles size of the 1D grid + * @param NUM_XCDS number of XCDs + * @return index_t The id after XCD remap + */ + CK_TILE_HOST_DEVICE static auto + RemapXCD(index_t block_1d_id, index_t total_num_tiles, index_t NUM_XCDS = 8) noexcept -> index_t + { + // Number of ids per XCD in the new arrangement + index_t ids_per_xcd = (total_num_tiles + NUM_XCDS - 1) / NUM_XCDS; + + // When total_num_tiles cannot divide NUM_XCDS, some xcds will have + // ids_per_xcd ids, the other will have ids_per_xcd - 1 ids. + // We calculate the number of xcds that have ids_per_xcd ids as tall_xcds + index_t tall_xcds = total_num_tiles % NUM_XCDS; + tall_xcds = (tall_xcds == 0) ? NUM_XCDS : tall_xcds; + + // Compute current XCD and local id within the XCD + index_t xcd = block_1d_id % NUM_XCDS; + index_t local_id = block_1d_id / NUM_XCDS; + + // Calculate new id based on the new grouping + if(xcd < tall_xcds) + { + block_1d_id = xcd * ids_per_xcd + local_id; + } + else + { + block_1d_id = + tall_xcds * ids_per_xcd + (xcd - tall_xcds) * (ids_per_xcd - 1) + local_id; + } + + /** + * original ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + * XCD 0 gets: [0, 8], XCD 1 gets: [1, 9], ... + * + * post-remap ids: [0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15] + * XCD 0 gets: [0, 1], XCD 1 gets: [2, 3], ... + * + * after remap the ids are continguous on each XCD + */ + return block_1d_id; + } + /** * @brief Calculate workgroup 1D index mapping into 2D output C-tile space. *