diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 2727c563c0..396bd6d2b8 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -200,52 +200,15 @@ struct UnifiedAttentionKernel return left - 1; } - CK_TILE_DEVICE static constexpr auto RemapTileIndices(const ck_tile::index_t pid, - const Kargs& kargs) - { - using namespace ck_tile; - - constexpr index_t NUM_XCDS = 8; - const index_t GRID_MN = kargs.total_num_q_blocks * (kargs.num_head_q / kargs.num_queries_per_kv); - - // Number of pids per XCD in the new arrangement - const index_t pids_per_xcd = (GRID_MN + NUM_XCDS - 1) / NUM_XCDS; - - // When GRID_MN cannot divide NUM_XCDS, some xcds will have - // pids_per_xcd pids, the other will have pids_per_xcd - 1 pids. - // We calculate the number of xcds that have pids_per_xcd pids as tall_xcds - index_t tall_xcds = GRID_MN % NUM_XCDS; - tall_xcds = tall_xcds == 0 ? NUM_XCDS : tall_xcds; - - // Compute current XCD and local pid within the XCD - const index_t xcd = pid % NUM_XCDS; - const index_t local_pid = pid / NUM_XCDS; - - // Calculate new pid based on the new grouping - index_t remapped_pid = 0; // Initialize to avoid constexpr error - if(xcd < tall_xcds) - { - remapped_pid = xcd * pids_per_xcd + local_pid; - } - else - { - remapped_pid = - tall_xcds * pids_per_xcd + (xcd - tall_xcds) * (pids_per_xcd - 1) + local_pid; - } - - return remapped_pid; - } CK_TILE_DEVICE static constexpr auto GetTileIndex(const ck_tile::index_t pid, const Kargs& kargs) { using namespace ck_tile; - ck_tile::index_t total_num_q_blocks = kargs.total_num_q_blocks; - // const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, - // UnifiedAttentionPipeline::kN1); + ck_tile::index_t num_head_kv = kargs.num_head_q / kargs.num_queries_per_kv; - return ck_tile::make_tuple(pid / total_num_q_blocks, pid % total_num_q_blocks); + return ck_tile::make_tuple(pid % num_head_kv, pid / num_head_kv); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } @@ -274,7 +237,6 @@ struct UnifiedAttentionKernel // const index_t num_head_q = kargs.num_head_q; // const index_t num_head_k = num_head_q / num_queries_per_kv; - pid = RemapTileIndices(pid, kargs); // divide problem const auto [kv_head_idx, q_block_global_idx] = GetTileIndex(pid, kargs);