diff --git a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp index 074d69c2f9..0c1b625500 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp @@ -471,9 +471,12 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) long kv_lens = has_varlen_k ? eff_kv_vec[b] : problem.seqlen_k; long valid_out_elements = 0; - if(problem.mask.type == mask_enum::no_mask) { + if(problem.mask.type == mask_enum::no_mask) + { valid_out_elements = kv_lens * query_lens; - } else { + } + else + { if(query_lens > kv_lens) { valid_out_elements = (kv_lens * kv_lens + kv_lens) / 2; @@ -483,7 +486,6 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) valid_out_elements = query_lens * kv_lens - ((query_lens * query_lens - query_lens) / 2); } - } // Causal logic for valid output elements diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 2b538a8e97..e43a8df76e 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -34,7 +34,10 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair std::pair& query_lens_input, - const std::vector& kv_lens_input, - bool varlen) -> std::pair, std::vector> + ck_tile::index_t max_seqlen_q, + ck_tile::index_t max_seqlen_kv, + const std::vector& query_lens_input, + const std::vector& kv_lens_input, + bool varlen) -> std::pair, std::vector> { // If both query_lens and kv_lens are provided, return them directly if(!query_lens_input.empty() && !kv_lens_input.empty()) @@ -107,11 +110,11 @@ auto seqlen_preprocess(ck_tile::index_t batch, query_lens.resize(batch); kv_lens.resize(batch); - + for(ck_tile::index_t i = 0; i < batch; ++i) { query_lens[i] = q_dist(gen); - kv_lens[i] = kv_dist(gen); + kv_lens[i] = kv_dist(gen); } } @@ -131,31 +134,27 @@ struct Problem nhead_q = nhead_kv * num_queries_per_kv; ck_tile::index_t max_seqlen_q = args.get_int("s"); - ck_tile::index_t max_seqlen_kv = args.get_int("s_k"); + ck_tile::index_t max_seqlen_kv = args.get_int("s_k"); - if (max_seqlen_kv == -1) { + if(max_seqlen_kv == -1) + { max_seqlen_kv = max_seqlen_q; } - + hdim = args.get_int("d"); query_lens = args.get_int_vec("query_lens"); kv_lens = args.get_int_vec("kv_lens"); - assert(query_lens.size() == kv_lens.size() && "query_lens and kv_lens must have the same length b"); - batch = args.get_int("b"); + assert(query_lens.size() == kv_lens.size() && + "query_lens and kv_lens must have the same length b"); + batch = args.get_int("b"); page_blk_size = args.get_int("page_blk_size"); - bool varlen = args.get_bool("varlen"); - auto [query_lens_, kv_lens_] = seqlen_preprocess( - batch, - max_seqlen_q, - max_seqlen_kv, - query_lens, - kv_lens, - varlen); + auto [query_lens_, kv_lens_] = + seqlen_preprocess(batch, max_seqlen_q, max_seqlen_kv, query_lens, kv_lens, varlen); query_lens = query_lens_; - kv_lens = kv_lens_; + kv_lens = kv_lens_; batch = query_lens.size(); // Calculate scale_s @@ -164,9 +163,9 @@ struct Problem scale_s = 1.0f / ck_tile::sqrt(static_cast(hdim)); // Initialize other scales - scale = args.get_float("scale"); - scale_k = args.get_float("scale_k"); - scale_v = args.get_float("scale_v"); + scale = args.get_float("scale"); + scale_k = args.get_float("scale_k"); + scale_v = args.get_float("scale_v"); num_tokens = 0; for(const auto& len : query_lens) { @@ -300,17 +299,12 @@ CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, ck_tile::reference_batched_masking( s_host_ref, ck_tile::make_generic_attention_mask_from_lr_window( - -1, - 0, - seqlen_q, - seqlen_kv, - 1, - false)); + -1, 0, seqlen_q, seqlen_kv, 1, false)); ck_tile::reference_batched_softmax( s_host_ref, p_host_ref, ck_tile::identity{}); ck_tile::reference_batched_gemm( p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op); - + // copy resulting per-batch data to the output tensor o_host_ref.ForEach( [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); }); @@ -342,7 +336,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) args.num_seqs = problem.batch; args.num_head_q = problem.nhead_q; args.num_queries_per_kv = num_queries_per_kv; - args.page_blk_size = problem.page_blk_size; + args.page_blk_size = problem.page_blk_size; args.mask_type = 2; args.hdim = problem.hdim; @@ -428,7 +422,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) ck_tile::index_t max_kv_len = max_element(eff_kv_lens); - ck_tile::index_t max_num_blocks_per_seq = (max_kv_len + problem.page_blk_size - 1) / problem.page_blk_size; + ck_tile::index_t max_num_blocks_per_seq = + (max_kv_len + problem.page_blk_size - 1) / problem.page_blk_size; // Create block_tables ck_tile::DeviceMem block_tables_buf(problem.batch * max_num_blocks_per_seq * @@ -506,20 +501,22 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) std::cout << "[" << problem.data_type << "|"; std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv - << ", d:" << problem.hdim << ", scale_s:" << problem.scale_s - << ", query_lens:["; - for (size_t i = 0; i < problem.query_lens.size(); ++i) { + << ", d:" << problem.hdim << ", scale_s:" << problem.scale_s << ", query_lens:["; + for(size_t i = 0; i < problem.query_lens.size(); ++i) + { std::cout << problem.query_lens[i]; - if (i < problem.query_lens.size() - 1) std::cout << ","; + if(i < problem.query_lens.size() - 1) + std::cout << ","; } std::cout << "], kv_lens:["; - for (size_t i = 0; i < problem.kv_lens.size(); ++i) { + for(size_t i = 0; i < problem.kv_lens.size(); ++i) + { std::cout << problem.kv_lens[i]; - if (i < problem.kv_lens.size() - 1) std::cout << ","; + if(i < problem.kv_lens.size() - 1) + std::cout << ","; } - std::cout << "], mask:" << "causal mask" << std::fixed << ", " - << std::setprecision(8) << time << " ms, " << std::setprecision(2) << tflops - << " TFlops, " << std::setprecision(2) + std::cout << "], mask:" << "causal mask" << std::fixed << ", " << std::setprecision(8) << time + << " ms, " << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << (static_cast(mem) / 1e12 / (time / 1e3)) << " TB/s" << std::endl; if(!run_config.verify) @@ -597,37 +594,37 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) ck_tile::HostTensor o(problem.get_output_shape()); o_buf.FromDevice(o.data()); - const auto [rtol, atol] = [&] { if constexpr(std::is_same_v) return std::make_tuple(1e-3, 1e-3); else return std::make_tuple(1e-2, 1e-2); }(); - - size_t total = static_cast(problem.num_tokens) * - static_cast(problem.nhead_q) * + + size_t total = static_cast(problem.num_tokens) * static_cast(problem.nhead_q) * static_cast(problem.hdim); size_t nonzero = 0; - for (int tok = 0; tok < problem.num_tokens; ++tok) { - for (int h = 0; h < problem.nhead_q; ++h) { - for (int d = 0; d < problem.hdim; ++d) { - if (static_cast(o(tok, h, d)) != 0.0f) { - nonzero++; + for(int tok = 0; tok < problem.num_tokens; ++tok) + { + for(int h = 0; h < problem.nhead_q; ++h) + { + for(int d = 0; d < problem.hdim; ++d) + { + if(static_cast(o(tok, h, d)) != 0.0f) + { + nonzero++; } } } } - float percent = (total > 0) - ? (100.0f * static_cast(nonzero) / static_cast(total)) - : 0.0f; + float percent = + (total > 0) ? (100.0f * static_cast(nonzero) / static_cast(total)) : 0.0f; - std::cout << "\nNon-zero elements in output tensor o: " - << nonzero << " / " << total - << " (" << percent << "%)\n"; + std::cout << "\nNon-zero elements in output tensor o: " << nonzero << " / " << total << " (" + << percent << "%)\n"; // std::cout << "\n=== Complete Output Tensor (o) ===\n"; // for (int tok = 0; tok < problem.num_tokens; ++tok) { @@ -652,7 +649,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) // std::cout << "\n"; // } // } - return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); + return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); } int main(int argc, char* argv[]) diff --git a/example/ck_tile/01_unified_attention/unified_attention.hpp b/example/ck_tile/01_unified_attention/unified_attention.hpp index 3a36690158..64f340c556 100644 --- a/example/ck_tile/01_unified_attention/unified_attention.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention.hpp @@ -30,7 +30,7 @@ struct unified_attention_args index_t num_head_q; index_t num_queries_per_kv; index_t page_blk_size; - //index_t BLOCK_SIZE; + // index_t BLOCK_SIZE; index_t hdim; // TODO window diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 4b983d503a..480d0f3ee9 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -204,7 +204,6 @@ struct UnifiedAttentionKernel return left - 1; } - CK_TILE_DEVICE static constexpr auto GetTileIndex(const ck_tile::index_t pid, const Kargs& kargs) { @@ -259,13 +258,14 @@ struct UnifiedAttentionKernel const index_t q_block_start_idx = kargs.query_start_len_ptr[seq_idx] / BLOCK_Q + seq_idx; - const index_t q_block_local_idx = amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx); + const index_t q_block_local_idx = + amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx); const index_t cur_batch_in_all_start_index = kargs.query_start_len_ptr[seq_idx]; - const index_t cur_batch_in_all_stop_index = kargs.query_start_len_ptr[seq_idx + 1]; + const index_t cur_batch_in_all_stop_index = kargs.query_start_len_ptr[seq_idx + 1]; const index_t cur_batch_query_len = - amd_wave_read_first_lane(cur_batch_in_all_stop_index - cur_batch_in_all_start_index); + amd_wave_read_first_lane(cur_batch_in_all_stop_index - cur_batch_in_all_start_index); // TODO check if we get the block size info from pipeline if(q_block_local_idx * BLOCK_Q >= cur_batch_query_len) @@ -276,11 +276,10 @@ struct UnifiedAttentionKernel const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * BLOCK_Q); const index_t seq_len = kargs.seq_lens_ptr[seq_idx]; - const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len); + const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len); - index_t _max_seq_prefix_len = - amd_wave_read_first_lane((context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) - + 1)); + index_t _max_seq_prefix_len = amd_wave_read_first_lane( + (context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) + 1)); if(seq_len < _max_seq_prefix_len) { @@ -288,7 +287,8 @@ struct UnifiedAttentionKernel } const auto max_seq_prefix_len = _max_seq_prefix_len; - const index_t num_blocks = amd_wave_read_first_lane((max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE); + const index_t num_blocks = + amd_wave_read_first_lane((max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE); // TODO sliding window const index_t num_blocks_start = 0; @@ -315,7 +315,8 @@ struct UnifiedAttentionKernel const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + kv_head_offset; ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + o_ptr_offset; - index_t query_len_padded = amd_wave_read_first_lane(integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q); + index_t query_len_padded = + amd_wave_read_first_lane(integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q); // const bool is_query_len_padded = (cur_batch_query_len % BLOCK_Q == 0); // Q/K/V DRAM and DRAM window @@ -397,21 +398,20 @@ struct UnifiedAttentionKernel FmhaMask mask = [&]() { if constexpr(kHasMask) return ck_tile::make_generic_attention_mask_from_lr_window( - -1, + -1, 0, cur_batch_query_len, // y_total - seq_len, // x_total - num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv - // times along x dim of the tile - false - ); + seq_len, // x_total + num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv + // times along x dim of the tile + false); else return FmhaMask{cur_batch_query_len, seq_len}; }(); const index_t kv_page_size_in_blocks = kargs.page_blk_size / BLOCK_SIZE; assert(kv_page_size_in_blocks >= 1); // BLOCK_SIZE <= page_blk_size - + auto o_acc_tile = [&]() { return UnifiedAttentionPipeline{}(q_dram_window, k_dram_window, diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 4cb5637c3a..6cc8ee954a 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -545,9 +545,8 @@ struct UnifiedAttentionPipeline const auto q_origin = q_dram_window.get_window_origin(); const auto num_total_loop = num_blocks; - index_t k_block_table_off = num_blocks_start; - index_t v_block_table_off = num_blocks_start; - + index_t k_block_idx = 0; + index_t v_block_idx = 0; // check early exit if no work to do if constexpr(FmhaMask::IsMasking) @@ -562,12 +561,13 @@ struct UnifiedAttentionPipeline } // TODO check correctness of this - index_t i_total_loops = num_blocks_start; + index_t i_total_loops = num_blocks_start; const index_t PAGE_BLOCK_SIZE = kv_page_size_in_blocks * BLOCK_SIZE; const ck_tile::index_t* block_tables_ptr_ = reinterpret_cast(block_tables_ptr); - assert(k_block_table_off == v_block_table_off); // because of the following line - index_t kv_blk_idx_initial = block_tables_ptr_[block_table_offset + k_block_table_off]; + assert(k_block_idx == v_block_idx); // because of the following line + block_table_offset += num_blocks_start; + index_t kv_blk_idx_initial = block_tables_ptr_[block_table_offset + k_block_idx]; auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), @@ -672,56 +672,35 @@ struct UnifiedAttentionPipeline constexpr int V_mem_su_ld_insts = v_dram_window.get_num_of_access(); // Page block index tracking - // const index_t kv_page_size_in_blocks = + // const index_t kv_page_size_in_blocks = // PAGE_BLOCK_SIZE / BLOCK_SIZE; - index_t k_block_i_inside_page = 0; - index_t v_block_i_inside_page = 0; + // index_t kv_block_idx = 0; + // only for block 0 and thread + if(blockIdx.x == 0 && threadIdx.x == 0) {} auto K_mem_load = [&](auto k_lds_write_idx) { async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); - // prefetch next K tile (only if not at the end of loop) - if (k_block_table_off * kv_page_size_in_blocks + k_block_i_inside_page + 1 >= num_total_loop) - { - return; - } - // Update block index inside the page - ++k_block_i_inside_page; - if(k_block_i_inside_page < kv_page_size_in_blocks) - { - // Staying inside the page, just move the window - move_tile_window(k_dram_window, {BLOCK_SIZE, 0}); - } - else - { - // Moving outside the page, fetch new physical page index - k_block_table_off++; - index_t k_page_blk_idx = amd_wave_read_first_lane(block_tables_ptr_[block_table_offset + k_block_table_off]); - k_dram_window.set_window_origin({k_page_blk_idx * PAGE_BLOCK_SIZE, 0}); - k_block_i_inside_page = 0; - } + k_block_idx++; + + index_t k_page_blk_idx = + block_tables_ptr_[block_table_offset + (k_block_idx / kv_page_size_in_blocks)]; + k_dram_window.set_window_origin( + {k_page_blk_idx * PAGE_BLOCK_SIZE + + (k_block_idx % kv_page_size_in_blocks) * BLOCK_SIZE, + 0}); }; auto V_mem_load = [&](auto v_lds_write_idx) { async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); // prefetch next V tile (only if not at the end of loop) - if (v_block_table_off * kv_page_size_in_blocks + v_block_i_inside_page + 1 >= num_total_loop) - { - return; - } - // Update the block index inside the page - ++v_block_i_inside_page; - if(v_block_i_inside_page < kv_page_size_in_blocks) - { - // Staying inside the page, just move the window - move_tile_window(v_dram_window, {BLOCK_SIZE, 0}); - } - else - { - // Moving outside the page, fetch new physical page index - v_block_table_off++; - index_t v_page_blk_idx = amd_wave_read_first_lane(block_tables_ptr_[block_table_offset + v_block_table_off]); - v_dram_window.set_window_origin({v_page_blk_idx * PAGE_BLOCK_SIZE, 0}); - v_block_i_inside_page = 0; - } + v_block_idx++; + + index_t v_page_blk_idx = + block_tables_ptr_[block_table_offset + (v_block_idx / kv_page_size_in_blocks)]; + v_dram_window.set_window_origin( + {v_page_blk_idx * PAGE_BLOCK_SIZE + + (v_block_idx % kv_page_size_in_blocks) * BLOCK_SIZE, + 0}); + // we assume that v load is always after k }; auto K_lds_load = [&](auto k_lds_read_idx) {