From bc6385f389abbd5b744cf4457df5b1f262263aac Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Mon, 13 Oct 2025 10:01:38 +0000 Subject: [PATCH 1/2] Some refactor --- .../kernel/unified_attention_kernel.hpp | 62 +++++++++---------- 1 file changed, 30 insertions(+), 32 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 67d6372c31..bb38df6b26 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -58,9 +58,7 @@ struct FmhaFwdV3Kernel const void* v_ptr; // [num_blks, blk_size, num_kv_heads, head_size] void* o_ptr; - ck_tile::index_t hdim_q; - ck_tile::index_t hdim_v; - + ck_tile::index_t num_blks; ck_tile::index_t num_head_q; // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k // if this param is larger than 1, indicate MQA/GQA case @@ -88,7 +86,7 @@ struct FmhaFwdV3Kernel }; - struct UnifiedAttentionVarlenKargs + struct UnifiedAttentionVarlenKargs: UnifiedAttentionCommonKargs { const int32_t* block_tables_ptr; const int32_t* seq_lens_ptr; // seq len in each batch @@ -97,20 +95,15 @@ struct FmhaFwdV3Kernel ck_tile::index_t num_seqs; // number of batches for q }; - struct Kargs { - UnifiedAttentionCommonKargs unifiedAttentionCommonKargs; - UnifiedAttentionVarlenKargs unifiedAttentionVarlenKargs; - }; - // using Kargs = FmhaFwdGroupModeKargs; + using Kargs = UnifiedAttentionVarlenKargs; CK_TILE_HOST static constexpr Kargs MakeKargs( const void* q_ptr, const void* k_ptr, const void* v_ptr, void* o_ptr, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, + ck_tile::index_t num_blks, ck_tile::index_t num_head_q, const ck_tile::index_t num_queries_per_kv, float scale_s, @@ -134,15 +127,14 @@ struct FmhaFwdV3Kernel const int32_t* block_tables_ptr, const int32_t* seq_lens_ptr, const int32_t* query_start_len_ptr, - ck_tile::index_t num_seqs, + ck_tile::index_t num_seqs ) { Kargs kargs{{q_ptr, k_ptr, v_ptr, o_ptr, - hdim_q, - hdim_v, + num_blks, num_head_q, num_queries_per_kv, static_cast(scale_s * ck_tile::log2e_v<>), @@ -221,9 +213,13 @@ struct FmhaFwdV3Kernel { using namespace ck_tile; + const index_t num_head_q = kargs.num_head_q; + const index_t num_queries_per_kv = kargs.num_queries_per_kv; + const index_t num_head_k = kargs.num_queries_per_kv; + constexpr index_t NUM_XCDS = 8; - const index_t GRID_MN = kargs.unifiedAttentionCommonKargs.total_num_q_blocks * - (kargs.unifiedAttentionCommonKargs.num_head_q); + const index_t GRID_MN = kargs.total_num_q_blocks * + (kargs.num_head_q); // Number of pids per XCD in the new arrangement const index_t pids_per_xcd = (GRID_MN + NUM_XCDS - 1) / NUM_XCDS; @@ -258,7 +254,7 @@ struct FmhaFwdV3Kernel { using namespace ck_tile; - ck_tile::index_t total_num_q_blocks = kargs.unifiedAttentionCommonKargs.total_num_q_blocks; + ck_tile::index_t total_num_q_blocks = kargs.total_num_q_blocks; // const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, // FmhaPipeline::kN1); @@ -281,9 +277,9 @@ struct FmhaFwdV3Kernel __shared__ char smem_ptr[GetSmemSize()]; ck_tile::index_t pid = blockIdx.x; - index_t num_queries_per_kv = kargs.unifiedAttentionCommonKargs.num_queries_per_kv; + index_t num_queries_per_kv = kargs.num_queries_per_kv; - const index_t BLOCK_M = BLOCK_Q * kargs.unifiedAttentionCommonKargs.num_queries_per_kv; + const index_t BLOCK_M = BLOCK_Q * kargs.num_queries_per_kv; pid = RemapTileIndices(pid, kargs); @@ -296,15 +292,15 @@ struct FmhaFwdV3Kernel // one q_block spans BLOCK_Q = BLOCK_M // num_queries_per_kv number of query token groups. One query token group shares one kv token const index_t seq_idx = find_seq_idx( - kargs.unifiedAttentionVarlenKargs.query_start_len_ptr, q_block_global_idx, kargs.unifiedAttentionVarlenKargs.num_seqs, kargs.unifiedAttentionCommonKargs.BLOCK_Q, true + kargs.query_start_len_ptr, q_block_global_idx, kargs.num_seqs, BLOCK_Q, true ); // which batch - const index_t q_block_start_idx = amd_wave_read_first_lane(kargs.unifiedAttentionVarlenKargs.query_start_len_ptr[seq_idx]); + const index_t q_block_start_idx = amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]); const index_t q_block_local_idx = amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx); - const index_t cur_batch_in_all_start_index = amd_wave_read_first_lane(kargs.unifiedAttentionVarlenKargs.query_start_len_ptr[seq_idx]); - const index_t cur_batch_in_all_stop_index = amd_wave_read_first_lane(kargs.unifiedAttentionVarlenKargs.query_start_len_ptr[seq_idx + 1]); + const index_t cur_batch_in_all_start_index = amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]); + const index_t cur_batch_in_all_stop_index = amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx + 1]); const index_t cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index; @@ -314,7 +310,7 @@ struct FmhaFwdV3Kernel } const index_t query_pos = q_block_local_idx * BLOCK_Q; - const index_t seq_len = kargs.unifiedAttentionVarlenKargs.seq_lens_ptr[seq_idx]; + const index_t seq_len = kargs.seq_lens_ptr[seq_idx]; const index_t context_len = seq_len - cur_batch_query_len; @@ -326,21 +322,22 @@ struct FmhaFwdV3Kernel ); // for simplicity, batch stride we just modify the pointer - index_t num_head_q = kargs.unifiedAttentionCommonKargs.num_head_q; - index_t num_queries_per_kv = kargs.unifiedAttentionCommonKargs.num_queries_per_kv; + const index_t num_head_q = kargs.num_head_q; + const index_t num_queries_per_kv = kargs.num_queries_per_kv; + const index_t num_head_k = num_head_q / num_queries_per_kv; // Q/K/V DRAM and DRAM window - const QDataType* q_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.q_ptr); - const KDataType* k_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.k_ptr); - const VDataType* v_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.v_ptr); - ODataType* o_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.o_ptr); + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr); + const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr); + const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr); + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr); // Q/K/V DRAM and DRAM window const auto q_dram = [&]() { const auto q_dram_base = make_naive_tensor_view( q_ptr, make_tuple(seq_len, num_head_q, HEAD_SIZE), - make_tuple(kargs.unifiedAttentionCommonKargs.query_stride_0, kargs.unifiedAttentionCommonKargs.query_stride_1, 1), + make_tuple(kargs.query_stride_0, kargs.query_stride_1, 1), number{}, number<1>{}); @@ -378,6 +375,7 @@ struct FmhaFwdV3Kernel make_tuple(sequence<0>{}, sequence<1>{}) ); + // TODO are we padding the tensor view or the block here? const auto q_dram_pad = pad_tensor_view( // aling cu_seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED q_dram_merged, // block sizes @@ -399,7 +397,7 @@ struct FmhaFwdV3Kernel const auto k_dram = [&]() { const auto k_dram_naive = make_naive_tensor_view( k_ptr, - make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(num_b, BLOCK_SIZE, num_head_k, HEAD_SIZE), make_tuple(kargs.stride_k, 1), number{}, number<1>{}); From 36a65b19687444c3b55d1b4cdcf2eb61d7f37198 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Mon, 13 Oct 2025 10:05:23 +0000 Subject: [PATCH 2/2] refactor --- .../kernel/unified_attention_kernel.hpp | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index bb38df6b26..eef8ca4f79 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -212,10 +212,6 @@ struct FmhaFwdV3Kernel RemapTileIndices(const ck_tile::index_t pid, const Kargs& kargs) { using namespace ck_tile; - - const index_t num_head_q = kargs.num_head_q; - const index_t num_queries_per_kv = kargs.num_queries_per_kv; - const index_t num_head_k = kargs.num_queries_per_kv; constexpr index_t NUM_XCDS = 8; const index_t GRID_MN = kargs.total_num_q_blocks * @@ -277,9 +273,12 @@ struct FmhaFwdV3Kernel __shared__ char smem_ptr[GetSmemSize()]; ck_tile::index_t pid = blockIdx.x; - index_t num_queries_per_kv = kargs.num_queries_per_kv; const index_t BLOCK_M = BLOCK_Q * kargs.num_queries_per_kv; + // for simplicity, batch stride we just modify the pointer + const index_t num_head_q = kargs.num_head_q; + const index_t num_queries_per_kv = kargs.num_queries_per_kv; + const index_t num_head_k = num_head_q / num_queries_per_kv; pid = RemapTileIndices(pid, kargs); @@ -321,11 +320,6 @@ struct FmhaFwdV3Kernel + 1 ); - // for simplicity, batch stride we just modify the pointer - const index_t num_head_q = kargs.num_head_q; - const index_t num_queries_per_kv = kargs.num_queries_per_kv; - const index_t num_head_k = num_head_q / num_queries_per_kv; - // Q/K/V DRAM and DRAM window const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr); const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr);