From 082749ba81c2b3d4fc6597d123bbcb13f41cf7a8 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Wed, 5 Nov 2025 12:05:15 +0000 Subject: [PATCH] xcd remap --- .../ops/flatmm/kernel/moe_flatmm_kernel.hpp | 18 +++++-- .../ops/gemm/kernel/gemm_tile_partitioner.hpp | 49 +++++++++++++++++++ 2 files changed, 62 insertions(+), 5 deletions(-) diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index 411cfe81ed..b3b073afbb 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -754,16 +754,25 @@ struct MoeFlatmmKernel template CK_TILE_DEVICE void operator()(MoeFlatmmKernelArgs kargs) const { - int partition_idx = blockIdx.x; - int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N); + // total number of tokens: sorted tokens + delimiter tokens + trailing padding tokens + // we launch the grid based on the total number of tokens which needs to be static + int partition_idx = blockIdx.x; + auto max_token_id = kargs.p_max_token_id[0]; // sorted tokens + delimiter tokens + int total_valid_tile_cnt = TilePartitioner::GridSize(max_token_id, kargs.N); + auto tilePartitioner = TilePartitioner{max_token_id, kargs.N}; do { + if(partition_idx >= total_valid_tile_cnt) + { + return; // early exit for trailing padding tokens + } + partition_idx = tilePartitioner.RemapXCD(partition_idx, total_valid_tile_cnt); const auto [block_offset_m, block_offset_n] = - TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx); + tilePartitioner.GetOutputTileIndex(partition_idx); this->operator()(kargs, block_offset_m, block_offset_n); partition_idx += gridDim.x; - } while(UsePersistentKernel && partition_idx < total_work_tile_cnt); + } while(UsePersistentKernel && partition_idx < total_valid_tile_cnt); } template @@ -773,7 +782,6 @@ struct MoeFlatmmKernel // const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x); const index_t coord_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); const index_t coord_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); - const index_t max_token_id = kargs.p_max_token_id[0]; // allocate LDS __shared__ char smem_ptr_ping[GetSmemPingSize()]; __shared__ char smem_ptr_pong[GetSmemPongSize()]; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index 673f5abc34..5ff6810387 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -265,6 +265,55 @@ struct GemmSpatiallyLocalTilePartitioner return integer_divide_ceil(K, KPerBlock); } + /** + * @brief XCDs access ids in round robin format, this function remaps the 1D ids to continguous + * XCD segments + * + * @param block_1d_id grid 1D id + * @param total_num_tiles size of the 1D grid + * @param NUM_XCDS number of XCDs + * @return index_t The id after XCD remap + */ + CK_TILE_HOST_DEVICE static auto RemapXCD(index_t block_1d_id, + index_t total_num_tiles, + index_t NUM_XCDS = 8) noexcept -> const index_t + { + // Number of ids per XCD in the new arrangement + index_t ids_per_xcd = (total_num_tiles + NUM_XCDS - 1) / NUM_XCDS; + + // When total_num_tiles cannot divide NUM_XCDS, some xcds will have + // ids_per_xcd ids, the other will have ids_per_xcd - 1 ids. + // We calculate the number of xcds that have ids_per_xcd ids as tall_xcds + index_t tall_xcds = total_num_tiles % NUM_XCDS; + tall_xcds = (tall_xcds == 0) ? NUM_XCDS : tall_xcds; + + // Compute current XCD and local id within the XCD + index_t xcd = block_1d_id % NUM_XCDS; + index_t local_id = block_1d_id / NUM_XCDS; + + // Calculate new id based on the new grouping + if(xcd < tall_xcds) + { + block_1d_id = xcd * ids_per_xcd + local_id; + } + else + { + block_1d_id = + tall_xcds * ids_per_xcd + (xcd - tall_xcds) * (ids_per_xcd - 1) + local_id; + } + + /** + * original ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + * XCD 0 gets: [0, 8], XCD 1 gets: [1, 9], ... + * + * post-remap ids: [0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15] + * XCD 0 gets: [0, 1], XCD 1 gets: [2, 3], ... + * + * after remap the ids are continguous on each XCD + */ + return block_1d_id; + } + /** * @brief Calculate workgroup 1D index mapping into 2D output C-tile space. *