refactor to clearer BLOCK Q logic

This commit is contained in:
Juuso Korhonen
2025-11-17 08:27:19 +00:00
parent 57a0ec8cc1
commit 5e43fd2dfc
3 changed files with 22 additions and 26 deletions

View File

@@ -30,11 +30,11 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
ck_tile::ArgParser arg_parser;
arg_parser.insert("prec", "fp16", "data type. fp16/bf16")
.insert("b", "3", "batch size")
.insert("h", "32", "num of head, for q")
.insert("h_k",
"-1",
"num of head, for k/v, -1 means equal to h\n"
"if not equal to h, then this is GQA/MQA case")
.insert("h", "8", "num head for k/v. num head for q is 4 times this")
// .insert("h_k",
// "-1",
// "num of head, for k/v, -1 means equal to h\n"
// "if not equal to h, then this is GQA/MQA case")
.insert("s", "1024", "max_seqlen_q")
.insert("nb", "1024", "num_blks")
.insert("bs", "128", "BLOCK_SIZE for kv")
@@ -101,12 +101,9 @@ struct Problem
max_context_len = args.get_int("s_k");
num_blks = args.get_int("nb");
BLOCK_SIZE = args.get_int("bs");
nhead_q = args.get_int("h");
nhead_kv = args.get_int("h_k");
if(nhead_kv < 0)
{
nhead_kv = nhead_q;
}
nhead_kv = args.get_int("h");
// TODO: support other GQA/MQA cases than just 4x
nhead_q = nhead_kv * 4;
hdim = args.get_int("d");
query_lens = args.get_int_vec("query_lens");

View File

@@ -58,14 +58,20 @@ struct unified_attention_kernel_traits
static constexpr auto date_type = DataType;
static constexpr bool is_masking = IsMasking;
// TODO please fix this
static constexpr index_t BLOCK_M = 256;
static constexpr index_t num_head_per_kv = 4;
static constexpr index_t BLOCK_Q = BLOCK_M / num_head_per_kv;
static constexpr index_t BLOCK_SIZE = 32;
static constexpr index_t HEAD_SIZE = 128;
// TODO please fix this to support also other num_qhead_per_kvhead
static constexpr index_t num_qhead_per_kvhead = 4;
static constexpr index_t BLOCK_Q = BLOCK_M / num_qhead_per_kvhead;
// BLOCK_M BLOCK_Q BLOCK_SIZE HEAD_SIZE
using unified_attention_block_tile = sequence<BLOCK_M, BLOCK_Q, 32, 128>;
using unified_attention_block_tile = sequence<BLOCK_M, BLOCK_Q, BLOCK_SIZE, HEAD_SIZE>;
using unified_attention_warp_gemm_shape = sequence<32, 32, 16>;
// need to have 8 warps per workgroup to have warp specialization
using unified_attention_block_warps = sequence<8, 1, 1>;
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
@@ -115,11 +121,9 @@ template <typename Kernel>
float unified_attention_kernel_launch(const unified_attention_args& args,
const stream_config& config)
{
index_t BLOCK_Q = Kernel::BLOCK_M / args.num_queries_per_kv;
index_t BLOCK_Q = Kernel::BLOCK_Q;
assert(BLOCK_Q == args.num_head_q / args.num_queries_per_kv && "BLOCK_Q must equal BLOCK_M / num_queries_per_kv");
index_t total_num_q_blocks = args.num_tokens / BLOCK_Q + args.num_seqs;
auto kargs = Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,

View File

@@ -170,13 +170,6 @@ struct UnifiedAttentionKernel
return dim3(num_kv_heads * total_num_q_blocks);
}
// CK_TILE_HOST static constexpr auto GridSize3D(ck_tile::index_t num_kv_heads,
// ck_tile::index_t total_num_q_blocks)
// {
// // TODO: fix 3D grid
// return dim2(num_kv_heads, total_num_q_blocks);
// }
// Binary search to find the sequence index for a given target index
CK_TILE_DEVICE static constexpr ck_tile::index_t
find_seq_idx(const int32_t* query_start_len_ptr,
@@ -277,6 +270,8 @@ struct UnifiedAttentionKernel
const index_t num_queries_per_kv = kargs.num_queries_per_kv;
assert(BLOCK_M / num_queries_per_kv == BLOCK_Q);
// const index_t BLOCK_Q = BLOCK_M / num_queries_per_kv;
// for simplicity, batch stride we just modify the pointer
// const index_t num_head_q = kargs.num_head_q;