From f2fbc44b7bf4dbfb0e8a5f031eb7398e9e218a37 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Mon, 24 Nov 2025 10:20:04 +0000 Subject: [PATCH] fix --- .../example_unified_attention.cpp | 17 +++++++++-------- .../kernel/unified_attention_kernel.hpp | 18 +++++++++--------- .../pipeline/unified_attention_pipeline.hpp | 9 +++++---- 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 5bc6544746..5d8f3fb435 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -238,14 +238,14 @@ CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, ck_tile::reference_batched_gemm( q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - -1, - 0, - seqlen_q, - seqlen_kv, - true)); + // ck_tile::reference_batched_masking( + // s_host_ref, + // ck_tile::make_generic_attention_mask_from_lr_window( + // -1, + // 0, + // seqlen_q, + // seqlen_kv, + // true)); ck_tile::reference_batched_softmax( s_host_ref, p_host_ref, ck_tile::identity{}); ck_tile::reference_batched_gemm( @@ -526,6 +526,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) ck_tile::HostTensor o(problem.get_output_shape()); o_buf.FromDevice(o.data()); + const auto [rtol, atol] = [&] { if constexpr(std::is_same_v) return std::make_tuple(1e-3, 1e-3); diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 969a9aac82..366a75e2df 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -310,18 +310,18 @@ struct UnifiedAttentionKernel const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * BLOCK_Q); const index_t seq_len = kargs.seq_lens_ptr[seq_idx]; - const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len); + // const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len); - index_t _max_seq_prefix_len = - amd_wave_read_first_lane((context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) - + 1)); + // index_t _max_seq_prefix_len = + // amd_wave_read_first_lane((context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) + // + 1)); - if(seq_len < _max_seq_prefix_len) - { - _max_seq_prefix_len = seq_len; - } + // if(seq_len < _max_seq_prefix_len) + // { + // _max_seq_prefix_len = seq_len; + // } - const auto max_seq_prefix_len = _max_seq_prefix_len; + const auto max_seq_prefix_len = seq_len; // _max_seq_prefix_len; const index_t num_blocks = amd_wave_read_first_lane((max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE); // TODO sliding window diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 3d941f5503..3bb30149bf 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -897,10 +897,11 @@ struct UnifiedAttentionPipeline auto fmha_mask = [&](auto sp_reg_idx) { if constexpr(FmhaMask::IsMasking) { - bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), - i_total_loops * BLOCK_SIZE, - number{}, - number{}); + bool need_perpixel_check = false; + // mask.IsEdgeTile(q_origin.at(number<0>{}), + // i_total_loops * BLOCK_SIZE, + // number{}, + // number{}); if(need_perpixel_check) { set_tile_if(sp(sp_reg_idx).sp_compute,