Assert block_size num_queries_per_kv

This commit is contained in:
Tianxing Wu
2025-11-17 12:40:31 +00:00
parent 9b68bbd425
commit d3c5faf47e
3 changed files with 31 additions and 26 deletions

View File

@@ -25,6 +25,9 @@
#include "unified_attention.hpp"
#include "mask.hpp"
const ck_tile::index_t BLOCK_SIZE = 32;
const ck_tile::index_t num_queries_per_kv = 4;
auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParser>
{
ck_tile::ArgParser arg_parser;
@@ -37,7 +40,6 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
// "if not equal to h, then this is GQA/MQA case")
.insert("nb", "1024", "num_blks")
.insert("bs", "128", "BLOCK_SIZE for kv")
.insert("d", "128", "head dim for q & k")
.insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)")
// TODO scale factors
@@ -83,15 +85,14 @@ struct Problem
{
explicit Problem(const ck_tile::ArgParser& args)
{
data_type = args.get_str("prec") == "fp16"
? ck_tile::unified_attention_args::data_type_enum::fp16
: ck_tile::unified_attention_args::data_type_enum::bf16;
batch = args.get_int("b");
num_blks = args.get_int("nb");
BLOCK_SIZE = args.get_int("bs");
nhead_kv = args.get_int("h_k");
data_type = args.get_str("prec") == "fp16"
? ck_tile::unified_attention_args::data_type_enum::fp16
: ck_tile::unified_attention_args::data_type_enum::bf16;
batch = args.get_int("b");
num_blks = args.get_int("nb");
nhead_kv = args.get_int("h_k");
// TODO: support other GQA/MQA cases than just 4x
nhead_q = nhead_kv * 4;
nhead_q = nhead_kv * num_queries_per_kv;
hdim = args.get_int("d");
query_lens = args.get_int_vec("query_lens");
@@ -130,7 +131,6 @@ struct Problem
ck_tile::unified_attention_args::data_type_enum data_type;
ck_tile::index_t batch;
ck_tile::index_t num_blks;
ck_tile::index_t BLOCK_SIZE;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_kv;
ck_tile::index_t hdim;
@@ -302,12 +302,11 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
ck_tile::unified_attention_args args{};
args.data_type = problem.data_type;
args.num_seqs = problem.batch;
// args.seqlen_q = problem.seqlen_q;
// args.seqlen_k = problem.seqlen_k;
args.data_type = problem.data_type;
args.num_seqs = problem.batch;
args.num_head_q = problem.nhead_q;
args.num_queries_per_kv = problem.nhead_q / problem.nhead_kv;
args.num_queries_per_kv = num_queries_per_kv;
args.BLOCK_SIZE = BLOCK_SIZE;
args.mask_type = 2;
args.hdim = problem.hdim;
@@ -319,7 +318,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
args.k_ptr = k_buf.GetDeviceBuffer();
args.stride_k_cache_0 = problem.hdim * problem.nhead_kv * problem.BLOCK_SIZE;
args.stride_k_cache_0 = problem.hdim * problem.nhead_kv * BLOCK_SIZE;
args.stride_k_cache_1 = problem.hdim * problem.nhead_kv;
args.stride_k_cache_2 = problem.hdim;
args.stride_k_cache_3 = 1;
@@ -393,8 +392,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
ck_tile::index_t max_kv_len = max_element(eff_kv_lens);
ck_tile::index_t max_num_blocks_per_seq =
(max_kv_len + problem.BLOCK_SIZE - 1) / problem.BLOCK_SIZE;
ck_tile::index_t max_num_blocks_per_seq = (max_kv_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
// Create block_tables
ck_tile::DeviceMem block_tables_buf(problem.batch * max_num_blocks_per_seq *
@@ -426,6 +424,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
run_config.kernel_repeat};
auto [result, time] = ck_tile::unified_attention(args, stream_config);
if(!result)
{
std::cerr << "faild to run fmha_fwd_v3()" << std::endl;
@@ -485,18 +484,18 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
});
k_b.ForEach([&](auto& self, auto idx) {
// kv cache is paged
ck_tile::index_t table_col = int(idx[1] / problem.BLOCK_SIZE);
ck_tile::index_t table_col = int(idx[1] / BLOCK_SIZE);
ck_tile::index_t block_table_offset = b * max_num_blocks_per_seq + table_col;
ck_tile::index_t block_idx = block_tables_host[block_table_offset];
self(idx) = k(block_idx, idx[1] % problem.BLOCK_SIZE, idx[2], idx[3]);
self(idx) = k(block_idx, idx[1] % BLOCK_SIZE, idx[2], idx[3]);
});
v_b.ForEach([&](auto& self, auto idx) {
ck_tile::index_t table_col = int(idx[1] / problem.BLOCK_SIZE);
ck_tile::index_t table_col = int(idx[1] / BLOCK_SIZE);
ck_tile::index_t block_table_offset = b * max_num_blocks_per_seq + table_col;
ck_tile::index_t block_idx = block_tables_host[block_table_offset];
self(idx) = v(block_idx, idx[1] % problem.BLOCK_SIZE, idx[2], idx[3]);
self(idx) = v(block_idx, idx[1] % BLOCK_SIZE, idx[2], idx[3]);
});
// v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); });
@@ -534,12 +533,13 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
return std::make_tuple(1e-2, 1e-2);
}();
return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol);
return true;
}
int main(int argc, char* argv[])
{
auto [parse_result, args] = parse_cmd_args(argc, argv);
if(!parse_result)
{
std::cerr << "failed to parse command line arguments" << std::endl;

View File

@@ -28,6 +28,7 @@ struct unified_attention_args
index_t num_blks;
index_t num_head_q;
index_t num_queries_per_kv;
index_t BLOCK_SIZE;
index_t hdim;
// TODO window

View File

@@ -62,9 +62,9 @@ struct unified_attention_kernel_traits
static constexpr index_t BLOCK_SIZE = 32;
static constexpr index_t HEAD_SIZE = 128;
// TODO please fix this to support also other num_qhead_per_kvhead
static constexpr index_t num_qhead_per_kvhead = 4;
static constexpr index_t BLOCK_Q = BLOCK_M / num_qhead_per_kvhead;
// TODO please fix this to support also other num_queries_per_kv
static constexpr index_t num_queries_per_kv = 4;
static constexpr index_t BLOCK_Q = BLOCK_M / num_queries_per_kv;
// BLOCK_M BLOCK_Q BLOCK_SIZE HEAD_SIZE
using unified_attention_block_tile = sequence<BLOCK_M, BLOCK_Q, BLOCK_SIZE, HEAD_SIZE>;
@@ -120,6 +120,10 @@ float unified_attention_kernel_launch(const unified_attention_args& args,
const stream_config& config)
{
index_t BLOCK_Q = Kernel::BLOCK_Q;
assert(args.num_queries_per_kv == Kernel::num_queries_per_kv &&
"argument num_queries_per_kv must equal compiled num_queries_per_kv");
assert(args.BLOCK_SIZE == Kernel::BLOCK_SIZE &&
"argument BLOCK_SIZE must equal compiled BLOCK_SIZE");
assert(BLOCK_Q == args.num_head_q / args.num_queries_per_kv &&
"BLOCK_Q must equal BLOCK_M / num_queries_per_kv");
index_t total_num_q_blocks = args.num_tokens / BLOCK_Q + args.num_seqs;