From 5d2a9e5f16d60fa97412a60a7ecd8cd347135b30 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Mon, 17 Nov 2025 09:46:31 +0000 Subject: [PATCH] deving the test... --- .../example_unified_attention.cpp | 363 +++++++++--------- 1 file changed, 186 insertions(+), 177 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index d34d8d9593..4f58a13c51 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -73,6 +73,13 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair; + using GenericMask = ck_tile::GenericAttentionMask; + using CausalMask = ck_tile::GenericAttentionMask; +}; + enum class TensorLayout { bhsd, @@ -204,105 +211,105 @@ auto generate_qkv(const Problem& problem, return std::make_tuple(q, k, v); } -// namespace host { -// template -// CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, -// const ck_tile::HostTensor& k_bshd, -// const ck_tile::HostTensor& v_bshd, -// const mask_info& mask, -// ck_tile::HostTensor& o_bshd, -// const QElementOp& q_element_op = {}, -// const KElementOp& k_element_op = {}, -// const VElementOp& v_element_op = {}, -// const SAccElementOp& s_acc_element_op = {}) -// { -// const int batch_size = q_bshd.mDesc.get_lengths()[0]; -// const int seqlen_q = q_bshd.mDesc.get_lengths()[1]; -// const int seqlen_kv = k_bshd.mDesc.get_lengths()[1]; -// const int nhead_q = q_bshd.mDesc.get_lengths()[2]; -// const int nhead_kv = k_bshd.mDesc.get_lengths()[2]; -// const int hdim_qk = q_bshd.mDesc.get_lengths()[3]; -// const int hdim_v = v_bshd.mDesc.get_lengths()[3]; +namespace host { +template +CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, + const ck_tile::HostTensor& k_bshd, + const ck_tile::HostTensor& v_bshd, + const mask_info& mask, + ck_tile::HostTensor& o_bshd, + const QElementOp& q_element_op = {}, + const KElementOp& k_element_op = {}, + const VElementOp& v_element_op = {}, + const SAccElementOp& s_acc_element_op = {}) +{ +const int batch_size = q_bshd.mDesc.get_lengths()[0]; +const int seqlen_q = q_bshd.mDesc.get_lengths()[1]; +const int seqlen_kv = k_bshd.mDesc.get_lengths()[1]; +const int nhead_q = q_bshd.mDesc.get_lengths()[2]; +const int nhead_kv = k_bshd.mDesc.get_lengths()[2]; +const int hdim_qk = q_bshd.mDesc.get_lengths()[3]; +const int hdim_v = v_bshd.mDesc.get_lengths()[3]; -// const int nr = nhead_q / nhead_kv; +const int nr = nhead_q / nhead_kv; -// ck_tile::HostTensor q_host_ref({nhead_q, seqlen_q, hdim_qk}); -// ck_tile::HostTensor k_host_ref({nhead_q, seqlen_kv, hdim_qk}); -// ck_tile::HostTensor v_host_ref({nhead_q, hdim_v, seqlen_kv}); -// ck_tile::HostTensor o_host_ref({nhead_q, seqlen_q, hdim_v}); +ck_tile::HostTensor q_host_ref({nhead_q, seqlen_q, hdim_qk}); +ck_tile::HostTensor k_host_ref({nhead_q, seqlen_kv, hdim_qk}); +ck_tile::HostTensor v_host_ref({nhead_q, hdim_v, seqlen_kv}); +ck_tile::HostTensor o_host_ref({nhead_q, seqlen_q, hdim_v}); -// ck_tile::HostTensor s_host_ref({nhead_q, seqlen_q, seqlen_kv}); -// ck_tile::HostTensor p_host_ref({nhead_q, seqlen_q, seqlen_kv}); +ck_tile::HostTensor s_host_ref({nhead_q, seqlen_q, seqlen_kv}); +ck_tile::HostTensor p_host_ref({nhead_q, seqlen_q, seqlen_kv}); -// // do computation for each batch -// for(int b = 0; b < batch_size; ++b) -// { -// // copy per-batch data from input tensors -// // clang-format off -// q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , -// idx[2]); }); k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], -// idx[0] / nr, idx[2]); }); v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = -// v_bshd(b, idx[2], idx[0] / nr, idx[1]); }); -// // clang-format on -// ck_tile::reference_batched_gemm( -// q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); +// do computation for each batch +for(int b = 0; b < batch_size; ++b) +{ + // copy per-batch data from input tensors + // clang-format off + q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , + idx[2]); }); k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], + idx[0] / nr, idx[2]); }); v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = + v_bshd(b, idx[2], idx[0] / nr, idx[1]); }); + // clang-format on + ck_tile::reference_batched_gemm( + q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); -// if(mask.type == mask_enum::no_mask) -// { -// ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv}); -// } -// else if(mask.type == mask_enum::window_generic) -// { -// ck_tile::reference_batched_masking( -// s_host_ref, -// ck_tile::make_generic_attention_mask_from_lr_window( -// mask.left, mask.right, seqlen_q, seqlen_kv)); -// } -// else -// { -// // if left window size is negative, means causal -// // else means generic (for current batch) -// if(mask.left < 0) -// ck_tile::reference_batched_masking( -// s_host_ref, -// ck_tile::make_generic_attention_mask_from_lr_window( -// mask.left, -// mask.right, -// seqlen_q, -// seqlen_kv, -// mask.type == mask_enum::mask_top_left)); -// else -// ck_tile::reference_batched_masking( -// s_host_ref, -// ck_tile::make_generic_attention_mask_from_lr_window( -// mask.left, -// mask.right, -// seqlen_q, -// seqlen_kv, -// mask.type == mask_enum::mask_top_left)); -// } + if(mask.type == mask_enum::no_mask) + { + ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv}); + } + else if(mask.type == mask_enum::window_generic) + { + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, mask.right, seqlen_q, seqlen_kv)); + } + else + { + // if left window size is negative, means causal + // else means generic (for current batch) + if(mask.left < 0) + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + seqlen_q, + seqlen_kv, + mask.type == mask_enum::mask_top_left)); + else + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + seqlen_q, + seqlen_kv, + mask.type == mask_enum::mask_top_left)); + } -// ck_tile::reference_batched_softmax( -// s_host_ref, p_host_ref, ck_tile::identity{}); + ck_tile::reference_batched_softmax( + s_host_ref, p_host_ref, ck_tile::identity{}); -// ck_tile::reference_batched_gemm( -// p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op); + ck_tile::reference_batched_gemm( + p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op); -// // copy resulting per-batch data to the output tensor -// o_host_ref.ForEach( -// [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); }); -// } -// } -// } // namespace host + // copy resulting per-batch data to the output tensor + o_host_ref.ForEach( + [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); }); +} +} +} // namespace host template bool run_impl(const Problem& problem, const RunConfig& run_config) @@ -455,111 +462,113 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) return false; } - // std::size_t flop = [&] { - // if(problem.mask.type == mask_enum::no_mask) - // { - // return 4 * args.num_tokens * problem.nhead_q * - // problem.hdim; - // } - // else - // { - // /// FIXME: Use a more accurate method; for now, we’re just dividing the flop by 2. - // return 2 * args.num_tokens * problem.nhead_q * - // problem.hdim; - // } - // }(); + std::size_t flop = [&] { + if(problem.mask.type == mask_enum::no_mask) + { + return 4 * args.num_tokens * problem.nhead_q * + problem.hdim; + } + else + { + /// FIXME: Use a more accurate method; for now, we’re just dividing the flop by 2. + return 2 * args.num_tokens * problem.nhead_q * + problem.hdim; + } + }(); // TODO fix this // std::size_t flop = 1; - // float tflops = static_cast(flop) / 1.e9 / time; + float tflops = static_cast(flop) / 1.e9 / time; - // std::cout << "[" << problem.data_type << "|"; - // std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv - // << ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim - // << ", scale_s:" << problem.sacle_s << ", mask:" << problem.mask << std::fixed - // << ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) << - // tflops - // << " TFlops" << std::endl; + std::cout << "[" << problem.data_type << "|"; + std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv + << ", d:" << problem.hdim + << ", mask:" << problem.mask << std::fixed + << ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) << + tflops + << " TFlops" << std::endl; // if(!run_config.verify) // { // return true; // } - // transpose tensor descriptors from bhsd to bshd if necessary - // if(problem.input_layout != TensorLayout::bshd) - // { - // q = q.transpose({0, 2, 1, 3}); - // k = k.transpose({0, 2, 1, 3}); - // v = v.transpose({0, 2, 1, 3}); - // } - - // ck_tile::HostTensor o_ref(problem.get_output_shape()); - // if(problem.output_layout != TensorLayout::bshd) - // { - // o_ref = o_ref.transpose({0, 2, 1, 3}); - // } - - // If variable lengths are provided, compute per-batch references + // variable lengths are provided -> compute per-batch references // with the effective lengths; else compute a single full reference. // Variable-length aware verification: zero-fill padded region and only compute valid part. - // o_ref.SetZero(); + ck_tile::HostTensor o_ref(problem.get_output_shape()); + o_ref.SetZero(); - // for(int b = 0; b < problem.batch; ++b) - // { - // const ck_tile::index_t seqlen_q_eff = eff_q_vec[b]; - // const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b]; + for(int b = 0; b < problem.batch; ++b) + { + const ck_tile::index_t seqlen_q_eff = eff_query_lens[b]; + const ck_tile::index_t seqlen_kv_eff = eff_kv_lens[b]; - // if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0) - // continue; + if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0) + continue; - // // Slice current batch from inputs (bshd) and build single-batch tensors - // ck_tile::HostTensor q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); - // ck_tile::HostTensor k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); - // ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); - // ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + // Slice current batch from inputs (bshd) and build single-batch tensors + ck_tile::HostTensor q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + ck_tile::HostTensor k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); + ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); + ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); - // // Copy effective region - // q_b.ForEach([&](auto& self, auto idx) { - // // idx: [0, s, h, d] - // self(idx) = q(b, idx[1], idx[2], idx[3]); - // }); - // k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); }); - // v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); }); + // Copy effective region + q_b.ForEach([&](auto& self, auto idx) { + // idx: [0, s, h, d] + self(idx) = q(b, idx[1], idx[2], idx[3]); + }); + k_b.ForEach([&](auto& self, auto idx) { + // kv cache is paged + ck_tile::index_t table_col = int(idx[1] / problem.BLOCK_SIZE); + ck_tile::index_t block_table_offset = b * max_num_blocks_per_seq + table_col; + ck_tile::index_t block_idx = block_tables_host[block_table_offset]; + + self(idx) = k(block_idx, idx[1] % problem.BLOCK_SIZE, idx[2], idx[3]); + + }); + v_b.ForEach([&](auto& self, auto idx) { + ck_tile::index_t table_col = int(idx[1] / problem.BLOCK_SIZE); + ck_tile::index_t block_table_offset = b * max_num_blocks_per_seq + table_col; + ck_tile::index_t block_idx = block_tables_host[block_table_offset]; + + self(idx) = v(block_idx, idx[1] % problem.BLOCK_SIZE, idx[2], idx[3]); + }); + // v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); }); - // // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors) - // host::fmha_fwd(q_b, - // k_b, - // v_b, - // problem.mask, - // o_b, - // ck_tile::identity{}, - // ck_tile::identity{}, - // ck_tile::identity{}, - // ck_tile::scales{problem.scale_s}); + // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors) + host::fmha_fwd(q_b, + k_b, + v_b, + problem.mask, + o_b, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales{problem.scale_s}); - // // Scatter into o_ref's bshd descriptor memory - // for(int s = 0; s < seqlen_q_eff; ++s) - // { - // for(int h = 0; h < problem.nhead_q; ++h) - // { - // for(int d = 0; d < problem.hdim; ++d) - // { - // o_ref(b, s, h, d) = o_b(0, s, h, d); - // } - // } - // } - // } + // Scatter into o_ref's bshd descriptor memory + for(int s = 0; s < seqlen_q_eff; ++s) + { + for(int h = 0; h < problem.nhead_q; ++h) + { + for(int d = 0; d < problem.hdim; ++d) + { + o_ref(b, s, h, d) = o_b(0, s, h, d); + } + } + } + } - // ck_tile::HostTensor o(problem.get_output_shape()); - // o_buf.FromDevice(o.data()); + ck_tile::HostTensor o(problem.get_output_shape()); + o_buf.FromDevice(o.data()); - // const auto [rtol, atol] = [&] { - // if constexpr(std::is_same_v) - // return std::make_tuple(1e-3, 1e-3); - // else - // return std::make_tuple(1e-2, 1e-2); - // }(); - // return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); + const auto [rtol, atol] = [&] { + if constexpr(std::is_same_v) + return std::make_tuple(1e-3, 1e-3); + else + return std::make_tuple(1e-2, 1e-2); + }(); + return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); return true; }