From 6ef0b9da8c9f654a1fbc55bd2007f8c3f4f26578 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Tue, 18 Nov 2025 08:57:30 +0000 Subject: [PATCH 01/12] fixing --- .../example_unified_attention.cpp | 98 +++++++++---------- .../kernel/unified_attention_kernel.hpp | 10 +- 2 files changed, 50 insertions(+), 58 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index d1eb6f5425..fdf9c7122a 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -34,7 +34,7 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair& q_bshd, const ck_tile::HostTensor& k_bshd, const ck_tile::HostTensor& v_bshd, - const mask_info& mask, + // const mask_info& mask, ck_tile::HostTensor& o_bshd, const QElementOp& q_element_op = {}, const KElementOp& k_element_op = {}, @@ -222,61 +223,34 @@ CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, ck_tile::HostTensor s_host_ref({nhead_q, seqlen_q, seqlen_kv}); ck_tile::HostTensor p_host_ref({nhead_q, seqlen_q, seqlen_kv}); - // do computation for each batch for(int b = 0; b < batch_size; ++b) { // copy per-batch data from input tensors // clang-format off - q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , - idx[2]); }); k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], - idx[0] / nr, idx[2]); }); v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = - v_bshd(b, idx[2], idx[0] / nr, idx[1]); }); + q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , + idx[2]); }); + k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], + idx[0] / nr, idx[2]); }); + v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = + v_bshd(b, idx[2], idx[0] / nr, idx[1]); }); // clang-format on ck_tile::reference_batched_gemm( q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); - if(mask.type == mask_enum::no_mask) - { - ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv}); - } - else if(mask.type == mask_enum::window_generic) - { - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, mask.right, seqlen_q, seqlen_kv)); - } - else - { - // if left window size is negative, means causal - // else means generic (for current batch) - if(mask.left < 0) - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, - mask.right, - seqlen_q, - seqlen_kv, - mask.type == mask_enum::mask_top_left)); - else - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, - mask.right, - seqlen_q, - seqlen_kv, - mask.type == mask_enum::mask_top_left)); - } - + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + -1, + 0, + seqlen_q, + seqlen_kv, + true)); ck_tile::reference_batched_softmax( s_host_ref, p_host_ref, ck_tile::identity{}); - ck_tile::reference_batched_gemm( p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op); - + // copy resulting per-batch data to the output tensor o_host_ref.ForEach( [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); }); @@ -528,7 +502,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) host::fmha_fwd(q_b, k_b, v_b, - problem.mask, + // problem.mask, o_b, ck_tile::identity{}, ck_tile::identity{}, @@ -551,13 +525,31 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) ck_tile::HostTensor o(problem.get_output_shape()); o_buf.FromDevice(o.data()); - const auto [rtol, atol] = [&] { - if constexpr(std::is_same_v) - return std::make_tuple(1e-3, 1e-3); - else - return std::make_tuple(1e-2, 1e-2); - }(); - return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); + // const auto [rtol, atol] = [&] { + // if constexpr(std::is_same_v) + // return std::make_tuple(1e-3, 1e-3); + // else + // return std::make_tuple(1e-2, 1e-2); + // }(); + + // Print some of the output data for debugging + std::cout << "\nFirst few elements of output tensor o:" << std::endl; + for(int b = 0; b < std::min(2, static_cast(problem.batch)); ++b) { + std::cout << "Batch " << b << ":" << std::endl; + for(int s = 0; s < std::min(5, static_cast(eff_query_lens[b])); ++s) { + for(int h = 0; h < std::min(2, static_cast(problem.nhead_q)); ++h) { + for(int d = 0; d < std::min(4, static_cast(problem.hdim)); ++d) { + std::cout << "o[" << b << "][" << s << "][" << h << "][" << d << "] = " + << static_cast(o(b, s, h, d)) + << std::endl; + } + } + } + } + + + + return 1; // ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); } int main(int argc, char* argv[]) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 2f1b574655..1d34ee5670 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -321,7 +321,7 @@ struct UnifiedAttentionKernel const index_t context_len = seq_len - cur_batch_query_len; index_t _max_seq_prefix_len = - (context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) // num_queries_per_kv + (context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) + 1); if(seq_len < _max_seq_prefix_len) @@ -457,10 +457,10 @@ struct UnifiedAttentionKernel FmhaMask mask = [&]() { if constexpr(kHasMask) return ck_tile::make_generic_attention_mask_from_lr_window( - seq_len - cur_batch_query_len, // y (i.e. context) - cur_batch_query_len, // x (i.e. extend) - seq_len, // y_total (x + y) - cur_batch_query_len, // x_total + -1, + 0, + cur_batch_query_len, // y_total + seq_len, // x_total num_queries_per_kv // the same sequence index is repeated num_queries_per_kv // times along x dim of the tile ); From de995fea7116cceed2d01e195ecc583b079c61cf Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Tue, 18 Nov 2025 13:04:58 +0000 Subject: [PATCH 02/12] Various fixes --- .../example_unified_attention.cpp | 53 +++++++++++-------- .../unified_attention_impl.hpp | 2 +- .../kernel/unified_attention_kernel.hpp | 42 ++++++--------- .../pipeline/unified_attention_pipeline.hpp | 1 - 4 files changed, 49 insertions(+), 49 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index fdf9c7122a..50eac35c3f 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -343,8 +343,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) }; calculate_cumulative(eff_query_lens, cu_query_lens); - ck_tile::DeviceMem seq_lens_buf(eff_kv_lens.size()); - ck_tile::DeviceMem query_start_len_buf(cu_query_lens.size()); + ck_tile::DeviceMem seq_lens_buf(eff_kv_lens.size() * sizeof(ck_tile::index_t)); + ck_tile::DeviceMem query_start_len_buf(cu_query_lens.size() * sizeof(ck_tile::index_t)); seq_lens_buf.ToDevice(eff_kv_lens.data()); query_start_len_buf.ToDevice(cu_query_lens.data()); @@ -525,31 +525,40 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) ck_tile::HostTensor o(problem.get_output_shape()); o_buf.FromDevice(o.data()); - // const auto [rtol, atol] = [&] { - // if constexpr(std::is_same_v) - // return std::make_tuple(1e-3, 1e-3); - // else - // return std::make_tuple(1e-2, 1e-2); - // }(); + const auto [rtol, atol] = [&] { + if constexpr(std::is_same_v) + return std::make_tuple(1e-3, 1e-3); + else + return std::make_tuple(1e-2, 1e-2); + }(); - // Print some of the output data for debugging - std::cout << "\nFirst few elements of output tensor o:" << std::endl; - for(int b = 0; b < std::min(2, static_cast(problem.batch)); ++b) { - std::cout << "Batch " << b << ":" << std::endl; - for(int s = 0; s < std::min(5, static_cast(eff_query_lens[b])); ++s) { - for(int h = 0; h < std::min(2, static_cast(problem.nhead_q)); ++h) { - for(int d = 0; d < std::min(4, static_cast(problem.hdim)); ++d) { - std::cout << "o[" << b << "][" << s << "][" << h << "][" << d << "] = " - << static_cast(o(b, s, h, d)) - << std::endl; + size_t total = static_cast(problem.num_tokens) * + static_cast(problem.nhead_q) * + static_cast(problem.hdim); + + size_t nonzero = 0; + + for (int b = 0; b < problem.batch; ++b) { + for (int s = 0; s < eff_query_lens[b]; ++s) { + for (int h = 0; h < problem.nhead_q; ++h) { + for (int d = 0; d < problem.hdim; ++d) { + if (static_cast(o(b, s, h, d)) != 0.0f) { + nonzero++; + } } } } } - - - - return 1; // ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); + + float percent = (total > 0) + ? (100.0f * static_cast(nonzero) / static_cast(total)) + : 0.0f; + + std::cout << "\nNon-zero elements in output tensor o: " + << nonzero << " / " << total + << " (" << percent << "%)\n"; + + return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); } int main(int argc, char* argv[]) diff --git a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp index 8b1536f52a..855c99f841 100644 --- a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp @@ -124,7 +124,7 @@ float unified_attention_kernel_launch(const unified_attention_args& args, "argument num_queries_per_kv must equal compiled num_queries_per_kv"); assert(args.BLOCK_SIZE == Kernel::BLOCK_SIZE && "argument BLOCK_SIZE must equal compiled BLOCK_SIZE"); - assert(BLOCK_Q == args.num_head_q / args.num_queries_per_kv && + assert(BLOCK_Q == BLOCK_M / args.num_queries_per_kv && "BLOCK_Q must equal BLOCK_M / num_queries_per_kv"); index_t total_num_q_blocks = args.num_tokens / BLOCK_Q + args.num_seqs; auto kargs = Kernel::MakeKargs(args.q_ptr, diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 1d34ee5670..969a9aac82 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -184,7 +184,7 @@ struct UnifiedAttentionKernel while(left < right) { ck_tile::index_t mid = (left + right) / 2; - ck_tile::index_t val = query_start_len_ptr[mid]; + ck_tile::index_t val = amd_wave_read_first_lane(query_start_len_ptr[mid]); ck_tile::index_t mid_val = use_q_block_mode ? (val / block_q + mid) : val; if(mid_val <= target_idx) @@ -206,7 +206,7 @@ struct UnifiedAttentionKernel using namespace ck_tile; constexpr index_t NUM_XCDS = 8; - const index_t GRID_MN = kargs.total_num_q_blocks * (kargs.num_head_q); + const index_t GRID_MN = kargs.total_num_q_blocks * (kargs.num_head_q / kargs.num_queries_per_kv); // Number of pids per XCD in the new arrangement const index_t pids_per_xcd = (GRID_MN + NUM_XCDS - 1) / NUM_XCDS; @@ -245,10 +245,7 @@ struct UnifiedAttentionKernel // const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, // UnifiedAttentionPipeline::kN1); - const index_t i_tile_m = pid % total_num_q_blocks; // Query block index - const index_t i_tile_n = pid / total_num_q_blocks; // Head index - - return ck_tile::make_tuple(i_tile_m, i_tile_n); + return ck_tile::make_tuple(pid / total_num_q_blocks, pid % total_num_q_blocks); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } @@ -277,7 +274,6 @@ struct UnifiedAttentionKernel // const index_t num_head_q = kargs.num_head_q; // const index_t num_head_k = num_head_q / num_queries_per_kv; - pid = RemapTileIndices(pid, kargs); // divide problem @@ -295,19 +291,15 @@ struct UnifiedAttentionKernel BLOCK_Q, true); // which batch - const index_t q_block_start_idx = - amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]); + const index_t q_block_start_idx = kargs.query_start_len_ptr[seq_idx] / BLOCK_Q + seq_idx; - const index_t q_block_local_idx = - amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx); + const index_t q_block_local_idx = amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx); - const index_t cur_batch_in_all_start_index = - amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]); - const index_t cur_batch_in_all_stop_index = - amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx + 1]); + const index_t cur_batch_in_all_start_index = kargs.query_start_len_ptr[seq_idx]; + const index_t cur_batch_in_all_stop_index = kargs.query_start_len_ptr[seq_idx + 1]; const index_t cur_batch_query_len = - cur_batch_in_all_stop_index - cur_batch_in_all_start_index; + amd_wave_read_first_lane(cur_batch_in_all_stop_index - cur_batch_in_all_start_index); // TODO check if we get the block size info from pipeline if(q_block_local_idx * BLOCK_Q >= cur_batch_query_len) @@ -315,14 +307,14 @@ struct UnifiedAttentionKernel return; } - const index_t query_pos = q_block_local_idx * BLOCK_Q; + const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * BLOCK_Q); const index_t seq_len = kargs.seq_lens_ptr[seq_idx]; - const index_t context_len = seq_len - cur_batch_query_len; + const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len); index_t _max_seq_prefix_len = - (context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) - + 1); + amd_wave_read_first_lane((context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) + + 1)); if(seq_len < _max_seq_prefix_len) { @@ -330,7 +322,7 @@ struct UnifiedAttentionKernel } const auto max_seq_prefix_len = _max_seq_prefix_len; - const index_t num_blocks = (max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + const index_t num_blocks = amd_wave_read_first_lane((max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE); // TODO sliding window const index_t num_blocks_start = 0; @@ -357,7 +349,7 @@ struct UnifiedAttentionKernel const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + kv_head_offset; ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + o_ptr_offset; - index_t query_len_padded = integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q; + index_t query_len_padded = amd_wave_read_first_lane(integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q); // const bool is_query_len_padded = (cur_batch_query_len % BLOCK_Q == 0); // Q/K/V DRAM and DRAM window @@ -367,20 +359,20 @@ struct UnifiedAttentionKernel make_tuple(cur_batch_query_len, num_queries_per_kv, HEAD_SIZE), make_tuple(kargs.query_stride_0, kargs.query_stride_1, 1), number{}, - number<2>{}); + number<1>{}); const auto q_dram_pad = pad_tensor_view( // aling seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED q_dram_base, // block sizes - make_tuple(number{}, number<1>{}, number{}), + make_tuple(number{}, 1, HEAD_SIZE_PADDED), sequence{}); // pads to (seq_len_padded, num_head_q, // HEAD_SIZE_PADDED) const auto q_dram_merged = transform_tensor_view( q_dram_pad, make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)), - make_pass_through_transform(number{})), + make_pass_through_transform(HEAD_SIZE_PADDED)), make_tuple(sequence<0, 1>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1>{})); // flattens the first two dims, head idx is the fastest diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index a575230ef6..3d941f5503 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -6,7 +6,6 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" - #define ENABLE_ASM_MARKER 1 #if ENABLE_ASM_MARKER #define ASM_MARKER(marker) \ From f552cd7841d40f06987f994bf609d2ff73ca2904 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 20 Nov 2025 11:34:39 +0000 Subject: [PATCH 03/12] ref data copying --- .../01_unified_attention/example_unified_attention.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 50eac35c3f..5bc6544746 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -32,7 +32,7 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + ck_tile::index_t seq_q_off = cu_query_lens[b]; // Copy effective region q_b.ForEach([&](auto& self, auto idx) { // idx: [0, s, h, d] - self(idx) = q(b, idx[1], idx[2], idx[3]); + self(idx) = q(seq_q_off + idx[1], idx[2], idx[3]); }); k_b.ForEach([&](auto& self, auto idx) { // kv cache is paged @@ -516,7 +517,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) { for(int d = 0; d < problem.hdim; ++d) { - o_ref(b, s, h, d) = o_b(0, s, h, d); + o_ref(seq_q_off + s, h, d) = o_b(0, s, h, d); } } } From f2fbc44b7bf4dbfb0e8a5f031eb7398e9e218a37 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Mon, 24 Nov 2025 10:20:04 +0000 Subject: [PATCH 04/12] fix --- .../example_unified_attention.cpp | 17 +++++++++-------- .../kernel/unified_attention_kernel.hpp | 18 +++++++++--------- .../pipeline/unified_attention_pipeline.hpp | 9 +++++---- 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 5bc6544746..5d8f3fb435 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -238,14 +238,14 @@ CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, ck_tile::reference_batched_gemm( q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - -1, - 0, - seqlen_q, - seqlen_kv, - true)); + // ck_tile::reference_batched_masking( + // s_host_ref, + // ck_tile::make_generic_attention_mask_from_lr_window( + // -1, + // 0, + // seqlen_q, + // seqlen_kv, + // true)); ck_tile::reference_batched_softmax( s_host_ref, p_host_ref, ck_tile::identity{}); ck_tile::reference_batched_gemm( @@ -526,6 +526,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) ck_tile::HostTensor o(problem.get_output_shape()); o_buf.FromDevice(o.data()); + const auto [rtol, atol] = [&] { if constexpr(std::is_same_v) return std::make_tuple(1e-3, 1e-3); diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 969a9aac82..366a75e2df 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -310,18 +310,18 @@ struct UnifiedAttentionKernel const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * BLOCK_Q); const index_t seq_len = kargs.seq_lens_ptr[seq_idx]; - const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len); + // const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len); - index_t _max_seq_prefix_len = - amd_wave_read_first_lane((context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) - + 1)); + // index_t _max_seq_prefix_len = + // amd_wave_read_first_lane((context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) + // + 1)); - if(seq_len < _max_seq_prefix_len) - { - _max_seq_prefix_len = seq_len; - } + // if(seq_len < _max_seq_prefix_len) + // { + // _max_seq_prefix_len = seq_len; + // } - const auto max_seq_prefix_len = _max_seq_prefix_len; + const auto max_seq_prefix_len = seq_len; // _max_seq_prefix_len; const index_t num_blocks = amd_wave_read_first_lane((max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE); // TODO sliding window diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 3d941f5503..3bb30149bf 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -897,10 +897,11 @@ struct UnifiedAttentionPipeline auto fmha_mask = [&](auto sp_reg_idx) { if constexpr(FmhaMask::IsMasking) { - bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), - i_total_loops * BLOCK_SIZE, - number{}, - number{}); + bool need_perpixel_check = false; + // mask.IsEdgeTile(q_origin.at(number<0>{}), + // i_total_loops * BLOCK_SIZE, + // number{}, + // number{}); if(need_perpixel_check) { set_tile_if(sp(sp_reg_idx).sp_compute, From 76d1866537c8edd804a43b3cc8ff01cb97abc3a1 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Mon, 24 Nov 2025 10:26:26 +0000 Subject: [PATCH 05/12] Pipeline minor fixes --- .../pipeline/unified_attention_pipeline.hpp | 22 ++++++------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 3bb30149bf..5844285ffe 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -410,7 +410,7 @@ struct UnifiedAttentionPipeline HEAD_SIZE_PADDED == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); - static_assert(sizeof(SaccDataType) * BLOCK_SIZE <= GetSmemSize()); + static_assert(sizeof(SaccDataType) * BLOCK_SIZE * BLOCK_M <= GetSmemSize()); auto s_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr)), MakeSimpleLdsDesc()); @@ -426,7 +426,7 @@ struct UnifiedAttentionPipeline auto o_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr)), - MakeSimpleLdsDesc()); + MakeSimpleLdsDesc()); [[maybe_unused]] auto o_lds_window = make_tile_window( o_lds, make_tuple(number{}, number{}), {0, 0}); @@ -542,16 +542,9 @@ struct UnifiedAttentionPipeline clear_tile(l); const auto q_origin = q_dram_window.get_window_origin(); - // const auto [seqlen_k_start, seqlen_k_end] = - // mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, - // number{}); - // const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, - // BLOCK_SIZE); const auto num_total_loop = num_blocks; - // index_t kv_token_start = seqlen_k_start; - // TODO check is paddings kPadSeqLenK // check early exit if no work to do if constexpr(FmhaMask::IsMasking) { @@ -567,20 +560,19 @@ struct UnifiedAttentionPipeline index_t i_total_loops = num_blocks_start; const ck_tile::index_t* block_tables_ptr_ = reinterpret_cast(block_tables_ptr); - index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; - index_t kv_blk_idx_prev = 0; + index_t kv_blk_idx_intial = block_tables_ptr_[block_table_offset + i_total_loops]; auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), k_dram_block_window_tmp.get_window_lengths(), - {(kv_blk_idx - kv_blk_idx_prev) * BLOCK_SIZE, 0}, + {kv_blk_idx_intial * BLOCK_SIZE, 0}, Policy::template MakeKDramTileDistribution()); k_dram_window.init_raw(); auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), - {(kv_blk_idx - kv_blk_idx_prev) * BLOCK_SIZE, 0}, // TODO: hdim split? + {kv_blk_idx_intial * BLOCK_SIZE, 0}, Policy::template MakeVDramTileDistribution()); v_dram_window.init_raw(); @@ -676,6 +668,7 @@ struct UnifiedAttentionPipeline async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); // TODO maybe needs i_total_loops as argument. Or maybe needs to use the k_lds_write_idx // as the index + index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; /// FIXME: use the future-predicting method to move the window k_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0}); }; @@ -686,7 +679,7 @@ struct UnifiedAttentionPipeline auto V_mem_load = [&](auto v_lds_write_idx) { async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); - // kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; + index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; /// FIXME: use the future-predicting method to move the window v_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0}); }; @@ -985,7 +978,6 @@ struct UnifiedAttentionPipeline __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx); - // TODO what is this??? Scheduler::schedule(cl_p, number<1>{}); fmha_mask(xdl_SP_p01_reg_idx); From b3c5cd0c762f77988618b43d6fc59c9298803452 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Mon, 24 Nov 2025 15:32:33 +0000 Subject: [PATCH 06/12] Fixed the block_table --- .../example_unified_attention.cpp | 10 ++--- .../kernel/unified_attention_kernel.hpp | 41 ++++++------------- .../pipeline/unified_attention_pipeline.hpp | 16 +++++--- 3 files changed, 25 insertions(+), 42 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 5d8f3fb435..765c55a5ce 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -35,11 +35,6 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair( k_ptr, - make_tuple(kargs.num_blks, BLOCK_SIZE, HEAD_SIZE), - make_tuple(kargs.stride_k_cache_0, kargs.stride_k_cache_1, kargs.stride_k_cache_3), + make_tuple(kargs.num_blks * BLOCK_SIZE, HEAD_SIZE), + make_tuple(kargs.stride_k_cache_1, kargs.stride_k_cache_3), number{}, number<1>{}); const auto k_dram_pad = pad_tensor_view(k_dram_naive, // TODO can the BLOCK_SIZE_RAW needs padding? - make_tuple(1, BLOCK_SIZE, HEAD_SIZE_PADDED), - sequence{}); + make_tuple(BLOCK_SIZE, HEAD_SIZE_PADDED), + sequence{}); - const auto k_dram_merged = transform_tensor_view( - k_dram_pad, - make_tuple(make_merge_transform(make_tuple(kargs.num_blks, BLOCK_SIZE)), - make_pass_through_transform(HEAD_SIZE_PADDED)), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, - sequence<1>{})); // flattens the first two dims, head idx is the fastest - // changing dim in the merged dim - - return k_dram_merged; + return k_dram_pad; }(); auto k_dram_window = make_tile_window( @@ -422,25 +414,16 @@ struct UnifiedAttentionKernel const auto v_dram = [&]() { const auto v_dram_naive = make_naive_tensor_view( v_ptr, - make_tuple(kargs.num_blks, BLOCK_SIZE, HEAD_SIZE), - make_tuple(kargs.stride_v_cache_0, kargs.stride_v_cache_1, kargs.stride_v_cache_3), + make_tuple(kargs.num_blks * BLOCK_SIZE, HEAD_SIZE), + make_tuple(kargs.stride_v_cache_1, kargs.stride_v_cache_3), number{}, number<1>{}); const auto v_dram_pad = pad_tensor_view(v_dram_naive, - make_tuple(1, BLOCK_SIZE, HEAD_SIZE_PADDED), - sequence{}); + make_tuple(BLOCK_SIZE, HEAD_SIZE_PADDED), + sequence{}); - const auto v_dram_merged = transform_tensor_view( - v_dram_pad, - make_tuple(make_merge_transform(make_tuple(kargs.num_blks, BLOCK_SIZE)), - make_pass_through_transform(HEAD_SIZE_PADDED)), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, - sequence<1>{})); // flattens the first two dims, head idx is the fastest - // changing dim in the merged dim - - return v_dram_merged; + return v_dram_pad; }(); auto v_dram_window = make_tile_window( diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 5844285ffe..3fbfb0cd9e 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -544,6 +544,8 @@ struct UnifiedAttentionPipeline const auto q_origin = q_dram_window.get_window_origin(); const auto num_total_loop = num_blocks; + index_t k_block_table_off = num_blocks_start; + index_t v_block_table_off = num_blocks_start; // check early exit if no work to do if constexpr(FmhaMask::IsMasking) @@ -557,10 +559,11 @@ struct UnifiedAttentionPipeline } } + // TODO check correctness of this index_t i_total_loops = num_blocks_start; const ck_tile::index_t* block_tables_ptr_ = reinterpret_cast(block_tables_ptr); - index_t kv_blk_idx_intial = block_tables_ptr_[block_table_offset + i_total_loops]; + index_t kv_blk_idx_intial = block_tables_ptr_[block_table_offset + k_block_table_off]; auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), @@ -668,7 +671,9 @@ struct UnifiedAttentionPipeline async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); // TODO maybe needs i_total_loops as argument. Or maybe needs to use the k_lds_write_idx // as the index - index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; + + k_block_table_off++; + index_t kv_blk_idx = block_tables_ptr_[block_table_offset + k_block_table_off]; /// FIXME: use the future-predicting method to move the window k_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0}); }; @@ -679,7 +684,9 @@ struct UnifiedAttentionPipeline auto V_mem_load = [&](auto v_lds_write_idx) { async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); - index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; + v_block_table_off++; + + index_t kv_blk_idx = block_tables_ptr_[block_table_offset + v_block_table_off]; /// FIXME: use the future-predicting method to move the window v_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0}); }; @@ -1006,7 +1013,6 @@ struct UnifiedAttentionPipeline cl_load(memV, V_w0_lds_wr_idx, K_w0_lds_rd_idx); Scheduler::schedule(cl_p, number<3>{}); - // kv_token_start += BLOCK_SIZE; if(num_total_loop <= ++i_total_loops) { result = false; @@ -1053,7 +1059,6 @@ struct UnifiedAttentionPipeline Scheduler::schedule(cl_p, number<2>{}); fmha_mask(xdl_SP_p01_reg_idx); - // kv_token_start += BLOCK_SIZE; if(num_total_loop <= ++i_total_loops) { result = false; @@ -1131,7 +1136,6 @@ struct UnifiedAttentionPipeline fmha_alu0(number<0>{}); fmha_alu_D_upd(); - // kv_token_start += BLOCK_SIZE; ++i_total_loops; if(num_total_loop <= i_total_loops) { From cc7caf4d7dfbde0368c5c6c385b462f3aa109b76 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Tue, 25 Nov 2025 09:27:40 +0000 Subject: [PATCH 07/12] correct results --- .../example_unified_attention.cpp | 40 +++++++++++++++---- .../kernel/unified_attention_kernel.hpp | 19 +++++---- .../pipeline/unified_attention_pipeline.hpp | 9 ++--- 3 files changed, 45 insertions(+), 23 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 765c55a5ce..74f0e3f80f 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -233,14 +233,14 @@ CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, ck_tile::reference_batched_gemm( q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); - // ck_tile::reference_batched_masking( - // s_host_ref, - // ck_tile::make_generic_attention_mask_from_lr_window( - // -1, - // 0, - // seqlen_q, - // seqlen_kv, - // true)); + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + -1, + 0, + seqlen_q, + seqlen_kv, + true)); ck_tile::reference_batched_softmax( s_host_ref, p_host_ref, ck_tile::identity{}); ck_tile::reference_batched_gemm( @@ -556,6 +556,30 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) << nonzero << " / " << total << " (" << percent << "%)\n"; + // std::cout << "\n=== Complete Output Tensor (o) ===\n"; + // for (int tok = 0; tok < problem.num_tokens; ++tok) { + // std::cout << "Token " << tok << ":\n"; + // for (int h = 0; h < problem.nhead_q; ++h) { + // std::cout << " Head " << h << ": "; + // for (int d = 0; d < problem.hdim; ++d) { + // std::cout << static_cast(o(tok, h, d)) << " "; + // } + // std::cout << "\n"; + // } + // } + + // std::cout << "\n=== Complete Reference Tensor (o_ref) ===\n"; + // for (int tok = 0; tok < problem.num_tokens; ++tok) { + // std::cout << "Token " << tok << ":\n"; + // for (int h = 0; h < problem.nhead_q; ++h) { + // std::cout << " Head " << h << ": "; + // for (int d = 0; d < problem.hdim; ++d) { + // std::cout << static_cast(o_ref(tok, h, d)) << " "; + // } + // std::cout << "\n"; + // } + // } + return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); } diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 90433da0fb..735a8c4252 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -310,19 +310,18 @@ struct UnifiedAttentionKernel const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * BLOCK_Q); const index_t seq_len = kargs.seq_lens_ptr[seq_idx]; - // const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len); + const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len); - // index_t _max_seq_prefix_len = - // amd_wave_read_first_lane((context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) - // + 1)); + index_t _max_seq_prefix_len = + amd_wave_read_first_lane((context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) + + 1)); - // if(seq_len < _max_seq_prefix_len) - // { - // _max_seq_prefix_len = seq_len; - // } + if(seq_len < _max_seq_prefix_len) + { + _max_seq_prefix_len = seq_len; + } - // const auto max_seq_prefix_len = _max_seq_prefix_len; - const auto max_seq_prefix_len = seq_len; + const auto max_seq_prefix_len = _max_seq_prefix_len; const index_t num_blocks = amd_wave_read_first_lane((max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE); // TODO sliding window diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 3fbfb0cd9e..d661c342f9 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -897,11 +897,10 @@ struct UnifiedAttentionPipeline auto fmha_mask = [&](auto sp_reg_idx) { if constexpr(FmhaMask::IsMasking) { - bool need_perpixel_check = false; - // mask.IsEdgeTile(q_origin.at(number<0>{}), - // i_total_loops * BLOCK_SIZE, - // number{}, - // number{}); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + i_total_loops * BLOCK_SIZE, + number{}, + number{}); if(need_perpixel_check) { set_tile_if(sp(sp_reg_idx).sp_compute, From 6a2ac8f758ac7d314a29ebe3b3c09d052499115b Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 27 Nov 2025 09:16:30 +0000 Subject: [PATCH 08/12] causal mask fix --- .../example_unified_attention.cpp | 47 ++++++++----------- .../unified_attention.hpp | 8 ++++ .../kernel/unified_attention_kernel.hpp | 3 +- .../pipeline/unified_attention_pipeline.hpp | 2 +- 4 files changed, 31 insertions(+), 29 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 74f0e3f80f..a4458eeffc 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -70,13 +70,6 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair; - using GenericMask = ck_tile::GenericAttentionMask; - using CausalMask = ck_tile::GenericAttentionMask; -}; - struct Problem { explicit Problem(const ck_tile::ArgParser& args) @@ -235,12 +228,13 @@ CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, ck_tile::reference_batched_masking( s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( + ck_tile::make_generic_attention_mask_from_lr_window( -1, 0, seqlen_q, seqlen_kv, - true)); + 1, + false)); ck_tile::reference_batched_softmax( s_host_ref, p_host_ref, ck_tile::identity{}); ck_tile::reference_batched_gemm( @@ -557,29 +551,28 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) << " (" << percent << "%)\n"; // std::cout << "\n=== Complete Output Tensor (o) ===\n"; - // for (int tok = 0; tok < problem.num_tokens; ++tok) { - // std::cout << "Token " << tok << ":\n"; - // for (int h = 0; h < problem.nhead_q; ++h) { - // std::cout << " Head " << h << ": "; - // for (int d = 0; d < problem.hdim; ++d) { - // std::cout << static_cast(o(tok, h, d)) << " "; - // } - // std::cout << "\n"; + // for (int tok = 0; tok < problem.num_tokens; ++tok) { + // std::cout << "Token " << tok << ":\n"; + // for (int h = 0; h < problem.nhead_q; ++h) { + // std::cout << " Head " << h << ": "; + // for (int d = 0; d < problem.hdim; ++d) { + // std::cout << static_cast(o(tok, h, d)) << " "; // } + // std::cout << "\n"; // } + // } - // std::cout << "\n=== Complete Reference Tensor (o_ref) ===\n"; - // for (int tok = 0; tok < problem.num_tokens; ++tok) { - // std::cout << "Token " << tok << ":\n"; - // for (int h = 0; h < problem.nhead_q; ++h) { - // std::cout << " Head " << h << ": "; - // for (int d = 0; d < problem.hdim; ++d) { - // std::cout << static_cast(o_ref(tok, h, d)) << " "; - // } - // std::cout << "\n"; + // std::cout << "\n=== Complete Reference Tensor (o_ref) ===\n"; + // for (int tok = 0; tok < problem.num_tokens; ++tok) { + // std::cout << "Token " << tok << ":\n"; + // for (int h = 0; h < problem.nhead_q; ++h) { + // std::cout << " Head " << h << ": "; + // for (int d = 0; d < problem.hdim; ++d) { + // std::cout << static_cast(o_ref(tok, h, d)) << " "; // } + // std::cout << "\n"; // } - + // } return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); } diff --git a/example/ck_tile/01_unified_attention/unified_attention.hpp b/example/ck_tile/01_unified_attention/unified_attention.hpp index f418a4a0d9..ed3e1e6b50 100644 --- a/example/ck_tile/01_unified_attention/unified_attention.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention.hpp @@ -8,6 +8,7 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/host/stream_config.hpp" +#include "ck_tile/ops/unified_attention.hpp" namespace ck_tile { @@ -76,3 +77,10 @@ std::pair unified_attention(const unified_attention_args& args, const stream_config& config); } // namespace ck_tile + +struct UnifiedAttentionMasks +{ + using NoMask = ck_tile::GenericAttentionMask; + using GenericMask = ck_tile::GenericAttentionMask; + using CausalMask = ck_tile::GenericAttentionMask; +}; diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 735a8c4252..2727c563c0 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -435,8 +435,9 @@ struct UnifiedAttentionKernel 0, cur_batch_query_len, // y_total seq_len, // x_total - num_queries_per_kv // the same sequence index is repeated num_queries_per_kv + num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv // times along x dim of the tile + false ); else return FmhaMask{cur_batch_query_len, seq_len}; diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index d661c342f9..105bfcca6e 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -899,7 +899,7 @@ struct UnifiedAttentionPipeline { bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), i_total_loops * BLOCK_SIZE, - number{}, + number{}, number{}); if(need_perpixel_check) { From c641d0d42c38e74b8c463f03e3aca09c9cc81974 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 27 Nov 2025 09:24:52 +0000 Subject: [PATCH 09/12] non zero calculation fix --- .../example_unified_attention.cpp | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index a4458eeffc..401dd90496 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -436,7 +436,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) std::cout << "[" << problem.data_type << "|"; std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv - << ", d:" << problem.hdim << ", mask:" << problem.mask << std::fixed << ", " + << ", d:" << problem.hdim << ", mask:" << "causal mask" << std::fixed << ", " << std::setprecision(8) << time << " ms, " << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << (static_cast(mem) / 1e12 / (time / 1e3)) << " TB/s" << std::endl; @@ -530,13 +530,11 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) size_t nonzero = 0; - for (int b = 0; b < problem.batch; ++b) { - for (int s = 0; s < eff_query_lens[b]; ++s) { - for (int h = 0; h < problem.nhead_q; ++h) { - for (int d = 0; d < problem.hdim; ++d) { - if (static_cast(o(b, s, h, d)) != 0.0f) { + for (int tok = 0; tok < problem.num_tokens; ++tok) { + for (int h = 0; h < problem.nhead_q; ++h) { + for (int d = 0; d < problem.hdim; ++d) { + if (static_cast(o(tok, h, d)) != 0.0f) { nonzero++; - } } } } From eeb419845df84a224e4bbd300c6285481996395c Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 27 Nov 2025 10:32:28 +0000 Subject: [PATCH 10/12] fmha v3 flops calculation --- .../ck_tile/01_fmha/example_fmha_fwd_v3.cpp | 34 ++++++++++++++----- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp index 7ddb65a2db..4bb740217e 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp @@ -463,17 +463,33 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) } std::size_t flop = [&] { - if(problem.mask.type == mask_enum::no_mask) + long flop_result = 0; + + for(int b = 0; b < problem.batch; ++b) { - return 4 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k * - problem.hdim; - } - else - { - /// FIXME: Use a more accurate method; for now, we’re just dividing the flop by 2. - return 2 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k * - problem.hdim; + long query_lens = has_varlen_q ? eff_q_vec[b] : problem.seqlen_q; + long kv_lens = has_varlen_k ? eff_kv_vec[b] : problem.seqlen_k; + long valid_out_elements = 0; + + if(problem.mask.type == mask_enum::no_mask) { + valid_out_elements = kv_lens * query_lens; + } else { + if(query_lens > kv_lens) + { + valid_out_elements = (kv_lens * kv_lens + kv_lens) / 2; + } + else + { + valid_out_elements = + query_lens * kv_lens - ((query_lens * query_lens - query_lens) / 2); + } + + } + // Causal logic for valid output elements + + flop_result += 2 * problem.nhead_q * valid_out_elements * (problem.hdim + problem.hdim); } + return flop_result; }(); float tflops = static_cast(flop) / 1.e9 / time; From 3131ebf1dfb227cddd03eefbc5a60ab10a307b55 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 27 Nov 2025 13:28:35 +0000 Subject: [PATCH 11/12] simplified kernel pid logic --- .../kernel/unified_attention_kernel.hpp | 42 +------------------ 1 file changed, 2 insertions(+), 40 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 2727c563c0..396bd6d2b8 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -200,52 +200,15 @@ struct UnifiedAttentionKernel return left - 1; } - CK_TILE_DEVICE static constexpr auto RemapTileIndices(const ck_tile::index_t pid, - const Kargs& kargs) - { - using namespace ck_tile; - - constexpr index_t NUM_XCDS = 8; - const index_t GRID_MN = kargs.total_num_q_blocks * (kargs.num_head_q / kargs.num_queries_per_kv); - - // Number of pids per XCD in the new arrangement - const index_t pids_per_xcd = (GRID_MN + NUM_XCDS - 1) / NUM_XCDS; - - // When GRID_MN cannot divide NUM_XCDS, some xcds will have - // pids_per_xcd pids, the other will have pids_per_xcd - 1 pids. - // We calculate the number of xcds that have pids_per_xcd pids as tall_xcds - index_t tall_xcds = GRID_MN % NUM_XCDS; - tall_xcds = tall_xcds == 0 ? NUM_XCDS : tall_xcds; - - // Compute current XCD and local pid within the XCD - const index_t xcd = pid % NUM_XCDS; - const index_t local_pid = pid / NUM_XCDS; - - // Calculate new pid based on the new grouping - index_t remapped_pid = 0; // Initialize to avoid constexpr error - if(xcd < tall_xcds) - { - remapped_pid = xcd * pids_per_xcd + local_pid; - } - else - { - remapped_pid = - tall_xcds * pids_per_xcd + (xcd - tall_xcds) * (pids_per_xcd - 1) + local_pid; - } - - return remapped_pid; - } CK_TILE_DEVICE static constexpr auto GetTileIndex(const ck_tile::index_t pid, const Kargs& kargs) { using namespace ck_tile; - ck_tile::index_t total_num_q_blocks = kargs.total_num_q_blocks; - // const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, - // UnifiedAttentionPipeline::kN1); + ck_tile::index_t num_head_kv = kargs.num_head_q / kargs.num_queries_per_kv; - return ck_tile::make_tuple(pid / total_num_q_blocks, pid % total_num_q_blocks); + return ck_tile::make_tuple(pid % num_head_kv, pid / num_head_kv); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } @@ -274,7 +237,6 @@ struct UnifiedAttentionKernel // const index_t num_head_q = kargs.num_head_q; // const index_t num_head_k = num_head_q / num_queries_per_kv; - pid = RemapTileIndices(pid, kargs); // divide problem const auto [kv_head_idx, q_block_global_idx] = GetTileIndex(pid, kargs); From 60ca9484b4471cf86e057a915ad72694d82b2d0d Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 27 Nov 2025 15:07:03 +0000 Subject: [PATCH 12/12] refined benchmarking --- .../example_unified_attention.cpp | 84 ++++++++++++++++++- 1 file changed, 81 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 401dd90496..d4313d79d5 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -35,7 +35,10 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair std::pair std::pair std::pair& query_lens_input, + const std::vector& kv_lens_input, + bool varlen) -> std::pair, std::vector> +{ + // If both query_lens and kv_lens are provided, return them directly + if(!query_lens_input.empty() && !kv_lens_input.empty()) + { + return std::make_pair(query_lens_input, kv_lens_input); + } + + std::vector query_lens; + std::vector kv_lens; + + if(!varlen) + { + // Fixed length mode: fill with max seqlen + query_lens.assign(batch, max_seqlen_q); + kv_lens.assign(batch, max_seqlen_kv); + } + else + { + // Variable length mode: generate random lengths up to max + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution q_dist(1, max_seqlen_q); + std::uniform_int_distribution kv_dist(1, max_seqlen_kv); + + query_lens.resize(batch); + kv_lens.resize(batch); + + for(ck_tile::index_t i = 0; i < batch; ++i) + { + query_lens[i] = q_dist(gen); + kv_lens[i] = kv_dist(gen); + } + } + + return std::make_pair(query_lens, kv_lens); +} + struct Problem { explicit Problem(const ck_tile::ArgParser& args) @@ -82,10 +129,30 @@ struct Problem // TODO: support other GQA/MQA cases than just 4x nhead_q = nhead_kv * num_queries_per_kv; + ck_tile::index_t max_seqlen_q = args.get_int("s"); + ck_tile::index_t max_seqlen_kv = args.get_int("s_k"); + + if (max_seqlen_kv == -1) { + max_seqlen_kv = max_seqlen_q; + } + hdim = args.get_int("d"); query_lens = args.get_int_vec("query_lens"); kv_lens = args.get_int_vec("kv_lens"); assert(query_lens.size() == kv_lens.size() && "query_lens and kv_lens must have the same length b"); + batch = args.get_int("b"); + + bool varlen = args.get_bool("varlen"); + auto [query_lens_, kv_lens_] = seqlen_preprocess( + batch, + max_seqlen_q, + max_seqlen_kv, + query_lens, + kv_lens, + varlen); + + query_lens = query_lens_; + kv_lens = kv_lens_; batch = query_lens.size(); // Calculate scale_s @@ -436,7 +503,18 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) std::cout << "[" << problem.data_type << "|"; std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv - << ", d:" << problem.hdim << ", mask:" << "causal mask" << std::fixed << ", " + << ", d:" << problem.hdim << ", scale_s:" << problem.scale_s + << ", query_lens:["; + for (size_t i = 0; i < problem.query_lens.size(); ++i) { + std::cout << problem.query_lens[i]; + if (i < problem.query_lens.size() - 1) std::cout << ","; + } + std::cout << "], kv_lens:["; + for (size_t i = 0; i < problem.kv_lens.size(); ++i) { + std::cout << problem.kv_lens[i]; + if (i < problem.kv_lens.size() - 1) std::cout << ","; + } + std::cout << "], mask:" << "causal mask" << std::fixed << ", " << std::setprecision(8) << time << " ms, " << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << (static_cast(mem) / 1e12 / (time / 1e3)) << " TB/s" << std::endl;