diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index c606160bc4..8dceacdcc8 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -265,30 +265,39 @@ struct GemmSpatiallyLocalTilePartitioner return integer_divide_ceil(K, KPerBlock); } - CK_TILE_HOST_DEVICE static auto RemapXCD(int pid, int GRID_MN, int NUM_XCDS = 8) { - // Number of pids per XCD in the new arrangement - int pids_per_xcd = (GRID_MN + NUM_XCDS - 1) / NUM_XCDS; + /** + * @brief XCDs access ids in round robin format, this function remaps the 1D ids to continguous XCD segments + * + * @param block_1d_id grid 1D id + * @param total_num_tiles size of the 1D grid + * @param NUM_XCDS number of XCDs + * @return index_t The id after XCD remap + */ + CK_TILE_HOST_DEVICE static auto RemapXCD(index_t block_1d_id, index_t total_num_tiles, index_t NUM_XCDS = 8) noexcept + -> const index_t { + // Number of ids per XCD in the new arrangement + index_t ids_per_xcd = (total_num_tiles + NUM_XCDS - 1) / NUM_XCDS; - // When GRID_MN cannot divide NUM_XCDS, some xcds will have - // pids_per_xcd pids, the other will have pids_per_xcd - 1 pids. - // We calculate the number of xcds that have pids_per_xcd pids as tall_xcds - int tall_xcds = GRID_MN % NUM_XCDS; + // When total_num_tiles cannot divide NUM_XCDS, some xcds will have + // ids_per_xcd ids, the other will have ids_per_xcd - 1 ids. + // We calculate the number of xcds that have ids_per_xcd ids as tall_xcds + index_t tall_xcds = total_num_tiles % NUM_XCDS; tall_xcds = (tall_xcds == 0) ? NUM_XCDS : tall_xcds; - // Compute current XCD and local pid within the XCD - int xcd = pid % NUM_XCDS; - int local_pid = pid / NUM_XCDS; + // Compute current XCD and local id within the XCD + index_t xcd = block_1d_id % NUM_XCDS; + index_t local_id = block_1d_id / NUM_XCDS; - // Calculate new pid based on the new grouping + // Calculate new id based on the new grouping if (xcd < tall_xcds) { - pid = xcd * pids_per_xcd + local_pid; + block_1d_id = xcd * ids_per_xcd + local_id; } else { - pid = tall_xcds * pids_per_xcd - + (xcd - tall_xcds) * (pids_per_xcd - 1) - + local_pid; + block_1d_id = tall_xcds * ids_per_xcd + + (xcd - tall_xcds) * (ids_per_xcd - 1) + + local_id; } - return pid; + return block_1d_id; } /** @@ -302,7 +311,6 @@ struct GemmSpatiallyLocalTilePartitioner { const auto M0 = integer_divide_ceil(M, MPerBlock); const auto N0 = integer_divide_ceil(N, NPerBlock); - // index_t block_1d_id = RemapXCD(_block_1d_id, M0 * N0) if(M0 == 1) {