diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index fdf9c7122a..50eac35c3f 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -343,8 +343,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) }; calculate_cumulative(eff_query_lens, cu_query_lens); - ck_tile::DeviceMem seq_lens_buf(eff_kv_lens.size()); - ck_tile::DeviceMem query_start_len_buf(cu_query_lens.size()); + ck_tile::DeviceMem seq_lens_buf(eff_kv_lens.size() * sizeof(ck_tile::index_t)); + ck_tile::DeviceMem query_start_len_buf(cu_query_lens.size() * sizeof(ck_tile::index_t)); seq_lens_buf.ToDevice(eff_kv_lens.data()); query_start_len_buf.ToDevice(cu_query_lens.data()); @@ -525,31 +525,40 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) ck_tile::HostTensor o(problem.get_output_shape()); o_buf.FromDevice(o.data()); - // const auto [rtol, atol] = [&] { - // if constexpr(std::is_same_v) - // return std::make_tuple(1e-3, 1e-3); - // else - // return std::make_tuple(1e-2, 1e-2); - // }(); + const auto [rtol, atol] = [&] { + if constexpr(std::is_same_v) + return std::make_tuple(1e-3, 1e-3); + else + return std::make_tuple(1e-2, 1e-2); + }(); - // Print some of the output data for debugging - std::cout << "\nFirst few elements of output tensor o:" << std::endl; - for(int b = 0; b < std::min(2, static_cast(problem.batch)); ++b) { - std::cout << "Batch " << b << ":" << std::endl; - for(int s = 0; s < std::min(5, static_cast(eff_query_lens[b])); ++s) { - for(int h = 0; h < std::min(2, static_cast(problem.nhead_q)); ++h) { - for(int d = 0; d < std::min(4, static_cast(problem.hdim)); ++d) { - std::cout << "o[" << b << "][" << s << "][" << h << "][" << d << "] = " - << static_cast(o(b, s, h, d)) - << std::endl; + size_t total = static_cast(problem.num_tokens) * + static_cast(problem.nhead_q) * + static_cast(problem.hdim); + + size_t nonzero = 0; + + for (int b = 0; b < problem.batch; ++b) { + for (int s = 0; s < eff_query_lens[b]; ++s) { + for (int h = 0; h < problem.nhead_q; ++h) { + for (int d = 0; d < problem.hdim; ++d) { + if (static_cast(o(b, s, h, d)) != 0.0f) { + nonzero++; + } } } } } - - - - return 1; // ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); + + float percent = (total > 0) + ? (100.0f * static_cast(nonzero) / static_cast(total)) + : 0.0f; + + std::cout << "\nNon-zero elements in output tensor o: " + << nonzero << " / " << total + << " (" << percent << "%)\n"; + + return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); } int main(int argc, char* argv[]) diff --git a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp index 8b1536f52a..855c99f841 100644 --- a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp @@ -124,7 +124,7 @@ float unified_attention_kernel_launch(const unified_attention_args& args, "argument num_queries_per_kv must equal compiled num_queries_per_kv"); assert(args.BLOCK_SIZE == Kernel::BLOCK_SIZE && "argument BLOCK_SIZE must equal compiled BLOCK_SIZE"); - assert(BLOCK_Q == args.num_head_q / args.num_queries_per_kv && + assert(BLOCK_Q == BLOCK_M / args.num_queries_per_kv && "BLOCK_Q must equal BLOCK_M / num_queries_per_kv"); index_t total_num_q_blocks = args.num_tokens / BLOCK_Q + args.num_seqs; auto kargs = Kernel::MakeKargs(args.q_ptr, diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 1d34ee5670..969a9aac82 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -184,7 +184,7 @@ struct UnifiedAttentionKernel while(left < right) { ck_tile::index_t mid = (left + right) / 2; - ck_tile::index_t val = query_start_len_ptr[mid]; + ck_tile::index_t val = amd_wave_read_first_lane(query_start_len_ptr[mid]); ck_tile::index_t mid_val = use_q_block_mode ? (val / block_q + mid) : val; if(mid_val <= target_idx) @@ -206,7 +206,7 @@ struct UnifiedAttentionKernel using namespace ck_tile; constexpr index_t NUM_XCDS = 8; - const index_t GRID_MN = kargs.total_num_q_blocks * (kargs.num_head_q); + const index_t GRID_MN = kargs.total_num_q_blocks * (kargs.num_head_q / kargs.num_queries_per_kv); // Number of pids per XCD in the new arrangement const index_t pids_per_xcd = (GRID_MN + NUM_XCDS - 1) / NUM_XCDS; @@ -245,10 +245,7 @@ struct UnifiedAttentionKernel // const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, // UnifiedAttentionPipeline::kN1); - const index_t i_tile_m = pid % total_num_q_blocks; // Query block index - const index_t i_tile_n = pid / total_num_q_blocks; // Head index - - return ck_tile::make_tuple(i_tile_m, i_tile_n); + return ck_tile::make_tuple(pid / total_num_q_blocks, pid % total_num_q_blocks); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } @@ -277,7 +274,6 @@ struct UnifiedAttentionKernel // const index_t num_head_q = kargs.num_head_q; // const index_t num_head_k = num_head_q / num_queries_per_kv; - pid = RemapTileIndices(pid, kargs); // divide problem @@ -295,19 +291,15 @@ struct UnifiedAttentionKernel BLOCK_Q, true); // which batch - const index_t q_block_start_idx = - amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]); + const index_t q_block_start_idx = kargs.query_start_len_ptr[seq_idx] / BLOCK_Q + seq_idx; - const index_t q_block_local_idx = - amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx); + const index_t q_block_local_idx = amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx); - const index_t cur_batch_in_all_start_index = - amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]); - const index_t cur_batch_in_all_stop_index = - amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx + 1]); + const index_t cur_batch_in_all_start_index = kargs.query_start_len_ptr[seq_idx]; + const index_t cur_batch_in_all_stop_index = kargs.query_start_len_ptr[seq_idx + 1]; const index_t cur_batch_query_len = - cur_batch_in_all_stop_index - cur_batch_in_all_start_index; + amd_wave_read_first_lane(cur_batch_in_all_stop_index - cur_batch_in_all_start_index); // TODO check if we get the block size info from pipeline if(q_block_local_idx * BLOCK_Q >= cur_batch_query_len) @@ -315,14 +307,14 @@ struct UnifiedAttentionKernel return; } - const index_t query_pos = q_block_local_idx * BLOCK_Q; + const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * BLOCK_Q); const index_t seq_len = kargs.seq_lens_ptr[seq_idx]; - const index_t context_len = seq_len - cur_batch_query_len; + const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len); index_t _max_seq_prefix_len = - (context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) - + 1); + amd_wave_read_first_lane((context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) + + 1)); if(seq_len < _max_seq_prefix_len) { @@ -330,7 +322,7 @@ struct UnifiedAttentionKernel } const auto max_seq_prefix_len = _max_seq_prefix_len; - const index_t num_blocks = (max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + const index_t num_blocks = amd_wave_read_first_lane((max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE); // TODO sliding window const index_t num_blocks_start = 0; @@ -357,7 +349,7 @@ struct UnifiedAttentionKernel const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + kv_head_offset; ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + o_ptr_offset; - index_t query_len_padded = integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q; + index_t query_len_padded = amd_wave_read_first_lane(integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q); // const bool is_query_len_padded = (cur_batch_query_len % BLOCK_Q == 0); // Q/K/V DRAM and DRAM window @@ -367,20 +359,20 @@ struct UnifiedAttentionKernel make_tuple(cur_batch_query_len, num_queries_per_kv, HEAD_SIZE), make_tuple(kargs.query_stride_0, kargs.query_stride_1, 1), number{}, - number<2>{}); + number<1>{}); const auto q_dram_pad = pad_tensor_view( // aling seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED q_dram_base, // block sizes - make_tuple(number{}, number<1>{}, number{}), + make_tuple(number{}, 1, HEAD_SIZE_PADDED), sequence{}); // pads to (seq_len_padded, num_head_q, // HEAD_SIZE_PADDED) const auto q_dram_merged = transform_tensor_view( q_dram_pad, make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)), - make_pass_through_transform(number{})), + make_pass_through_transform(HEAD_SIZE_PADDED)), make_tuple(sequence<0, 1>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1>{})); // flattens the first two dims, head idx is the fastest diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index a575230ef6..3d941f5503 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -6,7 +6,6 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" - #define ENABLE_ASM_MARKER 1 #if ENABLE_ASM_MARKER #define ASM_MARKER(marker) \