mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 13:48:30 +00:00
refine code
This commit is contained in:
@@ -265,30 +265,39 @@ struct GemmSpatiallyLocalTilePartitioner
|
||||
return integer_divide_ceil(K, KPerBlock);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static auto RemapXCD(int pid, int GRID_MN, int NUM_XCDS = 8) {
|
||||
// Number of pids per XCD in the new arrangement
|
||||
int pids_per_xcd = (GRID_MN + NUM_XCDS - 1) / NUM_XCDS;
|
||||
/**
|
||||
* @brief XCDs access ids in round robin format, this function remaps the 1D ids to continguous XCD segments
|
||||
*
|
||||
* @param block_1d_id grid 1D id
|
||||
* @param total_num_tiles size of the 1D grid
|
||||
* @param NUM_XCDS number of XCDs
|
||||
* @return index_t The id after XCD remap
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE static auto RemapXCD(index_t block_1d_id, index_t total_num_tiles, index_t NUM_XCDS = 8) noexcept
|
||||
-> const index_t {
|
||||
// Number of ids per XCD in the new arrangement
|
||||
index_t ids_per_xcd = (total_num_tiles + NUM_XCDS - 1) / NUM_XCDS;
|
||||
|
||||
// When GRID_MN cannot divide NUM_XCDS, some xcds will have
|
||||
// pids_per_xcd pids, the other will have pids_per_xcd - 1 pids.
|
||||
// We calculate the number of xcds that have pids_per_xcd pids as tall_xcds
|
||||
int tall_xcds = GRID_MN % NUM_XCDS;
|
||||
// When total_num_tiles cannot divide NUM_XCDS, some xcds will have
|
||||
// ids_per_xcd ids, the other will have ids_per_xcd - 1 ids.
|
||||
// We calculate the number of xcds that have ids_per_xcd ids as tall_xcds
|
||||
index_t tall_xcds = total_num_tiles % NUM_XCDS;
|
||||
tall_xcds = (tall_xcds == 0) ? NUM_XCDS : tall_xcds;
|
||||
|
||||
// Compute current XCD and local pid within the XCD
|
||||
int xcd = pid % NUM_XCDS;
|
||||
int local_pid = pid / NUM_XCDS;
|
||||
// Compute current XCD and local id within the XCD
|
||||
index_t xcd = block_1d_id % NUM_XCDS;
|
||||
index_t local_id = block_1d_id / NUM_XCDS;
|
||||
|
||||
// Calculate new pid based on the new grouping
|
||||
// Calculate new id based on the new grouping
|
||||
if (xcd < tall_xcds) {
|
||||
pid = xcd * pids_per_xcd + local_pid;
|
||||
block_1d_id = xcd * ids_per_xcd + local_id;
|
||||
} else {
|
||||
pid = tall_xcds * pids_per_xcd
|
||||
+ (xcd - tall_xcds) * (pids_per_xcd - 1)
|
||||
+ local_pid;
|
||||
block_1d_id = tall_xcds * ids_per_xcd
|
||||
+ (xcd - tall_xcds) * (ids_per_xcd - 1)
|
||||
+ local_id;
|
||||
}
|
||||
|
||||
return pid;
|
||||
return block_1d_id;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -302,7 +311,6 @@ struct GemmSpatiallyLocalTilePartitioner
|
||||
{
|
||||
const auto M0 = integer_divide_ceil(M, MPerBlock);
|
||||
const auto N0 = integer_divide_ceil(N, NPerBlock);
|
||||
// index_t block_1d_id = RemapXCD(_block_1d_id, M0 * N0)
|
||||
|
||||
if(M0 == 1)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user