add handling for -1 k heads arg

This commit is contained in:
Juuso Korhonen
2025-11-17 07:36:42 +00:00
parent 4a13749f7f
commit 57a0ec8cc1

View File

@@ -32,7 +32,7 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
.insert("b", "3", "batch size")
.insert("h", "32", "num of head, for q")
.insert("h_k",
"8",
"-1",
"num of head, for k/v, -1 means equal to h\n"
"if not equal to h, then this is GQA/MQA case")
.insert("s", "1024", "max_seqlen_q")
@@ -103,6 +103,10 @@ struct Problem
BLOCK_SIZE = args.get_int("bs");
nhead_q = args.get_int("h");
nhead_kv = args.get_int("h_k");
if(nhead_kv < 0)
{
nhead_kv = nhead_q;
}
hdim = args.get_int("d");
query_lens = args.get_int_vec("query_lens");