mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
refactor to clearer BLOCK Q logic
This commit is contained in:
@@ -170,13 +170,6 @@ struct UnifiedAttentionKernel
|
||||
return dim3(num_kv_heads * total_num_q_blocks);
|
||||
}
|
||||
|
||||
// CK_TILE_HOST static constexpr auto GridSize3D(ck_tile::index_t num_kv_heads,
|
||||
// ck_tile::index_t total_num_q_blocks)
|
||||
// {
|
||||
// // TODO: fix 3D grid
|
||||
// return dim2(num_kv_heads, total_num_q_blocks);
|
||||
// }
|
||||
|
||||
// Binary search to find the sequence index for a given target index
|
||||
CK_TILE_DEVICE static constexpr ck_tile::index_t
|
||||
find_seq_idx(const int32_t* query_start_len_ptr,
|
||||
@@ -277,6 +270,8 @@ struct UnifiedAttentionKernel
|
||||
|
||||
const index_t num_queries_per_kv = kargs.num_queries_per_kv;
|
||||
|
||||
assert(BLOCK_M / num_queries_per_kv == BLOCK_Q);
|
||||
|
||||
// const index_t BLOCK_Q = BLOCK_M / num_queries_per_kv;
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
// const index_t num_head_q = kargs.num_head_q;
|
||||
|
||||
Reference in New Issue
Block a user