From d3c5faf47e280f01365d371acaa5bc8584155cf1 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Mon, 17 Nov 2025 12:40:31 +0000 Subject: [PATCH] Assert block_size num_queries_per_kv --- .../example_unified_attention.cpp | 46 +++++++++---------- .../unified_attention.hpp | 1 + .../unified_attention_impl.hpp | 10 ++-- 3 files changed, 31 insertions(+), 26 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 0e3b425287..2195ad03a3 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -25,6 +25,9 @@ #include "unified_attention.hpp" #include "mask.hpp" +const ck_tile::index_t BLOCK_SIZE = 32; +const ck_tile::index_t num_queries_per_kv = 4; + auto parse_cmd_args(int argc, char* argv[]) -> std::pair { ck_tile::ArgParser arg_parser; @@ -37,7 +40,6 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair; @@ -120,6 +120,10 @@ float unified_attention_kernel_launch(const unified_attention_args& args, const stream_config& config) { index_t BLOCK_Q = Kernel::BLOCK_Q; + assert(args.num_queries_per_kv == Kernel::num_queries_per_kv && + "argument num_queries_per_kv must equal compiled num_queries_per_kv"); + assert(args.BLOCK_SIZE == Kernel::BLOCK_SIZE && + "argument BLOCK_SIZE must equal compiled BLOCK_SIZE"); assert(BLOCK_Q == args.num_head_q / args.num_queries_per_kv && "BLOCK_Q must equal BLOCK_M / num_queries_per_kv"); index_t total_num_q_blocks = args.num_tokens / BLOCK_Q + args.num_seqs;