diff --git a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp index b621b0920b..94c2fe0a0d 100644 --- a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp @@ -177,38 +177,9 @@ float unified_attention_kernel_launch(const unified_attention_args& args, assert(BLOCK_Q == BLOCK_M / args.num_queries_per_kv && "BLOCK_Q must equal BLOCK_M / num_queries_per_kv"); index_t total_num_q_blocks = args.num_tokens / BLOCK_Q + args.num_seqs; - auto kargs = + auto kargs_lambda = [&]() { - if(Kernel::QuantEnum == UnifiedAttentionQuantScaleEnum::NO_SCALE) - { - return Kernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.o_ptr, - args.num_blks, - args.num_head_q, - args.num_queries_per_kv, - args.scale_s, - total_num_q_blocks, - args.query_stride_0, - args.query_stride_1, - args.stride_k_cache_0, - args.stride_k_cache_1, - args.stride_k_cache_2, - args.stride_k_cache_3, - args.stride_v_cache_0, - args.stride_v_cache_1, - args.stride_v_cache_2, - args.stride_v_cache_3, - args.output_stride_0, - args.output_stride_1, - args.block_tables_ptr, - args.block_table_stride, - args.seq_lens_ptr, - args.query_start_len_ptr, - args.num_seqs); - } - else + if constexpr(Kernel::kIsQuantized) { return Kernel::MakeKargs(args.q_ptr, args.k_ptr, @@ -241,8 +212,40 @@ float unified_attention_kernel_launch(const unified_attention_args& args, args.query_start_len_ptr, args.num_seqs); } - } dim3 grids = - Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks); + else + { + return Kernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.o_ptr, + args.num_blks, + args.num_head_q, + args.num_queries_per_kv, + args.scale_s, + total_num_q_blocks, + args.query_stride_0, + args.query_stride_1, + args.stride_k_cache_0, + args.stride_k_cache_1, + args.stride_k_cache_2, + args.stride_k_cache_3, + args.stride_v_cache_0, + args.stride_v_cache_1, + args.stride_v_cache_2, + args.stride_v_cache_3, + args.output_stride_0, + args.output_stride_1, + args.block_tables_ptr, + args.block_table_stride, + args.seq_lens_ptr, + args.query_start_len_ptr, + args.num_seqs); + } + }; + + auto kargs = kargs_lambda(); + + dim3 grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks); constexpr dim3 blocks = Kernel::BlockSize(); constexpr index_t kBlockPerCu = Kernel::kBlockPerCu; diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 072de8106a..0bf91fbb77 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -25,6 +25,7 @@ struct UnifiedAttentionKernel static constexpr ck_tile::index_t kBlockPerCu = UnifiedAttentionPipeline::kBlockPerCu; static_assert(kBlockPerCu > 0); + using QDataType = ck_tile::remove_cvref_t; using KDataType = ck_tile::remove_cvref_t; using VDataType = ck_tile::remove_cvref_t; @@ -38,6 +39,7 @@ struct UnifiedAttentionKernel static constexpr bool kPadHeadDimQ = UnifiedAttentionPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimV = UnifiedAttentionPipeline::kPadHeadDimV; static constexpr auto QuantEnum = UnifiedAttentionPipeline::Problem::QuantEnum; + // TODO add yjese static constexpr index_t HEAD_SIZE = UnifiedAttentionPipeline::HEAD_SIZE; @@ -50,13 +52,21 @@ struct UnifiedAttentionKernel static constexpr index_t BLOCK_Q = UnifiedAttentionPipeline::BLOCK_Q; // BLOCK size for K seqlen static constexpr index_t BLOCK_SIZE = UnifiedAttentionPipeline::BLOCK_SIZE; - - template // to avoid duplicated base class prblem, introduce an template + static constexpr bool kIsQuantized = (QuantEnum != UnifiedAttentionQuantScaleEnum::NO_SCALE); + template // to avoid duplicated base class problem, introduce an template // arg struct UnifiedAttentionEmptyKargs { }; + struct UnifiedAttentionQuantKargs + { + float scale_q; + float scale_k; + float scale_v; + float scale_out; + }; + // kargs use aggregate initializer, so no constructor will provided // use inheritance to minimize karg size // user need to use MakeKargs() function to create kargs. @@ -92,19 +102,10 @@ struct UnifiedAttentionKernel ck_tile::index_t output_stride_1; }; - struct UnifiedAttentionQuantKargs - { - float scale_q; - float scale_k; - float scale_v; - float scale_out; - }; struct UnifiedAttentionVarlenKargs : UnifiedAttentionCommonKargs, - std::conditional_t> + std::conditional_t> { const int32_t* block_tables_ptr; ck_tile::index_t block_table_stride; @@ -116,7 +117,7 @@ struct UnifiedAttentionKernel using Kargs = UnifiedAttentionVarlenKargs; - template + template CK_TILE_HOST static constexpr std::enable_if_t MakeKargs(const void* q_ptr, const void* k_ptr, @@ -153,7 +154,6 @@ struct UnifiedAttentionKernel num_head_q, num_queries_per_kv, static_cast(scale_s * ck_tile::log2e_v<>), - scale_s, total_num_q_blocks, query_stride_0, query_stride_1, @@ -167,6 +167,7 @@ struct UnifiedAttentionKernel stride_v_cache_3, output_stride_0, output_stride_1}, + {}, block_tables_ptr, block_table_stride, seq_lens_ptr, @@ -176,7 +177,7 @@ struct UnifiedAttentionKernel return kargs; } - template + template CK_TILE_HOST static constexpr std::enable_if_t MakeKargs(const void* q_ptr, const void* k_ptr, @@ -488,7 +489,7 @@ struct UnifiedAttentionKernel // both quantized auto o_acc_tile = [&]() { - if(QuantEnum != UnifiedAttentionQuantScaleEnum::NO_SCALE) + if constexpr(kIsQuantized) { if(std::is_same_v) {