From 72871f527673eb50f9c31cfca9b07f876bda89ea Mon Sep 17 00:00:00 2001
From: "assistant-librarian[bot]"
<210906412+assistant-librarian[bot]@users.noreply.github.com>
Date: Wed, 18 Feb 2026 11:32:15 -0800
Subject: [PATCH] moe flatmm xcd remap (#4297)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
co-authors: @Chi-Chu319 @juuso-oskari
Added XCD remapping for flatmm moe
batch | Mixtral (tflops, wip_355) | Mixtral-7B (tflops, our branch) |
perf boost
-- | -- | -- | --
64 | 865.424 | 995.455 | 15.0%
256 | 886.336 | 1020.96 | 15.2%
1024 | 890.808 | 1022.53 | 14.8%
---
🔁 Imported from
[ROCm/composable_kernel#3161](https://github.com/ROCm/composable_kernel/pull/3161)
🧑💻 Originally authored by @Chi-Chu319
---------
Co-authored-by: Tianxing Wu
Co-authored-by: Tianxing Wu
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: systems-assistant[bot]
Co-authored-by: illsilin_amdeng
---
.../ops/flatmm/kernel/moe_flatmm_kernel.hpp | 20 +++++---
.../ops/gemm/kernel/gemm_tile_partitioner.hpp | 48 +++++++++++++++++++
2 files changed, 61 insertions(+), 7 deletions(-)
diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp
index 604089b7c4..a211d3b88e 100644
--- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp
+++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp
@@ -901,16 +901,25 @@ struct MoeFlatmmKernel
template
CK_TILE_DEVICE void operator()(MoeFlatmmKernelArgs kargs) const
{
- int partition_idx = blockIdx.x;
- int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
+ // total number of tokens: sorted tokens + delimiter tokens + trailing padding tokens
+ // we launch the grid based on the total number of tokens which needs to be static
+ int partition_idx = blockIdx.x;
+ auto max_token_id = kargs.p_max_token_id[0]; // sorted tokens + delimiter tokens
+ int total_valid_tile_cnt = TilePartitioner::GridSize(max_token_id, kargs.N);
+ auto tilePartitioner = TilePartitioner{max_token_id, kargs.N};
do
{
+ if(partition_idx >= total_valid_tile_cnt)
+ {
+ return; // early exit for trailing padding tokens
+ }
+ partition_idx = tilePartitioner.RemapXCD(partition_idx, total_valid_tile_cnt);
const auto [block_offset_m, block_offset_n] =
- TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
+ tilePartitioner.GetOutputTileIndex(partition_idx);
this->operator()(kargs, block_offset_m, block_offset_n);
partition_idx += gridDim.x;
- } while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
+ } while(UsePersistentKernel && partition_idx < total_valid_tile_cnt);
}
template
@@ -920,7 +929,6 @@ struct MoeFlatmmKernel
// const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
const index_t coord_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t coord_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
- const index_t max_token_id = kargs.p_max_token_id[0];
// allocate LDS
__shared__ char smem_ptr_ping[GetSmemPingSize()];
__shared__ char smem_ptr_pong[GetSmemPongSize()];
@@ -948,8 +956,6 @@ struct MoeFlatmmKernel
return gather_token_id;
};
- if(coord_m >= max_token_id)
- return;
static_for<0, DramMRepeat, 1>{}([&](auto m0) {
const auto row_idx =
coord_m + m0 * (TilePartitioner::MPerBlock / DramMRepeat) + a_coord[I0];
diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
index ac7a2966aa..6114bb2eeb 100644
--- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
+++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
@@ -265,6 +265,54 @@ struct GemmSpatiallyLocalTilePartitioner
return integer_divide_ceil(K, KPerBlock);
}
+ /**
+ * @brief XCDs access ids in round robin format, this function remaps the 1D ids to continguous
+ * XCD segments
+ *
+ * @param block_1d_id grid 1D id
+ * @param total_num_tiles size of the 1D grid
+ * @param NUM_XCDS number of XCDs
+ * @return index_t The id after XCD remap
+ */
+ CK_TILE_HOST_DEVICE static auto
+ RemapXCD(index_t block_1d_id, index_t total_num_tiles, index_t NUM_XCDS = 8) noexcept -> index_t
+ {
+ // Number of ids per XCD in the new arrangement
+ index_t ids_per_xcd = (total_num_tiles + NUM_XCDS - 1) / NUM_XCDS;
+
+ // When total_num_tiles cannot divide NUM_XCDS, some xcds will have
+ // ids_per_xcd ids, the other will have ids_per_xcd - 1 ids.
+ // We calculate the number of xcds that have ids_per_xcd ids as tall_xcds
+ index_t tall_xcds = total_num_tiles % NUM_XCDS;
+ tall_xcds = (tall_xcds == 0) ? NUM_XCDS : tall_xcds;
+
+ // Compute current XCD and local id within the XCD
+ index_t xcd = block_1d_id % NUM_XCDS;
+ index_t local_id = block_1d_id / NUM_XCDS;
+
+ // Calculate new id based on the new grouping
+ if(xcd < tall_xcds)
+ {
+ block_1d_id = xcd * ids_per_xcd + local_id;
+ }
+ else
+ {
+ block_1d_id =
+ tall_xcds * ids_per_xcd + (xcd - tall_xcds) * (ids_per_xcd - 1) + local_id;
+ }
+
+ /**
+ * original ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
+ * XCD 0 gets: [0, 8], XCD 1 gets: [1, 9], ...
+ *
+ * post-remap ids: [0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15]
+ * XCD 0 gets: [0, 1], XCD 1 gets: [2, 3], ...
+ *
+ * after remap the ids are continguous on each XCD
+ */
+ return block_1d_id;
+ }
+
/**
* @brief Calculate workgroup 1D index mapping into 2D output C-tile space.
*