mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
add handling for -1 k heads arg
This commit is contained in:
@@ -32,7 +32,7 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
|
||||
.insert("b", "3", "batch size")
|
||||
.insert("h", "32", "num of head, for q")
|
||||
.insert("h_k",
|
||||
"8",
|
||||
"-1",
|
||||
"num of head, for k/v, -1 means equal to h\n"
|
||||
"if not equal to h, then this is GQA/MQA case")
|
||||
.insert("s", "1024", "max_seqlen_q")
|
||||
@@ -103,6 +103,10 @@ struct Problem
|
||||
BLOCK_SIZE = args.get_int("bs");
|
||||
nhead_q = args.get_int("h");
|
||||
nhead_kv = args.get_int("h_k");
|
||||
if(nhead_kv < 0)
|
||||
{
|
||||
nhead_kv = nhead_q;
|
||||
}
|
||||
|
||||
hdim = args.get_int("d");
|
||||
query_lens = args.get_int_vec("query_lens");
|
||||
|
||||
Reference in New Issue
Block a user