From 89cfdb35e05b0ff790b294e072a2a5d1d3b871e6 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 23 Oct 2025 12:02:18 +0000 Subject: [PATCH] Fixed block Q with M --- .../pipeline/unified_attention_pipeline.hpp | 36 ++++++++----------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index ec4b3355de..48f1f4deb8 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -384,7 +384,6 @@ struct UnifiedAttentionPipeline [[maybe_unused]] const KElementFunction& k_element_func, const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile [[maybe_unused]] const VElementFunction& v_element_func, - index_t num_queries_per_kv, const void* block_tables_ptr, index_t block_table_offset, [[maybe_unused]] const SAccElementFunction& s_acc_element_func, @@ -395,16 +394,13 @@ struct UnifiedAttentionPipeline void* smem_ptr) const { using namespace ck_tile; - // TODO do we make the num_queries_per_kv and num_head conexpr??? - const index_t BLOCK_Q = BLOCK_M / num_queries_per_kv; - static_assert( std::is_same_v> && std::is_same_v> && std::is_same_v>, "wrong!"); - static_assert(BLOCK_Q == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + static_assert(BLOCK_M == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && BLOCK_SIZE == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && HEAD_SIZE_PADDED == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && HEAD_SIZE_PADDED == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && @@ -414,29 +410,29 @@ struct UnifiedAttentionPipeline static_assert(sizeof(SaccDataType) * BLOCK_SIZE <= GetSmemSize()); auto s_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr)), - MakeSimpleLdsDesc()); + MakeSimpleLdsDesc()); [[maybe_unused]] auto s_lds_window = - make_tile_window(s_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(s_lds, make_tuple(number{}, number{}), {0, 0}); auto p_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr) + Policy::template GetSmemSize()), - MakeSimpleLdsDesc()); + MakeSimpleLdsDesc()); [[maybe_unused]] auto p_lds_window = - make_tile_window(p_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(p_lds, make_tuple(number{}, number{}), {0, 0}); auto o_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr)), - MakeSimpleLdsDesc()); + MakeSimpleLdsDesc()); [[maybe_unused]] auto o_lds_window = - make_tile_window(o_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(o_lds, make_tuple(number{}, number{}), {0, 0}); auto m_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr) + Policy::template GetSmemSize()), - MakeSimpleLdsDesc1D()); + MakeSimpleLdsDesc1D()); [[maybe_unused]] auto m_lds_window = - make_tile_window(m_lds, make_tuple(number{}), {0}); + make_tile_window(m_lds, make_tuple(number{}), {0}); const index_t warp_group_id = get_warp_id() / 4; @@ -543,7 +539,7 @@ struct UnifiedAttentionPipeline const auto q_origin = q_dram_window.get_window_origin(); const auto [seqlen_k_start, seqlen_k_end] = - mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, BLOCK_SIZE); index_t kv_token_start = seqlen_k_start; @@ -812,7 +808,7 @@ struct UnifiedAttentionPipeline gemm_0(sp(sp_reg_idx).sp_compute, get_slice_tile(q_tile, sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{}, - sequence{}), + sequence{}), get_slice_tile(kv_tile.k_tile, sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{}, sequence{})); @@ -822,7 +818,7 @@ struct UnifiedAttentionPipeline gemm_1(o_acc, get_slice_tile(sp(sp_reg_idx).p, sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{}, - sequence{}), + sequence{}), get_slice_tile(kv_tile.v_tile, sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{}, sequence{})); @@ -836,7 +832,7 @@ struct UnifiedAttentionPipeline gemm_0(sp(sp_reg_idx).sp_compute, get_slice_tile(q_tile, sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{}, - sequence{}), + sequence{}), get_slice_tile(kv_tile.k_tile, sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{}, sequence{})); @@ -846,7 +842,7 @@ struct UnifiedAttentionPipeline gemm_1(o_acc, get_slice_tile(sp(sp_reg_idx).p, sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{}, - sequence{}), + sequence{}), get_slice_tile(kv_tile.v_tile, sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{}, sequence{})); @@ -894,7 +890,7 @@ struct UnifiedAttentionPipeline if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { bool need_perpixel_check = mask.IsEdgeTile( - q_origin.at(number<0>{}), kv_token_start, number{}, number{}); + q_origin.at(number<0>{}), kv_token_start, number{}, number{}); if(need_perpixel_check) { set_tile_if(sp(sp_reg_idx).sp_compute, @@ -1209,7 +1205,6 @@ struct UnifiedAttentionPipeline CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - index_t num_queries_per_kv, const void* block_tables_ptr, index_t block_table_offset, FmhaMask mask, @@ -1224,7 +1219,6 @@ struct UnifiedAttentionPipeline identity{}, v_dram_block_window_tmp, identity{}, - num_queries_per_kv, block_tables_ptr, block_table_offset, identity{},