diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 31bf24fa31..ffceec8aa2 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -59,9 +59,7 @@ struct FmhaFwdV3Kernel const void* v_ptr; // [num_blks, blk_size, num_kv_heads, head_size] void* o_ptr; - ck_tile::index_t hdim_q; - ck_tile::index_t hdim_v; - + ck_tile::index_t num_blks; ck_tile::index_t num_head_q; // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k // if this param is larger than 1, indicate MQA/GQA case @@ -89,7 +87,7 @@ struct FmhaFwdV3Kernel }; - struct UnifiedAttentionVarlenKargs + struct UnifiedAttentionVarlenKargs: UnifiedAttentionCommonKargs { const int32_t* block_tables_ptr; const int32_t* seq_lens_ptr; // seq len in each batch @@ -98,20 +96,15 @@ struct FmhaFwdV3Kernel ck_tile::index_t num_seqs; // number of batches for q }; - struct Kargs { - UnifiedAttentionCommonKargs unifiedAttentionCommonKargs; - UnifiedAttentionVarlenKargs unifiedAttentionVarlenKargs; - }; - // using Kargs = FmhaFwdGroupModeKargs; + using Kargs = UnifiedAttentionVarlenKargs; CK_TILE_HOST static constexpr Kargs MakeKargs( const void* q_ptr, const void* k_ptr, const void* v_ptr, void* o_ptr, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, + ck_tile::index_t num_blks, ck_tile::index_t num_head_q, const ck_tile::index_t num_queries_per_kv, float scale_s, @@ -135,15 +128,14 @@ struct FmhaFwdV3Kernel const int32_t* block_tables_ptr, const int32_t* seq_lens_ptr, const int32_t* query_start_len_ptr, - ck_tile::index_t num_seqs, + ck_tile::index_t num_seqs ) { Kargs kargs{{q_ptr, k_ptr, v_ptr, o_ptr, - hdim_q, - hdim_v, + num_blks, num_head_q, num_queries_per_kv, static_cast(scale_s * ck_tile::log2e_v<>), @@ -221,10 +213,10 @@ struct FmhaFwdV3Kernel RemapTileIndices(const ck_tile::index_t pid, const Kargs& kargs) { using namespace ck_tile; - + constexpr index_t NUM_XCDS = 8; - const index_t GRID_MN = kargs.unifiedAttentionCommonKargs.total_num_q_blocks * - (kargs.unifiedAttentionCommonKargs.num_head_q); + const index_t GRID_MN = kargs.total_num_q_blocks * + (kargs.num_head_q); // Number of pids per XCD in the new arrangement const index_t pids_per_xcd = (GRID_MN + NUM_XCDS - 1) / NUM_XCDS; @@ -259,7 +251,7 @@ struct FmhaFwdV3Kernel { using namespace ck_tile; - ck_tile::index_t total_num_q_blocks = kargs.unifiedAttentionCommonKargs.total_num_q_blocks; + ck_tile::index_t total_num_q_blocks = kargs.total_num_q_blocks; // const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, // FmhaPipeline::kN1); @@ -282,9 +274,12 @@ struct FmhaFwdV3Kernel __shared__ char smem_ptr[GetSmemSize()]; ck_tile::index_t pid = blockIdx.x; - index_t num_queries_per_kv = kargs.unifiedAttentionCommonKargs.num_queries_per_kv; - const index_t BLOCK_M = BLOCK_Q * kargs.unifiedAttentionCommonKargs.num_queries_per_kv; + const index_t BLOCK_M = BLOCK_Q * kargs.num_queries_per_kv; + // for simplicity, batch stride we just modify the pointer + const index_t num_head_q = kargs.num_head_q; + const index_t num_queries_per_kv = kargs.num_queries_per_kv; + const index_t num_head_k = num_head_q / num_queries_per_kv; pid = RemapTileIndices(pid, kargs); @@ -297,15 +292,15 @@ struct FmhaFwdV3Kernel // one q_block spans BLOCK_Q = BLOCK_M // num_queries_per_kv number of query token groups. One query token group shares one kv token const index_t seq_idx = find_seq_idx( - kargs.unifiedAttentionVarlenKargs.query_start_len_ptr, q_block_global_idx, kargs.unifiedAttentionVarlenKargs.num_seqs, kargs.unifiedAttentionCommonKargs.BLOCK_Q, true + kargs.query_start_len_ptr, q_block_global_idx, kargs.num_seqs, BLOCK_Q, true ); // which batch - const index_t q_block_start_idx = amd_wave_read_first_lane(kargs.unifiedAttentionVarlenKargs.query_start_len_ptr[seq_idx]); + const index_t q_block_start_idx = amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]); const index_t q_block_local_idx = amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx); - const index_t cur_batch_in_all_start_index = amd_wave_read_first_lane(kargs.unifiedAttentionVarlenKargs.query_start_len_ptr[seq_idx]); - const index_t cur_batch_in_all_stop_index = amd_wave_read_first_lane(kargs.unifiedAttentionVarlenKargs.query_start_len_ptr[seq_idx + 1]); + const index_t cur_batch_in_all_start_index = amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]); + const index_t cur_batch_in_all_stop_index = amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx + 1]); const index_t cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index; @@ -315,7 +310,7 @@ struct FmhaFwdV3Kernel } const index_t query_pos = q_block_local_idx * BLOCK_Q; - const index_t seq_len = kargs.unifiedAttentionVarlenKargs.seq_lens_ptr[seq_idx]; // should be cu_seqlens_q rather + const index_t seq_len = kargs.seq_lens_ptr[seq_idx]; const index_t context_len = seq_len - cur_batch_query_len; @@ -326,10 +321,6 @@ struct FmhaFwdV3Kernel + 1 ); - // for simplicity, batch stride we just modify the pointer - index_t num_head_q = kargs.unifiedAttentionCommonKargs.num_head_q; - index_t num_queries_per_kv = kargs.unifiedAttentionCommonKargs.num_queries_per_kv; - // Q/K/V DRAM and DRAM window index_t q_ptr_offset_0 = cur_batch_in_all_start_index * kargs.unifiedAttentionCommonKargs.query_stride_0; // move the pointer to the batch start index_t q_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.unifiedAttentionCommonKargs.query_stride_1; // move the pointer to the correct head group start @@ -371,6 +362,7 @@ struct FmhaFwdV3Kernel make_tuple(sequence<0>{}, sequence<1>{}) ); + // TODO are we padding the tensor view or the block here? return q_dram_merged; }(); @@ -385,7 +377,7 @@ struct FmhaFwdV3Kernel const auto k_dram = [&]() { const auto k_dram_naive = make_naive_tensor_view( k_ptr, - make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(num_b, BLOCK_SIZE, num_head_k, HEAD_SIZE), make_tuple(kargs.stride_k, 1), number{}, number<1>{});