refactor to clearer BLOCK Q logic

This commit is contained in:
Juuso Korhonen
2025-11-17 08:27:19 +00:00
parent 57a0ec8cc1
commit 5e43fd2dfc
3 changed files with 22 additions and 26 deletions

View File

@@ -170,13 +170,6 @@ struct UnifiedAttentionKernel
return dim3(num_kv_heads * total_num_q_blocks);
}
// CK_TILE_HOST static constexpr auto GridSize3D(ck_tile::index_t num_kv_heads,
// ck_tile::index_t total_num_q_blocks)
// {
// // TODO: fix 3D grid
// return dim2(num_kv_heads, total_num_q_blocks);
// }
// Binary search to find the sequence index for a given target index
CK_TILE_DEVICE static constexpr ck_tile::index_t
find_seq_idx(const int32_t* query_start_len_ptr,
@@ -277,6 +270,8 @@ struct UnifiedAttentionKernel
const index_t num_queries_per_kv = kargs.num_queries_per_kv;
assert(BLOCK_M / num_queries_per_kv == BLOCK_Q);
// const index_t BLOCK_Q = BLOCK_M / num_queries_per_kv;
// for simplicity, batch stride we just modify the pointer
// const index_t num_head_q = kargs.num_head_q;