mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
refactor to clearer BLOCK Q logic
This commit is contained in:
@@ -30,11 +30,11 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("prec", "fp16", "data type. fp16/bf16")
|
||||
.insert("b", "3", "batch size")
|
||||
.insert("h", "32", "num of head, for q")
|
||||
.insert("h_k",
|
||||
"-1",
|
||||
"num of head, for k/v, -1 means equal to h\n"
|
||||
"if not equal to h, then this is GQA/MQA case")
|
||||
.insert("h", "8", "num head for k/v. num head for q is 4 times this")
|
||||
// .insert("h_k",
|
||||
// "-1",
|
||||
// "num of head, for k/v, -1 means equal to h\n"
|
||||
// "if not equal to h, then this is GQA/MQA case")
|
||||
.insert("s", "1024", "max_seqlen_q")
|
||||
.insert("nb", "1024", "num_blks")
|
||||
.insert("bs", "128", "BLOCK_SIZE for kv")
|
||||
@@ -101,12 +101,9 @@ struct Problem
|
||||
max_context_len = args.get_int("s_k");
|
||||
num_blks = args.get_int("nb");
|
||||
BLOCK_SIZE = args.get_int("bs");
|
||||
nhead_q = args.get_int("h");
|
||||
nhead_kv = args.get_int("h_k");
|
||||
if(nhead_kv < 0)
|
||||
{
|
||||
nhead_kv = nhead_q;
|
||||
}
|
||||
nhead_kv = args.get_int("h");
|
||||
// TODO: support other GQA/MQA cases than just 4x
|
||||
nhead_q = nhead_kv * 4;
|
||||
|
||||
hdim = args.get_int("d");
|
||||
query_lens = args.get_int_vec("query_lens");
|
||||
|
||||
@@ -58,14 +58,20 @@ struct unified_attention_kernel_traits
|
||||
static constexpr auto date_type = DataType;
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
|
||||
// TODO please fix this
|
||||
|
||||
static constexpr index_t BLOCK_M = 256;
|
||||
static constexpr index_t num_head_per_kv = 4;
|
||||
static constexpr index_t BLOCK_Q = BLOCK_M / num_head_per_kv;
|
||||
static constexpr index_t BLOCK_SIZE = 32;
|
||||
static constexpr index_t HEAD_SIZE = 128;
|
||||
|
||||
|
||||
// TODO please fix this to support also other num_qhead_per_kvhead
|
||||
static constexpr index_t num_qhead_per_kvhead = 4;
|
||||
static constexpr index_t BLOCK_Q = BLOCK_M / num_qhead_per_kvhead;
|
||||
|
||||
// BLOCK_M BLOCK_Q BLOCK_SIZE HEAD_SIZE
|
||||
using unified_attention_block_tile = sequence<BLOCK_M, BLOCK_Q, 32, 128>;
|
||||
using unified_attention_block_tile = sequence<BLOCK_M, BLOCK_Q, BLOCK_SIZE, HEAD_SIZE>;
|
||||
using unified_attention_warp_gemm_shape = sequence<32, 32, 16>;
|
||||
// need to have 8 warps per workgroup to have warp specialization
|
||||
using unified_attention_block_warps = sequence<8, 1, 1>;
|
||||
|
||||
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
|
||||
@@ -115,11 +121,9 @@ template <typename Kernel>
|
||||
float unified_attention_kernel_launch(const unified_attention_args& args,
|
||||
const stream_config& config)
|
||||
{
|
||||
|
||||
index_t BLOCK_Q = Kernel::BLOCK_M / args.num_queries_per_kv;
|
||||
|
||||
index_t BLOCK_Q = Kernel::BLOCK_Q;
|
||||
assert(BLOCK_Q == args.num_head_q / args.num_queries_per_kv && "BLOCK_Q must equal BLOCK_M / num_queries_per_kv");
|
||||
index_t total_num_q_blocks = args.num_tokens / BLOCK_Q + args.num_seqs;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
|
||||
@@ -170,13 +170,6 @@ struct UnifiedAttentionKernel
|
||||
return dim3(num_kv_heads * total_num_q_blocks);
|
||||
}
|
||||
|
||||
// CK_TILE_HOST static constexpr auto GridSize3D(ck_tile::index_t num_kv_heads,
|
||||
// ck_tile::index_t total_num_q_blocks)
|
||||
// {
|
||||
// // TODO: fix 3D grid
|
||||
// return dim2(num_kv_heads, total_num_q_blocks);
|
||||
// }
|
||||
|
||||
// Binary search to find the sequence index for a given target index
|
||||
CK_TILE_DEVICE static constexpr ck_tile::index_t
|
||||
find_seq_idx(const int32_t* query_start_len_ptr,
|
||||
@@ -277,6 +270,8 @@ struct UnifiedAttentionKernel
|
||||
|
||||
const index_t num_queries_per_kv = kargs.num_queries_per_kv;
|
||||
|
||||
assert(BLOCK_M / num_queries_per_kv == BLOCK_Q);
|
||||
|
||||
// const index_t BLOCK_Q = BLOCK_M / num_queries_per_kv;
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
// const index_t num_head_q = kargs.num_head_q;
|
||||
|
||||
Reference in New Issue
Block a user