diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index da0c9ca931..cfb96b7d53 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -259,11 +259,11 @@ class FmhaFwdApiTrait: def skcheck(self) -> str: if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true if self.pipeline_tag == 'qr_async': - if self.skpad == 't' : return f'(a.cu_seqlen_kv_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)' - else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)' + if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' elif self.pipeline_tag in ['qr', 'qs']: if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)' + else : return f'a.seqlen_k % {self.bn0} == 0' elif self.pipeline_tag == 'qr_async_trload': if self.skpad == 't' : return 'true' else: return 'true' diff --git a/example/ck_tile/01_fmha/example_fmha_fwd.cpp b/example/ck_tile/01_fmha/example_fmha_fwd.cpp index 79fda6d564..91cb9f55be 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd.cpp @@ -33,10 +33,6 @@ auto create_args(int argc, char* argv[]) "0", "seqlen_k for new key/value, 0 means not to use this at all; " "-1 to choose s_knew in [1, s] randomly.") - .insert("s_qpad", - "-1", - "seqlen_q stride between 2 batches (group-mode optional).\n" - "Provide positive strides per-batch to simulate physical padding on Q.") .insert("s_kpad", "-1", "seqlen_k stride between 2 batches, currently used in group-mode only\n" @@ -111,15 +107,7 @@ auto create_args(int argc, char* argv[]) .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel") .insert("json", "0", "0: No Json, 1: Dump Results in Json format") - .insert("jsonfile", "fmha_fwd.json", "json file name to dump results") - .insert("q_eff_lens", - "", - "Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n" - "Comma-separated list of length 'b'. If empty, no override.") - .insert("kv_eff_lens", - "", - "Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n" - "Comma-separated list of length 'b'. If empty, no override."); + .insert("jsonfile", "fmha_fwd.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -139,9 +127,6 @@ auto run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); ck_tile::index_t seqlen_knew = arg_parser.get_int("s_knew"); auto seqlen_kpads = arg_parser.get_int_vec("s_kpad"); - auto seqlen_qpads = arg_parser.get_int_vec("s_qpad"); - auto q_eff_lens_per_batch = arg_parser.get_int_vec("q_eff_lens"); - auto kv_eff_lens_per_batch = arg_parser.get_int_vec("kv_eff_lens"); ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim"); bool i_perm = arg_parser.get_bool("iperm"); bool o_perm = arg_parser.get_bool("operm"); @@ -189,10 +174,7 @@ auto run(const ck_tile::ArgParser& arg_parser) hdim_q, hdim_v, seqlen_knew, - seqlen_qpads, seqlen_kpads, - q_eff_lens_per_batch, - kv_eff_lens_per_batch, rotary_dim, i_perm, o_perm, diff --git a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp index 7ddb65a2db..569c98a458 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp @@ -52,16 +52,7 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair get_query_shape() const @@ -183,8 +172,6 @@ struct Problem mask_info mask; TensorLayout input_layout; TensorLayout output_layout; - std::vector q_eff_lens; - std::vector kv_eff_lens; }; struct RunConfig @@ -339,10 +326,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) q_buf.ToDevice(q.data()); k_buf.ToDevice(k.data()); v_buf.ToDevice(v.data()); - // Ensure output buffer is zero-initialized so padded regions compare cleanly - o_buf.SetZero(); - ck_tile::fmha_fwd_v3_args args{}; + ck_tile::fmha_fwd_v3_args args; args.data_type = problem.data_type; args.batch = problem.batch; @@ -395,60 +380,6 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) : problem.seqlen_q * problem.hdim; args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim; - // Optional cumulative seqlen overrides (exclude PAD) - const bool has_varlen_q = !problem.q_eff_lens.empty() && problem.q_eff_lens[0] != -1; - const bool has_varlen_k = !problem.kv_eff_lens.empty() && problem.kv_eff_lens[0] != -1; - - auto make_effective_vec = [&](const std::vector& opt_vec, ck_tile::index_t fallback) { - std::vector eff; - if(!opt_vec.empty() && opt_vec[0] != -1) - { - eff.assign(opt_vec.begin(), opt_vec.end()); - if(eff.size() < static_cast(problem.batch)) - { - eff.resize(problem.batch, eff.back()); - } - } - else - { - eff.assign(problem.batch, fallback); - } - return eff; - }; - - const auto eff_q_vec = make_effective_vec(problem.q_eff_lens, problem.seqlen_q); - const auto eff_kv_vec = make_effective_vec(problem.kv_eff_lens, problem.seqlen_k); - - // Calculate cumulative sums for kernel arguments if varlen is used - std::vector cuq_cum, cukv_cum; - auto calculate_cumulative = [&](const std::vector& per_batch_vec, - std::vector& cum_vec) { - cum_vec.resize(per_batch_vec.size() + 1); - cum_vec[0] = 0; - for(std::size_t i = 0; i < per_batch_vec.size(); ++i) - cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i]; - }; - - if(has_varlen_q) - { - calculate_cumulative(eff_q_vec, cuq_cum); - } - if(has_varlen_k) - { - calculate_cumulative(eff_kv_vec, cukv_cum); - } - - ck_tile::DeviceMem cuq_buf(!cuq_cum.empty() ? cuq_cum.size() * sizeof(ck_tile::index_t) : 0); - ck_tile::DeviceMem cukv_buf(!cukv_cum.empty() ? cukv_cum.size() * sizeof(ck_tile::index_t) : 0); - cuq_buf.ToDevice(!cuq_cum.empty() ? cuq_cum.data() : nullptr); - cukv_buf.ToDevice(!cukv_cum.empty() ? cukv_cum.data() : nullptr); - args.cu_seqlen_q_ptr = - !cuq_cum.empty() ? reinterpret_cast(cuq_buf.GetDeviceBuffer()) - : nullptr; - args.cu_seqlen_kv_ptr = - !cukv_cum.empty() ? reinterpret_cast(cukv_buf.GetDeviceBuffer()) - : nullptr; - ck_tile::stream_config stream_config{nullptr, true, /*log_level=*/0, @@ -511,72 +442,15 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) o_ref = o_ref.transpose({0, 2, 1, 3}); } - // If variable lengths are provided, compute per-batch references - // with the effective lengths; else compute a single full reference. - if(has_varlen_q || has_varlen_k) - { - // Variable-length aware verification: zero-fill padded region and only compute valid part. - o_ref.SetZero(); - - for(int b = 0; b < problem.batch; ++b) - { - const ck_tile::index_t seqlen_q_eff = eff_q_vec[b]; - const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b]; - - if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0) - continue; - - // Slice current batch from inputs (bshd) and build single-batch tensors - ck_tile::HostTensor q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); - ck_tile::HostTensor k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); - ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); - ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); - - // Copy effective region - q_b.ForEach([&](auto& self, auto idx) { - // idx: [0, s, h, d] - self(idx) = q(b, idx[1], idx[2], idx[3]); - }); - k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); }); - v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); }); - - // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors) - host::fmha_fwd(q_b, - k_b, - v_b, - problem.mask, - o_b, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales{problem.softmax_scale}); - - // Scatter into o_ref's bshd descriptor memory - for(int s = 0; s < seqlen_q_eff; ++s) - { - for(int h = 0; h < problem.nhead_q; ++h) - { - for(int d = 0; d < problem.hdim; ++d) - { - o_ref(b, s, h, d) = o_b(0, s, h, d); - } - } - } - } - } - else - { - // No varlen override: compute the full reference once - host::fmha_fwd(q, - k, - v, - problem.mask, - o_ref, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales{problem.softmax_scale}); - } + host::fmha_fwd(q, + k, + v, + problem.mask, + o_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales{problem.softmax_scale}); ck_tile::HostTensor o(problem.get_output_shape()); o_buf.FromDevice(o.data()); diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index f5dd42a6bd..c41e48e6aa 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -162,20 +162,11 @@ struct fmha_fwd_args void* lse_ptr; void* o_ptr; - // Optional cumulative sequence length arrays - // Batch mode: cu_seqlen_* override effective per-batch lengths (exclude PAD) - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] - const void* seqstart_q_ptr; const void* seqstart_k_ptr; const void* seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr - // Group mode: seqstart_padded_* provide physical starts including PAD (optional) - const void* seqstart_padded_q_ptr = nullptr; // [batch+1] - const void* seqstart_padded_k_ptr = nullptr; // [batch+1] - ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; ck_tile::index_t batch; @@ -563,9 +554,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.min_seqlen_q, args.p_drop, args.s_randval, - args.drop_seed_offset, - args.seqstart_padded_q_ptr, - args.seqstart_padded_k_ptr); + args.drop_seed_offset); } else { // create batch mode kernel arguments @@ -611,9 +600,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.mask_type, args.p_drop, args.s_randval, - args.drop_seed_offset, - args.cu_seqlen_q_ptr, - args.cu_seqlen_kv_ptr); + args.drop_seed_offset); } }(); diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index cb5827975e..43f484fe14 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -151,10 +151,7 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t seqlen_knew, - std::vector seqlen_qpads, std::vector seqlen_kpads, - std::vector q_eff_lens_per_batch, - std::vector kv_eff_lens_per_batch, ck_tile::index_t rotary_dim, bool i_perm, bool o_perm, @@ -365,44 +362,6 @@ fwd_result fmha_fwd_run(mode_enum mode, const auto seqstart_k_host = to_seqstarts(seqlen_ks); const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads); - // Optional padded Q seqstarts (group-mode only) - std::vector seqstart_q_with_padding_host; - if(mode == mode_enum::group && !seqlen_qpads.empty() && seqlen_qpads[0] != -1) - { - if(seqlen_qpads.size() < static_cast(batch)) - { - seqlen_qpads.resize(batch, seqlen_qpads.back()); - } - if(seqlen_qpads.size() == static_cast(batch)) - { - seqstart_q_with_padding_host = to_seqstarts( - ck_tile::span(seqlen_qpads.data(), seqlen_qpads.size())); - } - } - - // Optional batch-mode cumulative seqlen overrides - std::vector cuq_cum, cukv_cum; - if(mode == mode_enum::batch) - { - auto calculate_cumulative = [&](std::vector& per_batch_vec, - std::vector& cum_vec) { - if(!per_batch_vec.empty() && per_batch_vec[0] != -1) - { - if(per_batch_vec.size() < static_cast(batch)) - { - per_batch_vec.resize(batch, per_batch_vec.back()); - } - cum_vec.resize(batch + 1); - cum_vec[0] = 0; - for(int i = 0; i < batch; ++i) - cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i]; - } - }; - - calculate_cumulative(q_eff_lens_per_batch, cuq_cum); - calculate_cumulative(kv_eff_lens_per_batch, cukv_cum); - } - using TypeConfig = FmhaFwdTypeConfig; using QDataType = typename TypeConfig::QDataType; @@ -486,15 +445,8 @@ fwd_result fmha_fwd_run(mode_enum mode, // host memory for storing all the tensor elements const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); - // logical(unpadded) total seqlen_q for group; batch uses fixed seqlen - const ck_tile::index_t shape_seqlen_q_lse = - (mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back()); - // physical(padded) total seqlen_q for group when s_qpad is provided; else use logical const ck_tile::index_t shape_seqlen_q = - (mode == mode_enum::batch - ? seqlen_qs[0] - : (seqstart_q_with_padding_host.empty() ? seqstart_q_host.back() - : seqstart_q_with_padding_host.back())); + (mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back()); const ck_tile::index_t shape_seqlen_k = (mode == mode_enum::batch ? seqlen_ks[0] : (seqlen_kpads[0] < 0 ? seqstart_k_host.back() @@ -552,7 +504,7 @@ fwd_result fmha_fwd_run(mode_enum mode, // batch mode of lse data layout is [batch, nhead, seqlen_q] // group mode of lse data layout is [nhead, total_seqlen_q] ck_tile::HostTensor lse_host( - lse ? std::array{shape_batch, nhead, shape_seqlen_q_lse} + lse ? std::array{shape_batch, nhead, shape_seqlen_q} : std::array{1, 1, 1} /* dummy shape for simplifying code */); ck_tile::HostTensor o_host( @@ -650,16 +602,6 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); - ck_tile::DeviceMem seqstart_q_padded_buf(seqstart_q_with_padding_host.empty() - ? 0 - : seqstart_q_with_padding_host.size() * - sizeof(int32_t)); - ck_tile::DeviceMem seqstart_k_padded_buf( - seqlen_kpads[0] < 0 ? 0 : seqstart_k_with_padding_host.size() * sizeof(int32_t)); - ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0 - : cuq_cum.size() * sizeof(ck_tile::index_t)); - ck_tile::DeviceMem cu_seqlen_kv_buf( - cukv_cum.empty() ? 0 : cukv_cum.size() * sizeof(ck_tile::index_t)); ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0] ? seqlen_ks.size() * sizeof(int32_t) @@ -751,14 +693,8 @@ fwd_result fmha_fwd_run(mode_enum mode, vnew_buf.ToDevice(vnew_host.data()); bias_buf.ToDevice(bias_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); - // Keep logical starts in seqstart_k; pass padded K via separate pointer - seqstart_k.ToDevice(seqstart_k_host.data()); - seqstart_q_padded_buf.ToDevice( - seqstart_q_with_padding_host.empty() ? nullptr : seqstart_q_with_padding_host.data()); - seqstart_k_padded_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr - : seqstart_k_with_padding_host.data()); - cu_seqlen_q_buf.ToDevice(cuq_cum.empty() ? nullptr : cuq_cum.data()); - cu_seqlen_kv_buf.ToDevice(cukv_cum.empty() ? nullptr : cukv_cum.data()); + seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data() + : seqstart_k_with_padding_host.data()); seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0] ? seqlen_ks.data() : nullptr); @@ -894,8 +830,8 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t nhead_stride_lse = shape_seqlen_q_lse; - const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q_lse); + const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; + const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments @@ -910,8 +846,8 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew); const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q_lse); - const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q_lse); + const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); + const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q); const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); @@ -1025,29 +961,6 @@ fwd_result fmha_fwd_run(mode_enum mode, { args.drop_seed_offset = std::make_pair(drop_seed, drop_offset); } - - // Group-mode: optional physical padded starts for Q/K - if(mode == mode_enum::group) - { - args.seqstart_padded_q_ptr = (seqstart_q_with_padding_host.empty() - ? nullptr - : seqstart_q_padded_buf.GetDeviceBuffer()); - args.seqstart_padded_k_ptr = - (seqlen_kpads[0] < 0 ? nullptr : seqstart_k_padded_buf.GetDeviceBuffer()); - } - - // Batch-mode: optional cumulative effective seqlen overrides - if(mode == mode_enum::batch) - { - args.cu_seqlen_q_ptr = cuq_cum.empty() - ? nullptr - : reinterpret_cast( - cu_seqlen_q_buf.GetDeviceBuffer()); - args.cu_seqlen_kv_ptr = cukv_cum.empty() - ? nullptr - : reinterpret_cast( - cu_seqlen_kv_buf.GetDeviceBuffer()); - } } else if constexpr(std::is_same_v>) { @@ -1254,29 +1167,15 @@ fwd_result fmha_fwd_run(mode_enum mode, for(ck_tile::index_t wb = 0; wb < batch; ++wb) { - ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; - ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; - if(mode == mode_enum::batch) - { - if(!cuq_cum.empty()) - { - real_seqlen_q = cuq_cum[wb + 1] - cuq_cum[wb]; - } - if(!cukv_cum.empty()) - { - real_seqlen_k = cukv_cum[wb + 1] - cukv_cum[wb]; - } - } + const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; // adjust matrix index according to the mode const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0); const ck_tile::index_t cache_b_idx = (use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx); const ck_tile::index_t query_offset = - (mode == mode_enum::batch - ? 0 - : (seqstart_q_with_padding_host.empty() ? seqstart_q_host[wb] - : seqstart_q_with_padding_host[wb])); + (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 @@ -1639,10 +1538,8 @@ fwd_result fmha_fwd_run(mode_enum mode, if(lse) { ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); - const ck_tile::index_t query_offset_lse = - (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); lse_host_result.ForEach([&](auto& self, auto idx) { - self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset_lse); + self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset); }); cur_pass = ck_tile::check_err(lse_host_result, diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp index 4bd1d1a367..10cb5149a4 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp @@ -56,11 +56,6 @@ struct fmha_fwd_v3_args index_t stride_o; index_t nhead_stride_o; index_t batch_stride_o; - - // Optional batch-mode cumulative seqlen overrides (exclude PAD) - // If provided, they override per-batch effective lengths to skip tail padding. - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] }; std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type); diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp index 194675f962..e0fbad39a5 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp @@ -158,9 +158,7 @@ float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_confi args.window_size_left, args.window_size_right, args.mask_type, - remap_opt, - args.cu_seqlen_q_ptr, - args.cu_seqlen_kv_ptr); + remap_opt); dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v); constexpr dim3 blocks = Kernel::BlockSize(); diff --git a/example/ck_tile/01_fmha/script/benchmark_fwd.sh b/example/ck_tile/01_fmha/script/benchmark_fwd.sh index 31ad800039..88c16cceb6 100755 --- a/example/ck_tile/01_fmha/script/benchmark_fwd.sh +++ b/example/ck_tile/01_fmha/script/benchmark_fwd.sh @@ -18,36 +18,3 @@ $EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kn done done done - -#Padding Benchmarks: batch mode (baseline vs low/med/high pad) -prec="fp16" -base_batch_args="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID" - -# baseline (no pad) -$EXE $base_batch_args - -# low pad (≈90–95% effective) -$EXE $base_batch_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896 - -# medium pad (≈60–75% effective) -$EXE $base_batch_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640 - -# high pad (≈30–40% effective) -$EXE $base_batch_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320 - -# Padding Benchmarks: group mode (baseline vs low/med/high physical pad) -seqlens_q="1024,768,512,256" -seqlens_k="1024,768,512,256" -base_group_args="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID" - -# baseline (no physical pad) -$EXE $base_group_args - -# low physical pad -$EXE $base_group_args -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320 - -# medium physical pad -$EXE $base_group_args -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384 - -# high physical pad -$EXE $base_group_args -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512 diff --git a/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh b/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh index a3f7d68eb3..b847e85398 100755 --- a/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh +++ b/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh @@ -23,20 +23,3 @@ done done done done - -# Padding benchmark comparisons for v3 (batch mode only) -# ==== V3 Padding Benchmarks: batch mode (baseline vs low/med/high pad) ==== -prec="fp16" -base_v3_args="-prec=$prec -b=4 -h=16 -d=128 -s=1024 -mask=0 -iperm=0 -operm=0 -v=$VALID" - -# baseline (no pad) -$EXE $base_v3_args - -# low pad (≈90–95% effective) -$EXE $base_v3_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896 - -# medium pad (≈60–75% effective) -$EXE $base_v3_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640 - -# high pad (≈30–40% effective) -$EXE $base_v3_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320 diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index fca6b8d0cd..afd0c728c6 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -137,118 +137,9 @@ run_fp16_appendkv_tests() { done ; done ; done } -run_padding_smoke_tests() { - # Padding-only smoke tests for batch/group mode using COMMON_ARGS - local prec="fp16" - - # Batch mode: padding via effective lengths (exclude PAD) - # Use lse=1 to select a non-trload kernel and avoid overly strict tolerance mismatches - local base_batch="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=1 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS" - # low pad (≈90–95% effective) - $EXE $base_batch -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896 - # medium pad (≈60–75% effective) - $EXE $base_batch -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640 - # high pad (≈30–40% effective) - $EXE $base_batch -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320 - - # Group mode: padding via physical stride along seqlen - local seqlens_q="1024,768,512,256" - local seqlens_k="1024,768,512,256" - local base_group="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS" - # low physical pad - $EXE $base_group -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320 - # medium physical pad - $EXE $base_group -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384 - # high physical pad - $EXE $base_group -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512 -} - -run_padding_basic_boundary_tests() { - # Basic padding and boundary tests (reference: smoke_test_fwd_pad.sh) - local prec - local perm - - # Group mode: Q&K padded with per-batch different strides - for prec in fp16 bf16 ; do - for perm in 0 1 ; do - $EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=16 -d_v=32 \ - -s=55 -s_k=256 -s_qpad=64,60 -s_kpad=272,260 \ - -bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \ - -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS - done - done - - # slightly larger, uneven padding strides - for prec in fp16 bf16 ; do - for perm in 0 1 ; do - $EXE -prec=$prec -mode=1 -b=3 -h=2 -h_k=1 -d=64 -d_v=64 \ - -s=50,60,40 -s_k=128,256,192 -s_qpad=64,64,64 -s_kpad=160,288,224 \ - -bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \ - -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS - done - done - - # only K padded; Q unpadded - for prec in fp16 bf16 ; do - for perm in 0 1 ; do - $EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 \ - -s=55 -s_k=256 -s_kpad=272,260 \ - -bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \ - -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS - done - done - - # use cu_seqlen overrides to skip tail PAD - for prec in fp16 bf16 ; do - for perm in 0 1 ; do - $EXE -prec=$prec -mode=0 -b=4 -h=8 -h_k=8 -d=128 -s=3 -s_k=3 \ - -q_eff_lens=1,2,1,2 -kv_eff_lens=1,2,1,2 \ - -bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \ - -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS - - $EXE -prec=$prec -mode=0 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 -s=64 -s_k=256 \ - -q_eff_lens=55,60 -kv_eff_lens=200,256 \ - -bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \ - -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS - done - done - - # no padding (equal), mixed Q/KV, all len=1 - for prec in fp16 bf16 ; do - $EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \ - -q_eff_lens=128,128,128,128 -kv_eff_lens=128,128,128,128 \ - -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS - - $EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \ - -q_eff_lens=10,20,30,40 -kv_eff_lens=40,30,20,10 \ - -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS - - $EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \ - -q_eff_lens=1,1,1,1 -kv_eff_lens=1,1,1,1 \ - -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS - done - - # highly variable logical lengths - for prec in fp16 bf16 ; do - $EXE -prec=$prec -mode=1 -b=4 -h=4 -d=32 \ - -s=1,127,3,65 -s_k=1,127,3,65 -s_kpad=128 \ - -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS - done - - # GQA + Alibi + Causal mask (keep vlayout row-major for fp16/bf16 - for prec in fp16 bf16 ; do - $EXE -prec=$prec -mode=1 -b=2 -h=16 -h_k=4 -d=128 \ - -s=256,129 -s_k=256,129 -s_kpad=256 \ - -bias=a -mask=t -lse=1 -iperm=0 -operm=0 -vlayout=r \ - -kname=$KNAME $COMMON_ARGS - done -} - set -x run_fp16_bf16_tests -run_padding_smoke_tests -run_padding_basic_boundary_tests run_fp8_tests run_fp8bf16_tests run_fp8fp32_tests diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 3f417bc125..58fdad149a 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -291,11 +291,6 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_o; - - // Optional cumulative sequence length pointers for batch mode - // If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding. - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // cumulative, length without PAD }; struct FmhaFwdGroupModeKargs @@ -315,11 +310,6 @@ struct FmhaFwdKernel const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; const int32_t* seqlen_k_ptr; - - // Optional cumulative padded sequence starts (including PAD tokens) - // Used solely to compute memory offsets when sequences are physically padded. - const int32_t* seqstart_padded_q_ptr = nullptr; - const int32_t* seqstart_padded_k_ptr = nullptr; }; using Kargs = std::conditional_t; @@ -470,105 +460,6 @@ struct FmhaFwdKernel return kargs; } - // Overload: Batch mode with optional cu_seqlen pointers (unpadded cumulative lengths) - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargsImpl(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* bias_ptr, - void* rand_val_ptr, - void* lse_ptr, - void* o_ptr, - ck_tile::index_t seqlen_q, - ck_tile::index_t seqlen_k, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale_s, - float scale_p, - float scale_o, - float logits_soft_cap, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, - ck_tile::index_t stride_randval, - ck_tile::index_t stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_lse, - ck_tile::index_t nhead_stride_o, - ck_tile::index_t batch_stride_q, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_bias, - ck_tile::index_t batch_stride_randval, - ck_tile::index_t batch_stride_lse, - ck_tile::index_t batch_stride_o, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - bool s_randval, - std::variant, std::pair> - drop_seed_offset, - const ck_tile::index_t* cu_seqlen_q_ptr, - const ck_tile::index_t* cu_seqlen_kv_ptr) - { - auto kargs = MakeKargsImpl(q_ptr, - k_ptr, - v_ptr, - bias_ptr, - rand_val_ptr, - lse_ptr, - o_ptr, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale_s, - scale_p, - scale_o, - logits_soft_cap, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_randval, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_randval, - nhead_stride_lse, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_randval, - batch_stride_lse, - batch_stride_o, - window_size_left, - window_size_right, - mask_type, - p_drop, - s_randval, - drop_seed_offset); - - kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr; - kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr; - return kargs; - } - // std::variant<> can't take in a list initializer, overload for backward compatibility template CK_TILE_HOST static constexpr std::enable_if_t @@ -890,95 +781,6 @@ struct FmhaFwdKernel return kargs; } - // Overload: Group mode with optional padded seqstarts for memory offsets - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargsImpl(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* bias_ptr, - void* rand_val_ptr, - void* lse_ptr, - void* o_ptr, - const void* seqstart_q_ptr, - const void* seqstart_k_ptr, - const void* seqlen_k_ptr, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale_s, - float scale_p, - float scale_o, - float logits_soft_cap, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, - ck_tile::index_t stride_randval, - ck_tile::index_t stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_lse, - ck_tile::index_t nhead_stride_o, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - ck_tile::index_t min_seqlen_q, - float p_drop, - bool s_randval, - std::variant, std::pair> - drop_seed_offset, - const void* seqstart_padded_q_ptr, - const void* seqstart_padded_k_ptr) - { - auto kargs = MakeKargsImpl(q_ptr, - k_ptr, - v_ptr, - bias_ptr, - rand_val_ptr, - lse_ptr, - o_ptr, - seqstart_q_ptr, - seqstart_k_ptr, - seqlen_k_ptr, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale_s, - scale_p, - scale_o, - logits_soft_cap, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_randval, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_randval, - nhead_stride_lse, - nhead_stride_o, - window_size_left, - window_size_right, - mask_type, - min_seqlen_q, - p_drop, - s_randval, - drop_seed_offset); - - kargs.seqstart_padded_q_ptr = reinterpret_cast(seqstart_padded_q_ptr); - kargs.seqstart_padded_k_ptr = reinterpret_cast(seqstart_padded_k_ptr); - return kargs; - } - // std::variant<> can't take in a list initializer, overload for backward compatibility template CK_TILE_HOST static constexpr std::enable_if_t @@ -1271,44 +1073,35 @@ struct FmhaFwdKernel if constexpr(kIsGroupMode) { - // logical and physical (padded) starts - const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch]; + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; - const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr - ? kargs.seqstart_padded_q_ptr[i_batch] - : query_start_unpadded; - const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr - ? kargs.seqstart_padded_k_ptr[i_batch] - : key_start_unpadded; - - // DRAM base offsets use physical padded starts - batch_offset_q = query_start_padded * kargs.stride_q; - batch_offset_k = key_start_padded * kargs.stride_k; + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; if constexpr(std::is_same_v) { - batch_offset_v = key_start_padded * kargs.stride_v; + batch_offset_v = key_start * kargs.stride_v; } else { - batch_offset_v = key_start_padded; + batch_offset_v = key_start; } if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - batch_offset_bias = query_start_padded * kargs.stride_bias; + batch_offset_bias = query_start * kargs.stride_bias; } if constexpr(kStoreLSE) { - // LSE stays indexed by unpadded starts - batch_offset_lse = query_start_unpadded; + batch_offset_lse = query_start; } if constexpr(kHasDropout) { - batch_offset_randval = query_start_padded * kargs.stride_randval; + batch_offset_randval = query_start * kargs.stride_randval; } - batch_offset_o = query_start_padded * kargs.stride_o; + batch_offset_o = query_start * kargs.stride_o; - // real logical lengths (exclude PAD) + // get real # queries & # keys under group mode const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; @@ -1320,7 +1113,8 @@ struct FmhaFwdKernel } } - // terminate unnecessary blocks earlier + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier if(kargs.seqlen_q <= i_m0) { return; @@ -1356,18 +1150,6 @@ struct FmhaFwdKernel static_cast(i_batch) * kargs.batch_stride_randval; } batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; - - // If cumulative seqlen pointers are provided, override per-batch effective lengths - if(kargs.cu_seqlen_q_ptr != nullptr) - { - kargs.seqlen_q = - kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; - } - if(kargs.cu_seqlen_kv_ptr != nullptr) - { - kargs.seqlen_k = - kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch]; - } } // for simplicity, batch stride we just modify the pointer @@ -1766,35 +1548,26 @@ struct FmhaFwdKernel if constexpr(kIsGroupMode) { // get starting offset for each batch - const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch]; + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; - const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr - ? kargs.seqstart_padded_q_ptr[i_batch] - : query_start_unpadded; - const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr - ? kargs.seqstart_padded_k_ptr[i_batch] - : key_start_unpadded; - - batch_offset_q = query_start_padded * kargs.stride_q; - batch_offset_k = key_start_padded * kargs.stride_k; + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; if constexpr(std::is_same_v) { - batch_offset_v = key_start_padded * kargs.stride_v; + batch_offset_v = key_start * kargs.stride_v; } else { - // col-major V: offset along seqlen dimension is scalar index - batch_offset_v = key_start_padded; + batch_offset_v = key_start; } if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - batch_offset_bias = query_start_padded * kargs.stride_bias; + batch_offset_bias = query_start * kargs.stride_bias; } - // LSE layout is [nhead, total_seqlen], index by unpadded start - batch_offset_lse = query_start_unpadded; - batch_offset_o = query_start_padded * kargs.stride_o; + batch_offset_lse = query_start; + batch_offset_o = query_start * kargs.stride_o; // get real # queries & # keys under group mode kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch]; @@ -1832,18 +1605,6 @@ struct FmhaFwdKernel batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; } - - // If cumulative seqlen pointers are provided, override per-batch effective lengths - if(kargs.cu_seqlen_q_ptr != nullptr) - { - kargs.seqlen_q = - kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; - } - if(kargs.cu_seqlen_kv_ptr != nullptr) - { - kargs.seqlen_k = - kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch]; - } } // for simplicity, batch stride we just modify the pointer diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp index 52b9da40b8..c5e5745817 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -100,11 +100,6 @@ struct FmhaFwdV3Kernel ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_o; - - // Optional cumulative sequence length pointers for batch mode - // If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding. - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] }; struct FmhaFwdGroupModeKargs @@ -115,11 +110,6 @@ struct FmhaFwdV3Kernel const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; const int32_t* seqlen_k_ptr; - - // Optional cumulative padded sequence starts (including PAD tokens) - // Used solely to compute memory offsets when sequences are physically padded. - const int32_t* seqstart_padded_q_ptr = nullptr; // [batch+1] - const int32_t* seqstart_padded_k_ptr = nullptr; // [batch+1] }; using Kargs = std::conditional_t; @@ -200,78 +190,6 @@ struct FmhaFwdV3Kernel return kargs; } - // Overload: Batch mode with optional cu_seqlen pointers - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - void* lse_ptr, - void* o_ptr, - ck_tile::index_t seqlen_q, - ck_tile::index_t seqlen_k, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale_s, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_lse, - ck_tile::index_t nhead_stride_o, - ck_tile::index_t batch_stride_q, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_lse, - ck_tile::index_t batch_stride_o, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - ck_tile::index_t remap_opt, - const ck_tile::index_t* cu_seqlen_q_ptr, - const ck_tile::index_t* cu_seqlen_kv_ptr) - { - auto kargs = MakeKargs(q_ptr, - k_ptr, - v_ptr, - lse_ptr, - o_ptr, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale_s, - stride_q, - stride_k, - stride_v, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_lse, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_lse, - batch_stride_o, - window_size_left, - window_size_right, - mask_type, - remap_opt); - - kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr; - kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr; - return kargs; - } - template CK_TILE_HOST static constexpr std::enable_if_t MakeKargs(const void* q_ptr, @@ -342,70 +260,6 @@ struct FmhaFwdV3Kernel return kargs; } - // Overload: Group mode with optional padded seqstarts for memory offsets - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - void* lse_ptr, - void* o_ptr, - const void* seqstart_q_ptr, - const void* seqstart_k_ptr, - const void* seqlen_k_ptr, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale_s, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_lse, - ck_tile::index_t nhead_stride_o, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - ck_tile::index_t remap_opt, - const void* seqstart_padded_q_ptr, - const void* seqstart_padded_k_ptr) - { - auto kargs = MakeKargs(q_ptr, - k_ptr, - v_ptr, - lse_ptr, - o_ptr, - seqstart_q_ptr, - seqstart_k_ptr, - seqlen_k_ptr, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale_s, - stride_q, - stride_k, - stride_v, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_lse, - nhead_stride_o, - window_size_left, - window_size_right, - mask_type, - remap_opt); - - kargs.seqstart_padded_q_ptr = reinterpret_cast(seqstart_padded_q_ptr); - kargs.seqstart_padded_k_ptr = reinterpret_cast(seqstart_padded_k_ptr); - return kargs; - } - CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, @@ -519,26 +373,18 @@ struct FmhaFwdV3Kernel if constexpr(kIsGroupMode) { // get starting offset for each batch - const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch]; + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; - const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr - ? kargs.seqstart_padded_q_ptr[i_batch] - : query_start_unpadded; - const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr - ? kargs.seqstart_padded_k_ptr[i_batch] - : key_start_unpadded; - - batch_offset_q = query_start_padded * kargs.stride_q; - batch_offset_k = key_start_padded * kargs.stride_k; - batch_offset_v = key_start_padded * kargs.stride_v; + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + batch_offset_v = key_start * kargs.stride_v; if constexpr(kStoreLSE) { - // LSE layout is [nhead, total_seqlen], index by unpadded start - batch_offset_lse = query_start_unpadded; + batch_offset_lse = query_start; } - batch_offset_o = query_start_padded * kargs.stride_o; + batch_offset_o = query_start * kargs.stride_o; // get real # queries & # keys under group mode const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; @@ -571,18 +417,6 @@ struct FmhaFwdV3Kernel batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; } batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; - - // If cumulative seqlen pointers are provided, override per-batch effective lengths - if(kargs.cu_seqlen_q_ptr != nullptr) - { - kargs.seqlen_q = - kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; - } - if(kargs.cu_seqlen_kv_ptr != nullptr) - { - kargs.seqlen_k = - kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch]; - } } // for simplicity, batch stride we just modify the pointer diff --git a/test/ck_tile/fmha/test_fmha_fwd.inc b/test/ck_tile/fmha/test_fmha_fwd.inc index 66d4e3dc21..08abd3358d 100644 --- a/test/ck_tile/fmha/test_fmha_fwd.inc +++ b/test/ck_tile/fmha/test_fmha_fwd.inc @@ -98,10 +98,7 @@ TEST_P(AllLong, Test) hdim_q, hdim_v, 0, // seqlen_knew - {-1}, // seqlen_qpads {seqlen_kpad}, // seqlen_kpads - {}, // q_eff_lens_per_batch - {}, // kv_eff_lens_per_batch 0, // rotary_dim perm, // i_perm perm, // o_perm @@ -163,10 +160,7 @@ TEST_P(HDimPadding, Test) hdim_q, hdim_v, 0, // seqlen_knew - {-1}, // seqlen_qpads {seqlen_kpad}, // seqlen_kpads - {}, // q_eff_lens_per_batch - {}, // kv_eff_lens_per_batch 0, // rotary_dim perm, // i_perm perm, // o_perm @@ -223,10 +217,7 @@ TEST_P(ElementwiseBias, Test) hdim_q, hdim_v, 0, // seqlen_knew - {-1}, // seqlen_qpads {-1}, // seqlen_kpads - {}, // q_eff_lens_per_batch - {}, // kv_eff_lens_per_batch 0, // rotary_dim i_perm, // i_perm false, // o_perm @@ -282,10 +273,7 @@ TEST_P(Alibi, Test) hdim_q, hdim_v, 0, // seqlen_knew - {-1}, // seqlen_qpads {-1}, // seqlen_kpads - {}, // q_eff_lens_per_batch - {}, // kv_eff_lens_per_batch 0, // rotary_dim true, // i_perm true, // o_perm @@ -343,10 +331,7 @@ TEST_P(Dropout, Test) hdim_q, hdim_v, 0, // seqlen_knew - {-1}, // seqlen_qpads {-1}, // seqlen_kpads - {}, // q_eff_lens_per_batch - {}, // kv_eff_lens_per_batch 0, // rotary_dim false, // i_perm false, // o_perm @@ -406,10 +391,7 @@ TEST_P(PagedKV, Test) hdim_q, hdim_v, 0, // seqlen_knew - {-1}, // seqlen_qpads {-1}, // seqlen_kpads - {}, // q_eff_lens_per_batch - {}, // kv_eff_lens_per_batch 0, // rotary_dim i_perm, // i_perm false, // o_perm @@ -475,10 +457,7 @@ TEST_P(SplitKV, Test) hdim_q, hdim_v, 0, // seqlen_knew - {-1}, // seqlen_qpads {-1}, // seqlen_kpads - {}, // q_eff_lens_per_batch - {}, // kv_eff_lens_per_batch 0, // rotary_dim i_perm, // i_perm false, // o_perm @@ -550,10 +529,7 @@ TEST_P(AppendKV, Test) hdim_q, hdim_v, seqlen_knew, // seqlen_knew - {-1}, // seqlen_qpads {-1}, // seqlen_kpads - {}, // q_eff_lens_per_batch - {}, // kv_eff_lens_per_batch 0, // rotary_dim i_perm, // i_perm true, // o_perm @@ -623,10 +599,7 @@ TEST_P(AppendKVRoPE, Test) hdim_q, hdim_v, seqlen_knew, // seqlen_knew - {-1}, // seqlen_qpads {-1}, // seqlen_kpads - {}, // q_eff_lens_per_batch - {}, // kv_eff_lens_per_batch rotary_dim, // rotary_dim i_perm, // i_perm true, // o_perm @@ -650,117 +623,3 @@ TEST_P(AppendKVRoPE, Test) } #endif // CK_TILE_FMHA_FWD_APPENDKV_API - -// --------------------------------------------------------------- -// Additional padding tests (q/kv physical padding & effective len) -// --------------------------------------------------------------- - -// Simple batch-mode test with per-batch Q/KV padding strides and effective lengths -TEST(TestCkTileFmhaFwd, BatchModeQKvPadding) -{ - if constexpr(std::is_same_v) - { - GTEST_SKIP() << "Skip for fp8"; - } - const mode_enum mode = mode_enum::batch; - const int batch = 3; - const int nhead = 2; - const int nhead_k = -1; - const int seqlen_q = 128; - const int seqlen_k = 128; - const int hdim_q = 64; - const int hdim_v = 64; - const int seqlen_knew = 0; - const std::vector seqlen_qpads{}; - const std::vector seqlen_kpads{}; - const std::vector q_eff_lens{120, 128, 100}; - const std::vector kv_eff_lens{110, 128, 90}; - - auto result = fmha_fwd_run(mode, - batch, - nhead, - nhead_k, - {adjust_seqlen(seqlen_q)}, - {adjust_seqlen(seqlen_k)}, - hdim_q, - hdim_v, - seqlen_knew, // seqlen_knew - seqlen_qpads, // seqlen_qpads - seqlen_kpads, // seqlen_kpads - q_eff_lens, // q_eff_lens_per_batch - kv_eff_lens, // kv_eff_lens_per_batch - 0, // rotary_dim - true, // i_perm - true, // o_perm - 0, // scale_s - 0, // logits_soft_cap - def_is_v_rowmajor, - def_lse, // lse - 0, // page_block_size - false, // use_cache_batch_idx - "n", // bias_str - 0.0f, // p_drop - 0, // drop_seed - 0, // drop_offset - false, // drop_prefs - "0", // mask_str - QUANT_ARGS, - true, // is_rotary_interleaved - 1, // num_splits - COMMON_ARGS); - CHECK_RESULT(result); -} - -// Simple group-mode test with uniform seqlen but per-batch padding & effective lengths -TEST(TestCkTileFmhaFwd, GroupModeQKvPadding) -{ - if constexpr(std::is_same_v) - { - GTEST_SKIP() << "Skip for fp8"; - } - const mode_enum mode = mode_enum::group; - const int batch = 2; - const int nhead = 2; - const int nhead_k = -1; - const std::vector seqlen_q{96, 128}; // unpadded - const std::vector seqlen_k{96, 128}; // unpadded - const int hdim_q = 64; - const int hdim_v = 64; - const int seqlen_knew = 0; - const std::vector seqlen_qpads{128, 160}; - const std::vector seqlen_kpads{128, 160}; - - auto result = fmha_fwd_run(mode, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - seqlen_knew, // seqlen_knew - seqlen_qpads, // seqlen_qpads - seqlen_kpads, // seqlen_kpads - {}, // q_eff_lens_per_batch - {}, // kv_eff_lens_per_batch - 0, // rotary_dim - true, // i_perm - true, // o_perm - 0, // scale_s - 0, // logits_soft_cap - def_is_v_rowmajor, - def_lse, // lse - 0, // page_block_size - false, // use_cache_batch_idx - "n", // bias_str - 0.0f, // p_drop - 0, // drop_seed - 0, // drop_offset - false, // drop_prefs - "0", // mask_str - QUANT_ARGS, - true, // is_rotary_interleaved - 1, // num_splits - COMMON_ARGS); - CHECK_RESULT(result); -}