diff --git a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp index 7ddb65a2db..4bb740217e 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp @@ -463,17 +463,33 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) } std::size_t flop = [&] { - if(problem.mask.type == mask_enum::no_mask) + long flop_result = 0; + + for(int b = 0; b < problem.batch; ++b) { - return 4 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k * - problem.hdim; - } - else - { - /// FIXME: Use a more accurate method; for now, we’re just dividing the flop by 2. - return 2 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k * - problem.hdim; + long query_lens = has_varlen_q ? eff_q_vec[b] : problem.seqlen_q; + long kv_lens = has_varlen_k ? eff_kv_vec[b] : problem.seqlen_k; + long valid_out_elements = 0; + + if(problem.mask.type == mask_enum::no_mask) { + valid_out_elements = kv_lens * query_lens; + } else { + if(query_lens > kv_lens) + { + valid_out_elements = (kv_lens * kv_lens + kv_lens) / 2; + } + else + { + valid_out_elements = + query_lens * kv_lens - ((query_lens * query_lens - query_lens) / 2); + } + + } + // Causal logic for valid output elements + + flop_result += 2 * problem.nhead_q * valid_out_elements * (problem.hdim + problem.hdim); } + return flop_result; }(); float tflops = static_cast(flop) / 1.e9 / time; diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 58d1d27d68..335d23b154 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -32,15 +32,13 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair std::pair std::pair std::pair& query_lens_input, + const std::vector& kv_lens_input, + bool varlen) -> std::pair, std::vector> { - using NoMask = ck_tile::GenericAttentionMask; - using GenericMask = ck_tile::GenericAttentionMask; - using CausalMask = ck_tile::GenericAttentionMask; -}; + // If both query_lens and kv_lens are provided, return them directly + if(!query_lens_input.empty() && !kv_lens_input.empty()) + { + return std::make_pair(query_lens_input, kv_lens_input); + } + + std::vector query_lens; + std::vector kv_lens; + + if(!varlen) + { + // Fixed length mode: fill with max seqlen + query_lens.assign(batch, max_seqlen_q); + kv_lens.assign(batch, max_seqlen_kv); + } + else + { + // Variable length mode: generate random lengths up to max + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution q_dist(1, max_seqlen_q); + std::uniform_int_distribution kv_dist(1, max_seqlen_kv); + + query_lens.resize(batch); + kv_lens.resize(batch); + + for(ck_tile::index_t i = 0; i < batch; ++i) + { + query_lens[i] = q_dist(gen); + kv_lens[i] = kv_dist(gen); + } + } + + return std::make_pair(query_lens, kv_lens); +} struct Problem { @@ -94,10 +129,31 @@ struct Problem // TODO: support other GQA/MQA cases than just 4x nhead_q = nhead_kv * num_queries_per_kv; + ck_tile::index_t max_seqlen_q = args.get_int("s"); + ck_tile::index_t max_seqlen_kv = args.get_int("s_k"); + + if (max_seqlen_kv == -1) { + max_seqlen_kv = max_seqlen_q; + } + hdim = args.get_int("d"); query_lens = args.get_int_vec("query_lens"); kv_lens = args.get_int_vec("kv_lens"); - batch = std::max(query_lens.size(), kv_lens.size()); + assert(query_lens.size() == kv_lens.size() && "query_lens and kv_lens must have the same length b"); + batch = args.get_int("b"); + + bool varlen = args.get_bool("varlen"); + auto [query_lens_, kv_lens_] = seqlen_preprocess( + batch, + max_seqlen_q, + max_seqlen_kv, + query_lens, + kv_lens, + varlen); + + query_lens = query_lens_; + kv_lens = kv_lens_; + batch = query_lens.size(); // Calculate scale_s scale_s = args.get_float("scale_s"); @@ -108,7 +164,7 @@ struct Problem scale = args.get_float("scale"); scale_k = args.get_float("scale_k"); scale_v = args.get_float("scale_v"); - + num_tokens = 0; for(const auto& len : query_lens) { num_tokens += len; @@ -198,7 +254,7 @@ template & q_bshd, const ck_tile::HostTensor& k_bshd, const ck_tile::HostTensor& v_bshd, - const mask_info& mask, + // const mask_info& mask, ck_tile::HostTensor& o_bshd, const QElementOp& q_element_op = {}, const KElementOp& k_element_op = {}, @@ -222,61 +278,35 @@ CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, ck_tile::HostTensor s_host_ref({nhead_q, seqlen_q, seqlen_kv}); ck_tile::HostTensor p_host_ref({nhead_q, seqlen_q, seqlen_kv}); - // do computation for each batch for(int b = 0; b < batch_size; ++b) { // copy per-batch data from input tensors // clang-format off - q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , - idx[2]); }); k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], - idx[0] / nr, idx[2]); }); v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = - v_bshd(b, idx[2], idx[0] / nr, idx[1]); }); + q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , + idx[2]); }); + k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], + idx[0] / nr, idx[2]); }); + v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = + v_bshd(b, idx[2], idx[0] / nr, idx[1]); }); // clang-format on ck_tile::reference_batched_gemm( q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); - if(mask.type == mask_enum::no_mask) - { - ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv}); - } - else if(mask.type == mask_enum::window_generic) - { - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, mask.right, seqlen_q, seqlen_kv)); - } - else - { - // if left window size is negative, means causal - // else means generic (for current batch) - if(mask.left < 0) - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, - mask.right, - seqlen_q, - seqlen_kv, - mask.type == mask_enum::mask_top_left)); - else - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, - mask.right, - seqlen_q, - seqlen_kv, - mask.type == mask_enum::mask_top_left)); - } - + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + -1, + 0, + seqlen_q, + seqlen_kv, + 1, + false)); ck_tile::reference_batched_softmax( s_host_ref, p_host_ref, ck_tile::identity{}); - ck_tile::reference_batched_gemm( p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op); - + // copy resulting per-batch data to the output tensor o_host_ref.ForEach( [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); }); @@ -303,6 +333,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) ck_tile::unified_attention_args args{}; + args.scale_s = problem.scale_s; args.data_type = problem.data_type; args.num_seqs = problem.batch; args.num_head_q = problem.nhead_q; @@ -369,8 +400,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) }; calculate_cumulative(eff_query_lens, cu_query_lens); - ck_tile::DeviceMem seq_lens_buf(eff_kv_lens.size()); - ck_tile::DeviceMem query_start_len_buf(cu_query_lens.size()); + ck_tile::DeviceMem seq_lens_buf(eff_kv_lens.size() * sizeof(ck_tile::index_t)); + ck_tile::DeviceMem query_start_len_buf(cu_query_lens.size() * sizeof(ck_tile::index_t)); seq_lens_buf.ToDevice(eff_kv_lens.data()); query_start_len_buf.ToDevice(cu_query_lens.data()); @@ -428,7 +459,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) if(!result) { - std::cerr << "faild to run fmha_fwd_v3()" << std::endl; + std::cerr << "faild to run unified_attention()" << std::endl; return false; } @@ -471,7 +502,18 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) std::cout << "[" << problem.data_type << "|"; std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv - << ", d:" << problem.hdim << ", mask:" << problem.mask << std::fixed << ", " + << ", d:" << problem.hdim << ", scale_s:" << problem.scale_s + << ", query_lens:["; + for (size_t i = 0; i < problem.query_lens.size(); ++i) { + std::cout << problem.query_lens[i]; + if (i < problem.query_lens.size() - 1) std::cout << ","; + } + std::cout << "], kv_lens:["; + for (size_t i = 0; i < problem.kv_lens.size(); ++i) { + std::cout << problem.kv_lens[i]; + if (i < problem.kv_lens.size() - 1) std::cout << ","; + } + std::cout << "], mask:" << "causal mask" << std::fixed << ", " << std::setprecision(8) << time << " ms, " << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << (static_cast(mem) / 1e12 / (time / 1e3)) << " TB/s" << std::endl; @@ -500,11 +542,12 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) ck_tile::HostTensor k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + ck_tile::index_t seq_q_off = cu_query_lens[b]; // Copy effective region q_b.ForEach([&](auto& self, auto idx) { // idx: [0, s, h, d] - self(idx) = q(b, idx[1], idx[2], idx[3]); + self(idx) = q(seq_q_off + idx[1], idx[2], idx[3]); }); k_b.ForEach([&](auto& self, auto idx) { // kv cache is paged @@ -527,7 +570,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) host::fmha_fwd(q_b, k_b, v_b, - problem.mask, + // problem.mask, o_b, ck_tile::identity{}, ck_tile::identity{}, @@ -541,7 +584,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) { for(int d = 0; d < problem.hdim; ++d) { - o_ref(b, s, h, d) = o_b(0, s, h, d); + o_ref(seq_q_off + s, h, d) = o_b(0, s, h, d); } } } @@ -550,13 +593,62 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) ck_tile::HostTensor o(problem.get_output_shape()); o_buf.FromDevice(o.data()); + const auto [rtol, atol] = [&] { if constexpr(std::is_same_v) return std::make_tuple(1e-3, 1e-3); else return std::make_tuple(1e-2, 1e-2); }(); - return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); + + size_t total = static_cast(problem.num_tokens) * + static_cast(problem.nhead_q) * + static_cast(problem.hdim); + + size_t nonzero = 0; + + for (int tok = 0; tok < problem.num_tokens; ++tok) { + for (int h = 0; h < problem.nhead_q; ++h) { + for (int d = 0; d < problem.hdim; ++d) { + if (static_cast(o(tok, h, d)) != 0.0f) { + nonzero++; + } + } + } + } + + float percent = (total > 0) + ? (100.0f * static_cast(nonzero) / static_cast(total)) + : 0.0f; + + std::cout << "\nNon-zero elements in output tensor o: " + << nonzero << " / " << total + << " (" << percent << "%)\n"; + + // std::cout << "\n=== Complete Output Tensor (o) ===\n"; + // for (int tok = 0; tok < problem.num_tokens; ++tok) { + // std::cout << "Token " << tok << ":\n"; + // for (int h = 0; h < problem.nhead_q; ++h) { + // std::cout << " Head " << h << ": "; + // for (int d = 0; d < problem.hdim; ++d) { + // std::cout << static_cast(o(tok, h, d)) << " "; + // } + // std::cout << "\n"; + // } + // } + + // std::cout << "\n=== Complete Reference Tensor (o_ref) ===\n"; + // for (int tok = 0; tok < problem.num_tokens; ++tok) { + // std::cout << "Token " << tok << ":\n"; + // for (int h = 0; h < problem.nhead_q; ++h) { + // std::cout << " Head " << h << ": "; + // for (int d = 0; d < problem.hdim; ++d) { + // std::cout << static_cast(o_ref(tok, h, d)) << " "; + // } + // std::cout << "\n"; + // } + // } + return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); } int main(int argc, char* argv[]) diff --git a/example/ck_tile/01_unified_attention/unified_attention.hpp b/example/ck_tile/01_unified_attention/unified_attention.hpp index f418a4a0d9..ed3e1e6b50 100644 --- a/example/ck_tile/01_unified_attention/unified_attention.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention.hpp @@ -8,6 +8,7 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/host/stream_config.hpp" +#include "ck_tile/ops/unified_attention.hpp" namespace ck_tile { @@ -76,3 +77,10 @@ std::pair unified_attention(const unified_attention_args& args, const stream_config& config); } // namespace ck_tile + +struct UnifiedAttentionMasks +{ + using NoMask = ck_tile::GenericAttentionMask; + using GenericMask = ck_tile::GenericAttentionMask; + using CausalMask = ck_tile::GenericAttentionMask; +}; diff --git a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp index 8b1536f52a..855c99f841 100644 --- a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp @@ -124,7 +124,7 @@ float unified_attention_kernel_launch(const unified_attention_args& args, "argument num_queries_per_kv must equal compiled num_queries_per_kv"); assert(args.BLOCK_SIZE == Kernel::BLOCK_SIZE && "argument BLOCK_SIZE must equal compiled BLOCK_SIZE"); - assert(BLOCK_Q == args.num_head_q / args.num_queries_per_kv && + assert(BLOCK_Q == BLOCK_M / args.num_queries_per_kv && "BLOCK_Q must equal BLOCK_M / num_queries_per_kv"); index_t total_num_q_blocks = args.num_tokens / BLOCK_Q + args.num_seqs; auto kargs = Kernel::MakeKargs(args.q_ptr, diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 2f1b574655..396bd6d2b8 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -184,7 +184,7 @@ struct UnifiedAttentionKernel while(left < right) { ck_tile::index_t mid = (left + right) / 2; - ck_tile::index_t val = query_start_len_ptr[mid]; + ck_tile::index_t val = amd_wave_read_first_lane(query_start_len_ptr[mid]); ck_tile::index_t mid_val = use_q_block_mode ? (val / block_q + mid) : val; if(mid_val <= target_idx) @@ -200,55 +200,15 @@ struct UnifiedAttentionKernel return left - 1; } - CK_TILE_DEVICE static constexpr auto RemapTileIndices(const ck_tile::index_t pid, - const Kargs& kargs) - { - using namespace ck_tile; - - constexpr index_t NUM_XCDS = 8; - const index_t GRID_MN = kargs.total_num_q_blocks * (kargs.num_head_q); - - // Number of pids per XCD in the new arrangement - const index_t pids_per_xcd = (GRID_MN + NUM_XCDS - 1) / NUM_XCDS; - - // When GRID_MN cannot divide NUM_XCDS, some xcds will have - // pids_per_xcd pids, the other will have pids_per_xcd - 1 pids. - // We calculate the number of xcds that have pids_per_xcd pids as tall_xcds - index_t tall_xcds = GRID_MN % NUM_XCDS; - tall_xcds = tall_xcds == 0 ? NUM_XCDS : tall_xcds; - - // Compute current XCD and local pid within the XCD - const index_t xcd = pid % NUM_XCDS; - const index_t local_pid = pid / NUM_XCDS; - - // Calculate new pid based on the new grouping - index_t remapped_pid = 0; // Initialize to avoid constexpr error - if(xcd < tall_xcds) - { - remapped_pid = xcd * pids_per_xcd + local_pid; - } - else - { - remapped_pid = - tall_xcds * pids_per_xcd + (xcd - tall_xcds) * (pids_per_xcd - 1) + local_pid; - } - - return remapped_pid; - } CK_TILE_DEVICE static constexpr auto GetTileIndex(const ck_tile::index_t pid, const Kargs& kargs) { using namespace ck_tile; - ck_tile::index_t total_num_q_blocks = kargs.total_num_q_blocks; - // const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, - // UnifiedAttentionPipeline::kN1); + ck_tile::index_t num_head_kv = kargs.num_head_q / kargs.num_queries_per_kv; - const index_t i_tile_m = pid % total_num_q_blocks; // Query block index - const index_t i_tile_n = pid / total_num_q_blocks; // Head index - - return ck_tile::make_tuple(i_tile_m, i_tile_n); + return ck_tile::make_tuple(pid % num_head_kv, pid / num_head_kv); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } @@ -278,8 +238,6 @@ struct UnifiedAttentionKernel // const index_t num_head_k = num_head_q / num_queries_per_kv; - pid = RemapTileIndices(pid, kargs); - // divide problem const auto [kv_head_idx, q_block_global_idx] = GetTileIndex(pid, kargs); @@ -295,19 +253,15 @@ struct UnifiedAttentionKernel BLOCK_Q, true); // which batch - const index_t q_block_start_idx = - amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]); + const index_t q_block_start_idx = kargs.query_start_len_ptr[seq_idx] / BLOCK_Q + seq_idx; - const index_t q_block_local_idx = - amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx); + const index_t q_block_local_idx = amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx); - const index_t cur_batch_in_all_start_index = - amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]); - const index_t cur_batch_in_all_stop_index = - amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx + 1]); + const index_t cur_batch_in_all_start_index = kargs.query_start_len_ptr[seq_idx]; + const index_t cur_batch_in_all_stop_index = kargs.query_start_len_ptr[seq_idx + 1]; const index_t cur_batch_query_len = - cur_batch_in_all_stop_index - cur_batch_in_all_start_index; + amd_wave_read_first_lane(cur_batch_in_all_stop_index - cur_batch_in_all_start_index); // TODO check if we get the block size info from pipeline if(q_block_local_idx * BLOCK_Q >= cur_batch_query_len) @@ -315,14 +269,14 @@ struct UnifiedAttentionKernel return; } - const index_t query_pos = q_block_local_idx * BLOCK_Q; + const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * BLOCK_Q); const index_t seq_len = kargs.seq_lens_ptr[seq_idx]; - const index_t context_len = seq_len - cur_batch_query_len; + const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len); index_t _max_seq_prefix_len = - (context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) // num_queries_per_kv - + 1); + amd_wave_read_first_lane((context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) + + 1)); if(seq_len < _max_seq_prefix_len) { @@ -330,7 +284,7 @@ struct UnifiedAttentionKernel } const auto max_seq_prefix_len = _max_seq_prefix_len; - const index_t num_blocks = (max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + const index_t num_blocks = amd_wave_read_first_lane((max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE); // TODO sliding window const index_t num_blocks_start = 0; @@ -357,7 +311,7 @@ struct UnifiedAttentionKernel const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + kv_head_offset; ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + o_ptr_offset; - index_t query_len_padded = integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q; + index_t query_len_padded = amd_wave_read_first_lane(integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q); // const bool is_query_len_padded = (cur_batch_query_len % BLOCK_Q == 0); // Q/K/V DRAM and DRAM window @@ -367,20 +321,20 @@ struct UnifiedAttentionKernel make_tuple(cur_batch_query_len, num_queries_per_kv, HEAD_SIZE), make_tuple(kargs.query_stride_0, kargs.query_stride_1, 1), number{}, - number<2>{}); + number<1>{}); const auto q_dram_pad = pad_tensor_view( // aling seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED q_dram_base, // block sizes - make_tuple(number{}, number<1>{}, number{}), + make_tuple(number{}, 1, HEAD_SIZE_PADDED), sequence{}); // pads to (seq_len_padded, num_head_q, // HEAD_SIZE_PADDED) const auto q_dram_merged = transform_tensor_view( q_dram_pad, make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)), - make_pass_through_transform(number{})), + make_pass_through_transform(HEAD_SIZE_PADDED)), make_tuple(sequence<0, 1>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1>{})); // flattens the first two dims, head idx is the fastest @@ -402,26 +356,17 @@ struct UnifiedAttentionKernel // HEAD dim is skipped as defined in the ptrs const auto k_dram_naive = make_naive_tensor_view( k_ptr, - make_tuple(kargs.num_blks, BLOCK_SIZE, HEAD_SIZE), - make_tuple(kargs.stride_k_cache_0, kargs.stride_k_cache_1, kargs.stride_k_cache_3), + make_tuple(kargs.num_blks * BLOCK_SIZE, HEAD_SIZE), + make_tuple(kargs.stride_k_cache_1, kargs.stride_k_cache_3), number{}, number<1>{}); const auto k_dram_pad = pad_tensor_view(k_dram_naive, // TODO can the BLOCK_SIZE_RAW needs padding? - make_tuple(1, BLOCK_SIZE, HEAD_SIZE_PADDED), - sequence{}); + make_tuple(BLOCK_SIZE, HEAD_SIZE_PADDED), + sequence{}); - const auto k_dram_merged = transform_tensor_view( - k_dram_pad, - make_tuple(make_merge_transform(make_tuple(kargs.num_blks, BLOCK_SIZE)), - make_pass_through_transform(HEAD_SIZE_PADDED)), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, - sequence<1>{})); // flattens the first two dims, head idx is the fastest - // changing dim in the merged dim - - return k_dram_merged; + return k_dram_pad; }(); auto k_dram_window = make_tile_window( @@ -430,25 +375,16 @@ struct UnifiedAttentionKernel const auto v_dram = [&]() { const auto v_dram_naive = make_naive_tensor_view( v_ptr, - make_tuple(kargs.num_blks, BLOCK_SIZE, HEAD_SIZE), - make_tuple(kargs.stride_v_cache_0, kargs.stride_v_cache_1, kargs.stride_v_cache_3), + make_tuple(kargs.num_blks * BLOCK_SIZE, HEAD_SIZE), + make_tuple(kargs.stride_v_cache_1, kargs.stride_v_cache_3), number{}, number<1>{}); const auto v_dram_pad = pad_tensor_view(v_dram_naive, - make_tuple(1, BLOCK_SIZE, HEAD_SIZE_PADDED), - sequence{}); + make_tuple(BLOCK_SIZE, HEAD_SIZE_PADDED), + sequence{}); - const auto v_dram_merged = transform_tensor_view( - v_dram_pad, - make_tuple(make_merge_transform(make_tuple(kargs.num_blks, BLOCK_SIZE)), - make_pass_through_transform(HEAD_SIZE_PADDED)), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, - sequence<1>{})); // flattens the first two dims, head idx is the fastest - // changing dim in the merged dim - - return v_dram_merged; + return v_dram_pad; }(); auto v_dram_window = make_tile_window( @@ -457,12 +393,13 @@ struct UnifiedAttentionKernel FmhaMask mask = [&]() { if constexpr(kHasMask) return ck_tile::make_generic_attention_mask_from_lr_window( - seq_len - cur_batch_query_len, // y (i.e. context) - cur_batch_query_len, // x (i.e. extend) - seq_len, // y_total (x + y) - cur_batch_query_len, // x_total - num_queries_per_kv // the same sequence index is repeated num_queries_per_kv + -1, + 0, + cur_batch_query_len, // y_total + seq_len, // x_total + num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv // times along x dim of the tile + false ); else return FmhaMask{cur_batch_query_len, seq_len}; diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index a575230ef6..105bfcca6e 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -6,7 +6,6 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" - #define ENABLE_ASM_MARKER 1 #if ENABLE_ASM_MARKER #define ASM_MARKER(marker) \ @@ -411,7 +410,7 @@ struct UnifiedAttentionPipeline HEAD_SIZE_PADDED == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); - static_assert(sizeof(SaccDataType) * BLOCK_SIZE <= GetSmemSize()); + static_assert(sizeof(SaccDataType) * BLOCK_SIZE * BLOCK_M <= GetSmemSize()); auto s_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr)), MakeSimpleLdsDesc()); @@ -427,7 +426,7 @@ struct UnifiedAttentionPipeline auto o_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr)), - MakeSimpleLdsDesc()); + MakeSimpleLdsDesc()); [[maybe_unused]] auto o_lds_window = make_tile_window( o_lds, make_tuple(number{}, number{}), {0, 0}); @@ -543,16 +542,11 @@ struct UnifiedAttentionPipeline clear_tile(l); const auto q_origin = q_dram_window.get_window_origin(); - // const auto [seqlen_k_start, seqlen_k_end] = - // mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, - // number{}); - // const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, - // BLOCK_SIZE); const auto num_total_loop = num_blocks; - // index_t kv_token_start = seqlen_k_start; + index_t k_block_table_off = num_blocks_start; + index_t v_block_table_off = num_blocks_start; - // TODO check is paddings kPadSeqLenK // check early exit if no work to do if constexpr(FmhaMask::IsMasking) { @@ -565,23 +559,23 @@ struct UnifiedAttentionPipeline } } + // TODO check correctness of this index_t i_total_loops = num_blocks_start; const ck_tile::index_t* block_tables_ptr_ = reinterpret_cast(block_tables_ptr); - index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; - index_t kv_blk_idx_prev = 0; + index_t kv_blk_idx_intial = block_tables_ptr_[block_table_offset + k_block_table_off]; auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), k_dram_block_window_tmp.get_window_lengths(), - {(kv_blk_idx - kv_blk_idx_prev) * BLOCK_SIZE, 0}, + {kv_blk_idx_intial * BLOCK_SIZE, 0}, Policy::template MakeKDramTileDistribution()); k_dram_window.init_raw(); auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), - {(kv_blk_idx - kv_blk_idx_prev) * BLOCK_SIZE, 0}, // TODO: hdim split? + {kv_blk_idx_intial * BLOCK_SIZE, 0}, Policy::template MakeVDramTileDistribution()); v_dram_window.init_raw(); @@ -677,6 +671,9 @@ struct UnifiedAttentionPipeline async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); // TODO maybe needs i_total_loops as argument. Or maybe needs to use the k_lds_write_idx // as the index + + k_block_table_off++; + index_t kv_blk_idx = block_tables_ptr_[block_table_offset + k_block_table_off]; /// FIXME: use the future-predicting method to move the window k_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0}); }; @@ -687,7 +684,9 @@ struct UnifiedAttentionPipeline auto V_mem_load = [&](auto v_lds_write_idx) { async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); - // kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; + v_block_table_off++; + + index_t kv_blk_idx = block_tables_ptr_[block_table_offset + v_block_table_off]; /// FIXME: use the future-predicting method to move the window v_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0}); }; @@ -900,7 +899,7 @@ struct UnifiedAttentionPipeline { bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), i_total_loops * BLOCK_SIZE, - number{}, + number{}, number{}); if(need_perpixel_check) { @@ -985,7 +984,6 @@ struct UnifiedAttentionPipeline __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx); - // TODO what is this??? Scheduler::schedule(cl_p, number<1>{}); fmha_mask(xdl_SP_p01_reg_idx); @@ -1014,7 +1012,6 @@ struct UnifiedAttentionPipeline cl_load(memV, V_w0_lds_wr_idx, K_w0_lds_rd_idx); Scheduler::schedule(cl_p, number<3>{}); - // kv_token_start += BLOCK_SIZE; if(num_total_loop <= ++i_total_loops) { result = false; @@ -1061,7 +1058,6 @@ struct UnifiedAttentionPipeline Scheduler::schedule(cl_p, number<2>{}); fmha_mask(xdl_SP_p01_reg_idx); - // kv_token_start += BLOCK_SIZE; if(num_total_loop <= ++i_total_loops) { result = false; @@ -1139,7 +1135,6 @@ struct UnifiedAttentionPipeline fmha_alu0(number<0>{}); fmha_alu_D_upd(); - // kv_token_start += BLOCK_SIZE; ++i_total_loops; if(num_total_loop <= i_total_loops) {