refine code

This commit is contained in:
Tianxing Wu
2025-09-11 16:45:53 +00:00
parent 5ee7e2cf97
commit 797031beea

View File

@@ -265,30 +265,39 @@ struct GemmSpatiallyLocalTilePartitioner
return integer_divide_ceil(K, KPerBlock);
}
CK_TILE_HOST_DEVICE static auto RemapXCD(int pid, int GRID_MN, int NUM_XCDS = 8) {
// Number of pids per XCD in the new arrangement
int pids_per_xcd = (GRID_MN + NUM_XCDS - 1) / NUM_XCDS;
/**
* @brief XCDs access ids in round robin format, this function remaps the 1D ids to continguous XCD segments
*
* @param block_1d_id grid 1D id
* @param total_num_tiles size of the 1D grid
* @param NUM_XCDS number of XCDs
* @return index_t The id after XCD remap
*/
CK_TILE_HOST_DEVICE static auto RemapXCD(index_t block_1d_id, index_t total_num_tiles, index_t NUM_XCDS = 8) noexcept
-> const index_t {
// Number of ids per XCD in the new arrangement
index_t ids_per_xcd = (total_num_tiles + NUM_XCDS - 1) / NUM_XCDS;
// When GRID_MN cannot divide NUM_XCDS, some xcds will have
// pids_per_xcd pids, the other will have pids_per_xcd - 1 pids.
// We calculate the number of xcds that have pids_per_xcd pids as tall_xcds
int tall_xcds = GRID_MN % NUM_XCDS;
// When total_num_tiles cannot divide NUM_XCDS, some xcds will have
// ids_per_xcd ids, the other will have ids_per_xcd - 1 ids.
// We calculate the number of xcds that have ids_per_xcd ids as tall_xcds
index_t tall_xcds = total_num_tiles % NUM_XCDS;
tall_xcds = (tall_xcds == 0) ? NUM_XCDS : tall_xcds;
// Compute current XCD and local pid within the XCD
int xcd = pid % NUM_XCDS;
int local_pid = pid / NUM_XCDS;
// Compute current XCD and local id within the XCD
index_t xcd = block_1d_id % NUM_XCDS;
index_t local_id = block_1d_id / NUM_XCDS;
// Calculate new pid based on the new grouping
// Calculate new id based on the new grouping
if (xcd < tall_xcds) {
pid = xcd * pids_per_xcd + local_pid;
block_1d_id = xcd * ids_per_xcd + local_id;
} else {
pid = tall_xcds * pids_per_xcd
+ (xcd - tall_xcds) * (pids_per_xcd - 1)
+ local_pid;
block_1d_id = tall_xcds * ids_per_xcd
+ (xcd - tall_xcds) * (ids_per_xcd - 1)
+ local_id;
}
return pid;
return block_1d_id;
}
/**
@@ -302,7 +311,6 @@ struct GemmSpatiallyLocalTilePartitioner
{
const auto M0 = integer_divide_ceil(M, MPerBlock);
const auto N0 = integer_divide_ceil(N, NPerBlock);
// index_t block_1d_id = RemapXCD(_block_1d_id, M0 * N0)
if(M0 == 1)
{