mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
xcd remap
This commit is contained in:
@@ -754,16 +754,25 @@ struct MoeFlatmmKernel
|
||||
template <class MoeFlatmmKernelArgs>
|
||||
CK_TILE_DEVICE void operator()(MoeFlatmmKernelArgs kargs) const
|
||||
{
|
||||
int partition_idx = blockIdx.x;
|
||||
int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
// total number of tokens: sorted tokens + delimiter tokens + trailing padding tokens
|
||||
// we launch the grid based on the total number of tokens which needs to be static
|
||||
int partition_idx = blockIdx.x;
|
||||
auto max_token_id = kargs.p_max_token_id[0]; // sorted tokens + delimiter tokens
|
||||
int total_valid_tile_cnt = TilePartitioner::GridSize(max_token_id, kargs.N);
|
||||
auto tilePartitioner = TilePartitioner{max_token_id, kargs.N};
|
||||
do
|
||||
{
|
||||
if(partition_idx >= total_valid_tile_cnt)
|
||||
{
|
||||
return; // early exit for trailing padding tokens
|
||||
}
|
||||
partition_idx = tilePartitioner.RemapXCD(partition_idx, total_valid_tile_cnt);
|
||||
const auto [block_offset_m, block_offset_n] =
|
||||
TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
|
||||
tilePartitioner.GetOutputTileIndex(partition_idx);
|
||||
|
||||
this->operator()(kargs, block_offset_m, block_offset_n);
|
||||
partition_idx += gridDim.x;
|
||||
} while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
|
||||
} while(UsePersistentKernel && partition_idx < total_valid_tile_cnt);
|
||||
}
|
||||
|
||||
template <class MoeFlatmmKernelArgs>
|
||||
@@ -773,7 +782,6 @@ struct MoeFlatmmKernel
|
||||
// const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
|
||||
const index_t coord_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t coord_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
const index_t max_token_id = kargs.p_max_token_id[0];
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_ping[GetSmemPingSize()];
|
||||
__shared__ char smem_ptr_pong[GetSmemPongSize()];
|
||||
|
||||
@@ -265,6 +265,55 @@ struct GemmSpatiallyLocalTilePartitioner
|
||||
return integer_divide_ceil(K, KPerBlock);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief XCDs access ids in round robin format, this function remaps the 1D ids to continguous
|
||||
* XCD segments
|
||||
*
|
||||
* @param block_1d_id grid 1D id
|
||||
* @param total_num_tiles size of the 1D grid
|
||||
* @param NUM_XCDS number of XCDs
|
||||
* @return index_t The id after XCD remap
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE static auto RemapXCD(index_t block_1d_id,
|
||||
index_t total_num_tiles,
|
||||
index_t NUM_XCDS = 8) noexcept -> const index_t
|
||||
{
|
||||
// Number of ids per XCD in the new arrangement
|
||||
index_t ids_per_xcd = (total_num_tiles + NUM_XCDS - 1) / NUM_XCDS;
|
||||
|
||||
// When total_num_tiles cannot divide NUM_XCDS, some xcds will have
|
||||
// ids_per_xcd ids, the other will have ids_per_xcd - 1 ids.
|
||||
// We calculate the number of xcds that have ids_per_xcd ids as tall_xcds
|
||||
index_t tall_xcds = total_num_tiles % NUM_XCDS;
|
||||
tall_xcds = (tall_xcds == 0) ? NUM_XCDS : tall_xcds;
|
||||
|
||||
// Compute current XCD and local id within the XCD
|
||||
index_t xcd = block_1d_id % NUM_XCDS;
|
||||
index_t local_id = block_1d_id / NUM_XCDS;
|
||||
|
||||
// Calculate new id based on the new grouping
|
||||
if(xcd < tall_xcds)
|
||||
{
|
||||
block_1d_id = xcd * ids_per_xcd + local_id;
|
||||
}
|
||||
else
|
||||
{
|
||||
block_1d_id =
|
||||
tall_xcds * ids_per_xcd + (xcd - tall_xcds) * (ids_per_xcd - 1) + local_id;
|
||||
}
|
||||
|
||||
/**
|
||||
* original ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
|
||||
* XCD 0 gets: [0, 8], XCD 1 gets: [1, 9], ...
|
||||
*
|
||||
* post-remap ids: [0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15]
|
||||
* XCD 0 gets: [0, 1], XCD 1 gets: [2, 3], ...
|
||||
*
|
||||
* after remap the ids are continguous on each XCD
|
||||
*/
|
||||
return block_1d_id;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculate workgroup 1D index mapping into 2D output C-tile space.
|
||||
*
|
||||
|
||||
Reference in New Issue
Block a user