diff --git a/example/ck_tile/01_unified_attention/CMakeLists.txt b/example/ck_tile/01_unified_attention/CMakeLists.txt index 2150ea09b7..11c413d192 100644 --- a/example/ck_tile/01_unified_attention/CMakeLists.txt +++ b/example/ck_tile/01_unified_attention/CMakeLists.txt @@ -178,6 +178,15 @@ # --- Unified Attention target (kept) --- +# +set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) +# Currently only gfx9 archs are supported by FMHA +list(FILTER INST_TARGETS INCLUDE REGEX "gfx9") +if(NOT INST_TARGETS) + message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + return() +endif() + set(EXAMPLE_FMHA_FWD_V3 "tile_example_unified_attention") message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}") diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index dc0e389f0a..0885179c77 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -154,7 +154,6 @@ struct Problem float scale_k; float scale_v; mask_info mask; - TensorLayout output_layout; std::vector query_lens; std::vector kv_lens; }; @@ -350,8 +349,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) args.stride_v_cache_3 = args.stride_k_cache_3; args.o_ptr = o_buf.GetDeviceBuffer(); - args.output_stride_0 = query_stride_0; - args.output_stride_1 = query_stride_1; + args.output_stride_0 = args.query_stride_0; + args.output_stride_1 = args.query_stride_1; // Optional cumulative seqlen overrides (exclude PAD) auto make_effective_vec = [&](const std::vector& opt_vec, ck_tile::index_t fallback) { @@ -386,19 +385,19 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) }; calculate_cumulative(eff_query_lens, cu_query_lens); - ck_tile::DeviceMem seq_lens_buf(kv_lens.size()); + ck_tile::DeviceMem seq_lens_buf(eff_kv_lens.size()); ck_tile::DeviceMem query_start_len_buf(cu_query_lens.size()); - seq_lens_buf.ToDevice(kv_lens.data()); + seq_lens_buf.ToDevice(eff_kv_lens.data()); query_start_len_buf.ToDevice(cu_query_lens.data()); args.seq_lens_ptr =reinterpret_cast(seq_lens_buf.GetDeviceBuffer()); args.query_start_len_ptr =reinterpret_cast(query_start_len_buf.GetDeviceBuffer()); - auto max_kv_len = std::max_element(problem.kv_lens.begin(), problem.kv_lens.end()); + int max_kv_len = std::max_element(eff_kv_lens.begin(), eff_kv_lens.end()); - index_t max_num_blocks_per_seq = (max_kv_len + problem.BLOCK_SIZE - 1) / problem.BLOCK_SIZE + ck_tile::index_t max_num_blocks_per_seq = (max_kv_len + problem.BLOCK_SIZE - 1) / problem.BLOCK_SIZE; // Create block_tables ck_tile::DeviceMem block_tables_buf(problem.batch * max_num_blocks_per_seq * sizeof(ck_tile::index_t)); @@ -433,30 +432,24 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) return false; } - std::size_t flop = [&] { - if(problem.mask.type == mask_enum::no_mask) - { - return 4 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k * - problem.hdim; - } - else - { - /// FIXME: Use a more accurate method; for now, we’re just dividing the flop by 2. - return 2 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k * - problem.hdim; - } - }(); + // std::size_t flop = [&] { + // if(problem.mask.type == mask_enum::no_mask) + // { + // return 4 * args.num_tokens * problem.nhead_q * + // problem.hdim; + // } + // else + // { + // /// FIXME: Use a more accurate method; for now, we’re just dividing the flop by 2. + // return 2 * args.num_tokens * problem.nhead_q * + // problem.hdim; + // } + // }(); + // TODO fix this + std::size_t flop = 1; float tflops = static_cast(flop) / 1.e9 / time; std::cout << "[" << problem.data_type << "|"; - if(problem.input_layout == problem.output_layout) - { - std::cout << problem.input_layout; - } - else - { - std::cout << problem.input_layout << "-" << problem.output_layout; - } std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv << ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim << ", scale_s:" << problem.softmax_scale << ", mask:" << problem.mask << std::fixed @@ -469,85 +462,70 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) } // transpose tensor descriptors from bhsd to bshd if necessary - if(problem.input_layout != TensorLayout::bshd) - { - q = q.transpose({0, 2, 1, 3}); - k = k.transpose({0, 2, 1, 3}); - v = v.transpose({0, 2, 1, 3}); - } + // if(problem.input_layout != TensorLayout::bshd) + // { + // q = q.transpose({0, 2, 1, 3}); + // k = k.transpose({0, 2, 1, 3}); + // v = v.transpose({0, 2, 1, 3}); + // } - ck_tile::HostTensor o_ref(problem.get_output_shape()); - if(problem.output_layout != TensorLayout::bshd) - { - o_ref = o_ref.transpose({0, 2, 1, 3}); - } + // ck_tile::HostTensor o_ref(problem.get_output_shape()); + // if(problem.output_layout != TensorLayout::bshd) + // { + // o_ref = o_ref.transpose({0, 2, 1, 3}); + // } // If variable lengths are provided, compute per-batch references // with the effective lengths; else compute a single full reference. - if(has_varlen_q || has_varlen_k) + // Variable-length aware verification: zero-fill padded region and only compute valid part. + o_ref.SetZero(); + + for(int b = 0; b < problem.batch; ++b) { - // Variable-length aware verification: zero-fill padded region and only compute valid part. - o_ref.SetZero(); + const ck_tile::index_t seqlen_q_eff = eff_q_vec[b]; + const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b]; - for(int b = 0; b < problem.batch; ++b) - { - const ck_tile::index_t seqlen_q_eff = eff_q_vec[b]; - const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b]; + if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0) + continue; - if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0) - continue; + // Slice current batch from inputs (bshd) and build single-batch tensors + ck_tile::HostTensor q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + ck_tile::HostTensor k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); + ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); + ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); - // Slice current batch from inputs (bshd) and build single-batch tensors - ck_tile::HostTensor q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); - ck_tile::HostTensor k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); - ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); - ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + // Copy effective region + q_b.ForEach([&](auto& self, auto idx) { + // idx: [0, s, h, d] + self(idx) = q(b, idx[1], idx[2], idx[3]); + }); + k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); }); + v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); }); - // Copy effective region - q_b.ForEach([&](auto& self, auto idx) { - // idx: [0, s, h, d] - self(idx) = q(b, idx[1], idx[2], idx[3]); - }); - k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); }); - v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); }); - - // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors) - host::fmha_fwd(q_b, - k_b, - v_b, - problem.mask, - o_b, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales{problem.softmax_scale}); - - // Scatter into o_ref's bshd descriptor memory - for(int s = 0; s < seqlen_q_eff; ++s) - { - for(int h = 0; h < problem.nhead_q; ++h) - { - for(int d = 0; d < problem.hdim; ++d) - { - o_ref(b, s, h, d) = o_b(0, s, h, d); - } - } - } - } - } - else - { - // No varlen override: compute the full reference once - host::fmha_fwd(q, - k, - v, + // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors) + host::fmha_fwd(q_b, + k_b, + v_b, problem.mask, - o_ref, + o_b, ck_tile::identity{}, ck_tile::identity{}, ck_tile::identity{}, ck_tile::scales{problem.softmax_scale}); + + // Scatter into o_ref's bshd descriptor memory + for(int s = 0; s < seqlen_q_eff; ++s) + { + for(int h = 0; h < problem.nhead_q; ++h) + { + for(int d = 0; d < problem.hdim; ++d) + { + o_ref(b, s, h, d) = o_b(0, s, h, d); + } + } + } } + ck_tile::HostTensor o(problem.get_output_shape()); o_buf.FromDevice(o.data()); diff --git a/example/ck_tile/01_unified_attention/instances/unified_attention_d128_bf16_nmask.cpp b/example/ck_tile/01_unified_attention/instances/unified_attention_d128_bf16_nmask.cpp index a6806b95d7..391103891a 100644 --- a/example/ck_tile/01_unified_attention/instances/unified_attention_d128_bf16_nmask.cpp +++ b/example/ck_tile/01_unified_attention/instances/unified_attention_d128_bf16_nmask.cpp @@ -7,7 +7,7 @@ namespace ck_tile { using kernel_traits = - unified_attention_kernel_traits; + unified_attention_kernel_traits; INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) diff --git a/example/ck_tile/01_unified_attention/instances/unified_attention_d128_fp16_mask.cpp b/example/ck_tile/01_unified_attention/instances/unified_attention_d128_fp16_mask.cpp index a710efd2cb..f2cc00f835 100644 --- a/example/ck_tile/01_unified_attention/instances/unified_attention_d128_fp16_mask.cpp +++ b/example/ck_tile/01_unified_attention/instances/unified_attention_d128_fp16_mask.cpp @@ -7,7 +7,7 @@ namespace ck_tile { using kernel_traits = - unified_attention_kernel_traits; + unified_attention_kernel_traits; INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) diff --git a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp index 65f17fa251..f83209a2c4 100644 --- a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp @@ -20,7 +20,7 @@ #include "unified_attention.hpp" #include "mask.hpp" -#define INST_unified_attention_DISPATCH(kernel_traits) \ +#define INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) \ template <> \ std::pair unified_attention_kernel_dispatch( \ const unified_attention_args& args, const stream_config& config) \ @@ -73,7 +73,6 @@ struct unified_attention_kernel_traits using unified_attention_traits = TileUnifiedAttentionTraits; @@ -110,6 +109,7 @@ struct unified_attention_kernel_traits template float unified_attention_kernel_launch(const unified_attention_args& args, const stream_config& config) { + index_t total_num_q_blocks = args.num_tokens / Kernel::BLOCK_Q + args.num_seqs; auto kargs = Kernel::MakeKargs(args.q_ptr, args.k_ptr, @@ -123,6 +123,7 @@ float unified_attention_kernel_launch(const unified_attention_args& args, const args.scale_k, args.scale_v, args.scale_out, + total_num_q_blocks, args.query_stride_0, args.query_stride_1, args.stride_k_cache_0, @@ -141,9 +142,6 @@ float unified_attention_kernel_launch(const unified_attention_args& args, const args.num_seqs ); - index_t total_num_q_blocks = args.num_tokens / Kernel::BLOCK_Q + args.num_seqs; - - dim3 grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks); constexpr dim3 blocks = Kernel::BlockSize(); constexpr index_t kBlockPerCu = Kernel::kBlockPerCu; diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index b5d46c754f..0ce14afdd0 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -156,12 +156,11 @@ struct UnifiedAttentionKernel stride_v_cache_3, output_stride_0, output_stride_1}, - { block_tables_ptr, seq_lens_ptr, query_start_len_ptr, num_seqs - }}; + }; return kargs; } @@ -344,7 +343,7 @@ struct UnifiedAttentionKernel index_t query_len_padded = integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q; - const bool is_query_len_padded = (cur_batch_query_len % BLOCK_Q == 0); + // const bool is_query_len_padded = (cur_batch_query_len % BLOCK_Q == 0); // Q/K/V DRAM and DRAM window const auto q_dram = [&]() { @@ -359,7 +358,7 @@ struct UnifiedAttentionKernel q_dram_base, // block sizes make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED), - sequence{} + sequence{} ); // pads to (seq_len_padded, num_head_q, HEAD_SIZE_PADDED) const auto q_dram_merged = transform_tensor_view( @@ -486,7 +485,7 @@ struct UnifiedAttentionKernel o_dram_base, // block sizes make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED), - sequence{} + sequence{} ); // pads to (seq_len_padded, num_head_q, HEAD_SIZE_PADDED) const auto o_dram_merged = transform_tensor_view( diff --git a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp index f10b064487..b27a09a1b4 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp @@ -12,13 +12,11 @@ namespace ck_tile { template struct TileUnifiedAttentionTraits { static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; static constexpr bool kPadHeadDim = kPadHeadDim_; - static constexpr bool kStoreLSE = kStoreLSE_; static constexpr index_t kBlockPerCu = kBlockPerCu_; }; } diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 0b7d313757..3dee9b4ad8 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -256,7 +256,6 @@ struct UnifiedAttentionPipeline using VDataType = ck_tile::remove_cvref_t; using SaccDataType = ck_tile::remove_cvref_t; using SMPLComputeDataType = ck_tile::remove_cvref_t; - using LSEDataType = ck_tile::remove_cvref_t; using PDataType = ck_tile::remove_cvref_t; using OaccDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; @@ -372,11 +371,9 @@ struct UnifiedAttentionPipeline template @@ -1206,14 +1203,12 @@ struct UnifiedAttentionPipeline template + typename VDramBlockWindowTmp> CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const void* block_tables_ptr, index_t block_table_offset, - LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile FmhaMask mask, float scale_s, void* smem_ptr) const @@ -1228,7 +1223,6 @@ struct UnifiedAttentionPipeline identity{}, block_tables_ptr, block_table_offset, - lse_dram_block_window_tmp, identity{}, identity{}, identity{}, diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp index 3f676bf01d..d21d8316af 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp @@ -51,8 +51,6 @@ struct UnifiedAttentionPipelineProblem static constexpr bool kPadHeadDim = Traits::kPadHeadDim; static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap; static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ; - static constexpr auto BiasEnum = Traits::BiasEnum; - static constexpr bool kStoreLSE = Traits::kStoreLSE; static constexpr bool kHasDropout = Traits::kHasDropout; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;