mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
simplified kernel pid logic
This commit is contained in:
@@ -200,52 +200,15 @@ struct UnifiedAttentionKernel
|
||||
return left - 1;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto RemapTileIndices(const ck_tile::index_t pid,
|
||||
const Kargs& kargs)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
constexpr index_t NUM_XCDS = 8;
|
||||
const index_t GRID_MN = kargs.total_num_q_blocks * (kargs.num_head_q / kargs.num_queries_per_kv);
|
||||
|
||||
// Number of pids per XCD in the new arrangement
|
||||
const index_t pids_per_xcd = (GRID_MN + NUM_XCDS - 1) / NUM_XCDS;
|
||||
|
||||
// When GRID_MN cannot divide NUM_XCDS, some xcds will have
|
||||
// pids_per_xcd pids, the other will have pids_per_xcd - 1 pids.
|
||||
// We calculate the number of xcds that have pids_per_xcd pids as tall_xcds
|
||||
index_t tall_xcds = GRID_MN % NUM_XCDS;
|
||||
tall_xcds = tall_xcds == 0 ? NUM_XCDS : tall_xcds;
|
||||
|
||||
// Compute current XCD and local pid within the XCD
|
||||
const index_t xcd = pid % NUM_XCDS;
|
||||
const index_t local_pid = pid / NUM_XCDS;
|
||||
|
||||
// Calculate new pid based on the new grouping
|
||||
index_t remapped_pid = 0; // Initialize to avoid constexpr error
|
||||
if(xcd < tall_xcds)
|
||||
{
|
||||
remapped_pid = xcd * pids_per_xcd + local_pid;
|
||||
}
|
||||
else
|
||||
{
|
||||
remapped_pid =
|
||||
tall_xcds * pids_per_xcd + (xcd - tall_xcds) * (pids_per_xcd - 1) + local_pid;
|
||||
}
|
||||
|
||||
return remapped_pid;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto GetTileIndex(const ck_tile::index_t pid,
|
||||
const Kargs& kargs)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
ck_tile::index_t total_num_q_blocks = kargs.total_num_q_blocks;
|
||||
// const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v,
|
||||
// UnifiedAttentionPipeline::kN1);
|
||||
ck_tile::index_t num_head_kv = kargs.num_head_q / kargs.num_queries_per_kv;
|
||||
|
||||
return ck_tile::make_tuple(pid / total_num_q_blocks, pid % total_num_q_blocks);
|
||||
return ck_tile::make_tuple(pid % num_head_kv, pid / num_head_kv);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
@@ -274,7 +237,6 @@ struct UnifiedAttentionKernel
|
||||
// const index_t num_head_q = kargs.num_head_q;
|
||||
|
||||
// const index_t num_head_k = num_head_q / num_queries_per_kv;
|
||||
pid = RemapTileIndices(pid, kargs);
|
||||
|
||||
// divide problem
|
||||
const auto [kv_head_idx, q_block_global_idx] = GetTileIndex(pid, kargs);
|
||||
|
||||
Reference in New Issue
Block a user