diff --git a/example/ck_tile/42_unified_attention/example_unified_attention.cpp b/example/ck_tile/42_unified_attention/example_unified_attention.cpp index e43a8df76e..03e5697ba0 100644 --- a/example/ck_tile/42_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/42_unified_attention/example_unified_attention.cpp @@ -26,7 +26,7 @@ #include "mask.hpp" // const ck_tile::index_t page_blk_size = 32; -const ck_tile::index_t num_queries_per_kv = 1; +// num_queries_per_kv is now a runtime arg (see parse_cmd_args) auto parse_cmd_args(int argc, char* argv[]) -> std::pair { @@ -34,10 +34,8 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair; \ + return unified_attention_kernel_dispatch(args, config); \ + } + std::pair unified_attention(const unified_attention_args& args, const stream_config& config) { - if(args.data_type == unified_attention_args::data_type_enum::fp16) + const bool is_mask = (args.mask_type != static_cast(mask_enum::no_mask)); + + // Route based on (data_type, mask, hdim, num_queries_per_kv). + // Only d128 MHA (8 warps, kBlockM=256) instances available. + // Decode-tuned instances require pipeline changes (NumWarpGroups must == 2, + // which means exactly 8 warps; fewer warps are not supported). + if(args.hdim == 128 && args.num_queries_per_kv == 1) { - if(args.mask_type == static_cast(mask_enum::no_mask)) + if(args.data_type == unified_attention_args::data_type_enum::fp16) { - using kernel_traits = - unified_attention_kernel_traits; - - return unified_attention_kernel_dispatch(args, config); + if(!is_mask) DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::fp16, false, 128, 256, 1) + else DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::fp16, true, 128, 256, 1) } - else + else if(args.data_type == unified_attention_args::data_type_enum::bf16) { - using kernel_traits = - unified_attention_kernel_traits; - - return unified_attention_kernel_dispatch(args, config); - } - } - else if(args.data_type == unified_attention_args::data_type_enum::bf16) - { - if(args.mask_type == static_cast(mask_enum::no_mask)) - { - using kernel_traits = - unified_attention_kernel_traits; - - return unified_attention_kernel_dispatch(args, config); - } - else - { - using kernel_traits = - unified_attention_kernel_traits; - - return unified_attention_kernel_dispatch(args, config); + if(!is_mask) DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::bf16, false, 128, 256, 1) + else DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::bf16, true, 128, 256, 1) } } + std::cerr << "unified_attention: no matching kernel instance for hdim=" << args.hdim + << " num_queries_per_kv=" << args.num_queries_per_kv + << " data_type=" << args.data_type << " mask_type=" << args.mask_type << std::endl; return std::make_pair(false, -1.f); } +#undef DISPATCH_UNIFIED_ATTENTION + } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp index 8087c4b8e6..bbd5ffb912 100644 --- a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp @@ -52,24 +52,29 @@ struct unified_attention_problem_traits +// Parameterized kernel traits: DataType, IsMasking, HeadSize, BlockM, NumQueriesPerKV +template struct unified_attention_kernel_traits { static constexpr auto date_type = DataType; static constexpr bool is_masking = IsMasking; - static constexpr index_t kBlockM = 256; + static constexpr index_t kBlockM = BlockM_; static constexpr index_t BLOCK_SIZE = 32; - static constexpr index_t HEAD_SIZE = 128; + static constexpr index_t HEAD_SIZE = HeadSize_; - // TODO please fix this to support also other num_queries_per_kv - static constexpr index_t num_queries_per_kv = 1; + static constexpr index_t num_queries_per_kv = NumQPerKV_; static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv; // kBlockM kBlockQ BLOCK_SIZE HEAD_SIZE using unified_attention_block_tile = sequence; + using unified_attention_warp_gemm_shape = sequence<32, 32, 16>; - // need to have 8 warps per workgroup to have warp specialization + // 8 warps for warp specialization; kBlockM must be 8 * 32 = 256 using unified_attention_block_warps = sequence<8, 1, 1>; using unified_attention_shape = TileUnifiedAttentionShape; }; +// Decode-tuned traits: fewer warps, smaller kBlockM for low-token workloads. +// NOTE: Currently cannot compile due to pipeline constraint (NumWarpGroups must == 2). +// Kept for future pipeline work. +template +struct unified_attention_decode_kernel_traits +{ + static constexpr auto date_type = DataType; + static constexpr bool is_masking = IsMasking; + + static constexpr index_t kBlockM = BlockM_; + static constexpr index_t BLOCK_SIZE = 32; + static constexpr index_t HEAD_SIZE = HeadSize_; + + static constexpr index_t num_queries_per_kv = NumQPerKV_; + static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv; + + using unified_attention_block_tile = sequence; + using unified_attention_warp_gemm_shape = sequence<32, 32, 16>; + using unified_attention_block_warps = sequence<2, 1, 1>; + + using unified_attention_shape = TileUnifiedAttentionShape; + + using unified_attention_traits = TileUnifiedAttentionTraits; + using unified_attention_mask = GenericAttentionMask; + + using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem< + typename unified_attention_problem_traits::qkvp_dtype, + typename unified_attention_problem_traits::qkvp_dtype, + typename unified_attention_problem_traits::qkvp_dtype, + typename unified_attention_problem_traits::acc_dtype, + typename unified_attention_problem_traits::acc_dtype, + typename unified_attention_problem_traits::acc_dtype, + typename unified_attention_problem_traits::lse_dtype, + typename unified_attention_problem_traits::qkvp_dtype, + typename unified_attention_problem_traits::acc_dtype, + typename unified_attention_problem_traits::o_dtype, + unified_attention_shape, + unified_attention_mask, + unified_attention_traits>; + + using unified_attention_pipeline = UnifiedAttentionPipeline; + + using epilogue = Default2DEpilogue< + Default2DEpilogueProblem::acc_dtype, + typename unified_attention_problem_traits::o_dtype, + true, true, true>>; + + using kernel = UnifiedAttentionKernel; +}; + template float unified_attention_kernel_launch(const unified_attention_args& args, const stream_config& config) { - index_t kBlockQ = Kernel::kBlockQ; - assert(args.num_queries_per_kv == Kernel::num_queries_per_kv && - "argument num_queries_per_kv must equal compiled num_queries_per_kv"); - assert(args.BLOCK_SIZE == Kernel::BLOCK_SIZE && - "argument BLOCK_SIZE must equal compiled BLOCK_SIZE"); - assert(kBlockQ == kBlockM / args.num_queries_per_kv && - "kBlockQ must equal kBlockM / num_queries_per_kv"); + constexpr index_t kBlockQ = Kernel::kBlockQ; index_t total_num_q_blocks = args.num_tokens / kBlockQ + args.num_seqs; auto kargs = Kernel::MakeKargs(args.q_ptr, args.k_ptr, diff --git a/include/ck_tile/host/reference/reference_batched_masking.hpp b/include/ck_tile/host/reference/reference_batched_masking.hpp index a172a0013e..749b88690d 100644 --- a/include/ck_tile/host/reference/reference_batched_masking.hpp +++ b/include/ck_tile/host/reference/reference_batched_masking.hpp @@ -20,7 +20,14 @@ CK_TILE_HOST void reference_batched_masking(HostTensor& c_b_m_n, cons { for(int m = 0; m < M; ++m) { - if(mask.IsOutOfSinkBound(m, n)) + const bool is_out_of_bound = [&]() { + if constexpr(requires { mask.IsOutOfSinkBound(m, n); }) + return mask.IsOutOfSinkBound(m, n); + else + return mask.IsOutOfBound(m, n); + }(); + + if(is_out_of_bound) c_b_m_n(batch, m, n) = -ck_tile::numeric::infinity(); } } diff --git a/include/ck_tile/ops/unified_attention/block/block_masking.hpp b/include/ck_tile/ops/unified_attention/block/block_masking.hpp index 33ca84d2c5..68eb6efc89 100644 --- a/include/ck_tile/ops/unified_attention/block/block_masking.hpp +++ b/include/ck_tile/ops/unified_attention/block/block_masking.hpp @@ -257,6 +257,19 @@ struct GenericAttentionMask index_t repeat_idx; }; +template +struct is_generic_attention_mask : std::false_type +{ +}; + +template +struct is_generic_attention_mask> : std::true_type +{ +}; + +template +static constexpr bool is_generic_attention_mask_v = is_generic_attention_mask::value; + // TODO: prefer use this function in host code // can convert from the FA style left/right to our generic coordinate // if left_size < 0 && right_size = 0, it is normal causal mask diff --git a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp index 68c53401cd..5a3d3262e4 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp @@ -40,9 +40,9 @@ struct TileUnifiedAttentionShape using Gemm1WarpTile = remove_cvref_t; static constexpr index_t NumGemm0Warps = - reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{}); + reduce_on_sequence(Gemm0BlockWarps{}, multiplies<>{}, number<1>{}); static constexpr index_t NumGemm1Warps = - reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}); + reduce_on_sequence(Gemm1BlockWarps{}, multiplies<>{}, number<1>{}); static_assert(NumGemm1Warps % NumGemm0Warps == 0); static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);