From 5e43fd2dfcd0aaa3ac39eb7b19e4eecfd5be8506 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Mon, 17 Nov 2025 08:27:19 +0000 Subject: [PATCH] refactor to clearer BLOCK Q logic --- .../example_unified_attention.cpp | 19 ++++++++---------- .../unified_attention_impl.hpp | 20 +++++++++++-------- .../kernel/unified_attention_kernel.hpp | 9 ++------- 3 files changed, 22 insertions(+), 26 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index b8c65d7c0a..d34d8d9593 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -30,11 +30,11 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair; + using unified_attention_block_tile = sequence; using unified_attention_warp_gemm_shape = sequence<32, 32, 16>; + // need to have 8 warps per workgroup to have warp specialization using unified_attention_block_warps = sequence<8, 1, 1>; using unified_attention_shape = TileUnifiedAttentionShape float unified_attention_kernel_launch(const unified_attention_args& args, const stream_config& config) { - - index_t BLOCK_Q = Kernel::BLOCK_M / args.num_queries_per_kv; - + index_t BLOCK_Q = Kernel::BLOCK_Q; + assert(BLOCK_Q == args.num_head_q / args.num_queries_per_kv && "BLOCK_Q must equal BLOCK_M / num_queries_per_kv"); index_t total_num_q_blocks = args.num_tokens / BLOCK_Q + args.num_seqs; - auto kargs = Kernel::MakeKargs(args.q_ptr, args.k_ptr, args.v_ptr, diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 1cf7698b61..2f1b574655 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -170,13 +170,6 @@ struct UnifiedAttentionKernel return dim3(num_kv_heads * total_num_q_blocks); } - // CK_TILE_HOST static constexpr auto GridSize3D(ck_tile::index_t num_kv_heads, - // ck_tile::index_t total_num_q_blocks) - // { - // // TODO: fix 3D grid - // return dim2(num_kv_heads, total_num_q_blocks); - // } - // Binary search to find the sequence index for a given target index CK_TILE_DEVICE static constexpr ck_tile::index_t find_seq_idx(const int32_t* query_start_len_ptr, @@ -277,6 +270,8 @@ struct UnifiedAttentionKernel const index_t num_queries_per_kv = kargs.num_queries_per_kv; + assert(BLOCK_M / num_queries_per_kv == BLOCK_Q); + // const index_t BLOCK_Q = BLOCK_M / num_queries_per_kv; // for simplicity, batch stride we just modify the pointer // const index_t num_head_q = kargs.num_head_q;