simplified kernel pid logic

This commit is contained in:
Tianxing Wu
2025-11-27 13:28:35 +00:00
parent eeb419845d
commit 3131ebf1df

View File

@@ -200,52 +200,15 @@ struct UnifiedAttentionKernel
return left - 1;
}
CK_TILE_DEVICE static constexpr auto RemapTileIndices(const ck_tile::index_t pid,
const Kargs& kargs)
{
using namespace ck_tile;
constexpr index_t NUM_XCDS = 8;
const index_t GRID_MN = kargs.total_num_q_blocks * (kargs.num_head_q / kargs.num_queries_per_kv);
// Number of pids per XCD in the new arrangement
const index_t pids_per_xcd = (GRID_MN + NUM_XCDS - 1) / NUM_XCDS;
// When GRID_MN cannot divide NUM_XCDS, some xcds will have
// pids_per_xcd pids, the other will have pids_per_xcd - 1 pids.
// We calculate the number of xcds that have pids_per_xcd pids as tall_xcds
index_t tall_xcds = GRID_MN % NUM_XCDS;
tall_xcds = tall_xcds == 0 ? NUM_XCDS : tall_xcds;
// Compute current XCD and local pid within the XCD
const index_t xcd = pid % NUM_XCDS;
const index_t local_pid = pid / NUM_XCDS;
// Calculate new pid based on the new grouping
index_t remapped_pid = 0; // Initialize to avoid constexpr error
if(xcd < tall_xcds)
{
remapped_pid = xcd * pids_per_xcd + local_pid;
}
else
{
remapped_pid =
tall_xcds * pids_per_xcd + (xcd - tall_xcds) * (pids_per_xcd - 1) + local_pid;
}
return remapped_pid;
}
CK_TILE_DEVICE static constexpr auto GetTileIndex(const ck_tile::index_t pid,
const Kargs& kargs)
{
using namespace ck_tile;
ck_tile::index_t total_num_q_blocks = kargs.total_num_q_blocks;
// const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v,
// UnifiedAttentionPipeline::kN1);
ck_tile::index_t num_head_kv = kargs.num_head_q / kargs.num_queries_per_kv;
return ck_tile::make_tuple(pid / total_num_q_blocks, pid % total_num_q_blocks);
return ck_tile::make_tuple(pid % num_head_kv, pid / num_head_kv);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
@@ -274,7 +237,6 @@ struct UnifiedAttentionKernel
// const index_t num_head_q = kargs.num_head_q;
// const index_t num_head_k = num_head_q / num_queries_per_kv;
pid = RemapTileIndices(pid, kargs);
// divide problem
const auto [kv_head_idx, q_block_global_idx] = GetTileIndex(pid, kargs);